[
  {
    "path": ".Rbuildignore",
    "content": "^data-raw$\n^Meta$\n^doc$\n^\\.travis\\.yml$\n^pkg.Rproj$\nfigure$\ncache$\n"
  },
  {
    "path": ".gitignore",
    "content": "Meta\ndoc\ninst/doc\n\n## Files\n# Emacs autosave files\n*~\n\\#*#\n# Don't put data in the repo\n*.csv\n*.feather\n\n# R stuff\n*.Rout\n*.Rhistory\n*.RData \n*.Rapp.history\n\n# Mac stuff\n*.DS_store\n\n\n# C++ stuff\n*.o\n*.so\n*.dll\n\ntest.R\n\n*-vignette.pdf"
  },
  {
    "path": ".travis.yml",
    "content": "# R for travis: see documentation at https://docs.travis-ci.com/user/languages/r\n\nlanguage: r\nr:\n    - 3.5.1  \nsudo: false\ncache: packages\nwarnings_are_errors: false\nr_binary_packages:\n    - dplyr\n    - magrittr\n    - ggplot2\n    - glmnet\n    - plyr\n    - kableExtra"
  },
  {
    "path": "DESCRIPTION",
    "content": "Package: augsynth\nTitle: The Augmented Synthetic Control Method\nVersion: 0.2.0\nAuthors@R: person(\"Eli\", \"Ben-Michael\", email = \"ebenmichael@berkeley.edu\", role = c(\"aut\", \"cre\"))\nDescription: A package implementing the Augmented Synthetic Controls Method.\nDepends:\n    R (>= 3.5.0)\nImports:\n    dplyr,\n    tidyr,\n    magrittr,\n    ggplot2,\n    MASS,\n    LiblineaR,\n    Formula,\n    Matrix,\n    osqp,\n    rlang,\n    purrr,\n    FNN\nRemotes:\n    susanathey/MCPanel\nLicense: MIT + file LICENSE\nEncoding: UTF-8\nLazyData: true\nRoxygenNote: 7.2.3\nSuggests:\n    testthat,\n    CausalImpact,\n    keras,\n    gsynth,\n    knitr,\n    rmarkdown,\n    softImpute,\n    MCPanel,\n    glmnet,\n    randomForest,\n    kableExtra,\n    ggrepel\nVignetteBuilder: knitr\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 Elijahu Ben-Michael\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "NAMESPACE",
    "content": "# Generated by roxygen2: do not edit by hand\n\nS3method(plot,augsynth)\nS3method(plot,augsynth_multiout)\nS3method(plot,multisynth)\nS3method(plot,summary.augsynth)\nS3method(plot,summary.augsynth_multiout)\nS3method(plot,summary.multisynth)\nS3method(predict,augsynth)\nS3method(predict,augsynth_multiout)\nS3method(predict,multisynth)\nS3method(print,augsynth)\nS3method(print,augsynth_multiout)\nS3method(print,multisynth)\nS3method(print,summary.augsynth)\nS3method(print,summary.augsynth_multiout)\nS3method(print,summary.multisynth)\nS3method(summary,augsynth)\nS3method(summary,augsynth_multiout)\nS3method(summary,multisynth)\nexport(augsynth)\nexport(augsynth_multiout)\nexport(multisynth)\nexport(rdirichlet_b)\nexport(rmultinom_b)\nexport(rwild_b)\nexport(single_augsynth)\nimport(dplyr)\nimport(tidyr)\nimportFrom(ggplot2,aes)\nimportFrom(graphics,plot)\nimportFrom(magrittr,\"%>%\")\nimportFrom(purrr,reduce)\nimportFrom(stats,coef)\nimportFrom(stats,delete.response)\nimportFrom(stats,formula)\nimportFrom(stats,lm)\nimportFrom(stats,model.frame)\nimportFrom(stats,model.matrix)\nimportFrom(stats,na.omit)\nimportFrom(stats,poly)\nimportFrom(stats,predict)\nimportFrom(stats,sd)\nimportFrom(stats,terms)\nimportFrom(stats,update)\nimportFrom(utils,capture.output)\n"
  },
  {
    "path": "R/augsynth.R",
    "content": "################################################################################\n## Main functions for single-period treatment augmented synthetic controls Method\n################################################################################\n\n\n#' Fit Augmented SCM\n#' \n#' @param form outcome ~ treatment | auxillary covariates\n#' @param unit Name of unit column\n#' @param time Name of time column\n#' @param t_int Time of intervention\n#' @param data Panel data as dataframe\n#' @param progfunc What function to use to impute control outcomes\n#'                 ridge=Ridge regression (allows for standard errors),\n#'                 none=No outcome model,\n#'                 en=Elastic Net, RF=Random Forest, GSYN=gSynth,\n#'                 mcp=MCPanel, \n#'                 cits=Comparitive Interuppted Time Series\n#'                 causalimpact=Bayesian structural time series with CausalImpact\n#' @param scm Whether the SCM weighting function is used\n#' @param fixedeff Whether to include a unit fixed effect, default F \n#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted\n#' @param ... optional arguments for outcome model\n#'\n#' @return augsynth object that contains:\n#'         \\itemize{\n#'          \\item{\"weights\"}{Ridge ASCM weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#'          \\item{\"mhat\"}{Outcome model estimate}\n#'          \\item{\"data\"}{Panel data as matrices}\n#'         }\n#' @export\nsingle_augsynth <- function(form, unit, time, t_int, data,\n                     progfunc = \"ridge\",\n                     scm=T,\n                     fixedeff = FALSE,\n                     cov_agg=NULL, ...) {\n    call_name <- match.call()\n\n    form <- Formula::Formula(form)\n    unit <- enquo(unit)\n    time <- enquo(time)\n\n    ## format data\n    outcome <- terms(formula(form, rhs=1))[[2]]\n    trt <- terms(formula(form, rhs=1))[[3]]\n\n    wide <- format_data(outcome, trt, unit, time, t_int, data)\n    synth_data <- do.call(format_synth, wide)\n    \n    treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)\n    control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% \n                        distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit)\n        ## add covariates\n    if(length(form)[2] == 2) {\n        Z <- extract_covariates(form, unit, time, t_int, data, cov_agg)\n    } else {\n        Z <- NULL\n    }\n    \n    # fit augmented SCM\n    augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc, \n                                      scm, fixedeff, ...)\n    \n    # add some extra data\n    augsynth$data$time <- data %>% distinct(!!time) %>%\n                                   arrange(!!time) %>% pull(!!time)\n    augsynth$call <- call_name\n    augsynth$t_int <- t_int \n    \n    augsynth$weights <- matrix(augsynth$weights)\n    rownames(augsynth$weights) <- control_units\n\n    return(augsynth)\n}\n\n\n#' Internal function to fit augmented SCM\n#' @param wide Data formatted from format_data\n#' @param synth_data Data formatted from foramt_synth\n#' @param Z Matrix of auxiliary covariates\n#' @param progfunc outcome model to use\n#' @param scm Whether to fit SCM\n#' @param fixedeff Whether to de-mean synth\n#' @param V V matrix for Synth, default NULL\n#' @param ... Extra args for outcome model\n#' \n#' @noRd\n#' \nfit_augsynth_internal <- function(wide, synth_data, Z, progfunc,\n                                  scm, fixedeff, V = NULL, ...) {\n\n    n <- nrow(wide$X)\n    t0 <- ncol(wide$X)\n    ttot <- t0 + ncol(wide$y)\n    if(fixedeff) {\n        demeaned <- demean_data(wide, synth_data)\n        fit_wide <- demeaned$wide\n        fit_synth_data <- demeaned$synth_data\n        mhat <- demeaned$mhat\n    } else {\n        fit_wide <- wide\n        fit_synth_data <- synth_data\n        mhat <- matrix(0, n, ttot)\n    }\n    if (is.null(progfunc)) {\n        progfunc = \"none\"\n    }\n    progfunc = tolower(progfunc)\n    ## fit augsynth\n    if(progfunc == \"ridge\") {\n        # Ridge ASCM\n        augsynth <- do.call(fit_ridgeaug_formatted,\n                            list(wide_data = fit_wide,\n                                 synth_data = fit_synth_data,\n                                 Z = Z, V = V, scm = scm, ...))\n    } else if(progfunc == \"none\") {\n        ## Just SCM\n        augsynth <- do.call(fit_ridgeaug_formatted,\n                        c(list(wide_data = fit_wide, \n                               synth_data = fit_synth_data,\n                               Z = Z, ridge = F, scm = T, V = V, ...)))\n    } else {\n        ## Other outcome models\n        progfuncs = c(\"ridge\", \"none\", \"en\", \"rf\", \"gsyn\", \"mcp\",\n                      \"cits\", \"causalimpact\", \"seq2seq\")\n        if (progfunc %in% progfuncs) {\n            augsynth <- fit_augsyn(fit_wide, fit_synth_data, \n                                   progfunc, scm, ...)\n        } else {\n            stop(\"progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq', 'None'\")\n        }\n        \n    }\n\n    augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0), \n                                  augsynth$mhat)\n    augsynth$data <- wide\n    augsynth$data$Z <- Z\n    augsynth$data$synth_data <- synth_data\n    augsynth$progfunc <- progfunc\n    augsynth$scm <- scm\n    augsynth$fixedeff <- fixedeff\n    augsynth$extra_args <- list(...)\n    if(progfunc == \"ridge\") {\n        augsynth$extra_args$lambda <- augsynth$lambda\n    } else if(progfunc == \"gsyn\") {\n        augsynth$extra_args$r <- ncol(augsynth$params$factor)\n        augsynth$extra_args$CV <- 0\n    }\n    ##format output\n    class(augsynth) <- \"augsynth\"\n    return(augsynth)\n}\n\n#' Get prediction of ATT or average outcome under control\n#' @param object augsynth object\n#' @param att If TRUE, return the ATT, if FALSE, return imputed counterfactual\n#' @param ... Optional arguments\n#'\n#' @return Vector of predicted post-treatment control averages\n#' @export\npredict.augsynth <- function(object, att = F, ...) {\n    # if (\"att\" %in% names(list(...))) {\n    #     att <- list(...)$att\n    # } else {\n    #     att <- F\n    # }\n    augsynth <- object\n    \n    X <- augsynth$data$X\n    y <- augsynth$data$y\n    comb <- cbind(X, y)\n    trt <- augsynth$data$trt\n    mhat <- augsynth$mhat\n    \n    m1 <- colMeans(mhat[trt==1,,drop=F])\n\n    resid <- (comb[trt==0,,drop=F] - mhat[trt==0,drop=F])\n\n    y0 <- m1 + t(resid) %*% augsynth$weights\n    if(att) {\n        return(colMeans(comb[trt == 1,, drop = F]) - c(y0))\n    } else {\n        rnames <- rownames(y0)\n        y0_vec <- c(y0)\n        names(y0_vec) <- rnames\n        return(y0_vec)\n    }\n}\n\n\n#' Print function for augsynth\n#' @param x augsynth object\n#' @param ... Optional arguments\n#' @export\nprint.augsynth <- function(x, ...) {\n    augsynth <- x\n    \n    ## straight from lm\n    cat(\"\\nCall:\\n\", paste(deparse(augsynth$call), sep=\"\\n\", collapse=\"\\n\"), \"\\n\\n\", sep=\"\")\n\n    ## print att estimates\n    tint <- ncol(augsynth$data$X)\n    ttotal <- tint + ncol(augsynth$data$y)\n    att_post <- predict(augsynth, att = T)[(tint + 1):ttotal]\n\n    cat(paste(\"Average ATT Estimate: \",\n              format(round(mean(att_post),3), nsmall = 3), \"\\n\\n\", sep=\"\"))\n}\n\n\n#' Plot function for augsynth\n#' @importFrom graphics plot\n#' \n#' @param x Augsynth object to be plotted\n#' @param inf Boolean, whether to get confidence intervals around the point estimates\n#' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects\n#' @param ... Optional arguments\n#' @export\nplot.augsynth <- function(x, inf = T, cv = F, ...) {\n    # if (\"se\" %in% names(list(...))) {\n    #     se <- list(...)$se\n    # } else {\n    #     se <- T\n    # }\n\n    augsynth <- x\n    \n    if (cv == T) {\n        errors = data.frame(lambdas = augsynth$lambdas,\n                            errors = augsynth$lambda_errors,\n                            errors_se = augsynth$lambda_errors_se)\n        p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) +\n              ggplot2::geom_point(size = 2) + \n              ggplot2::geom_errorbar(\n                ggplot2::aes(ymin = errors,\n                             ymax = errors + errors_se),\n                width=0.2, size = 0.5) \n        p <- p + ggplot2::labs(title = bquote(\"Cross Validation MSE over \" ~ lambda),\n                              x = expression(lambda), y = \"Cross Validation MSE\", \n                              parse = TRUE)\n        p <- p + ggplot2::scale_x_log10()\n        \n        # find minimum and min + 1se lambda to plot\n        min_lambda <- choose_lambda(augsynth$lambdas,\n                                   augsynth$lambda_errors,\n                                   augsynth$lambda_errors_se,\n                                   F)\n        min_1se_lambda <- choose_lambda(augsynth$lambdas,\n                                       augsynth$lambda_errors,\n                                       augsynth$lambda_errors_se,\n                                       T)\n        min_lambda_index <- which(augsynth$lambdas == min_lambda)\n        min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda)\n\n        p <- p + ggplot2::geom_point(\n            ggplot2::aes(x = min_lambda, \n                         y = augsynth$lambda_errors[min_lambda_index]),\n            color = \"gold\")\n        p + ggplot2::geom_point(\n              ggplot2::aes(x = min_1se_lambda,\n                           y = augsynth$lambda_errors[min_1se_lambda_index]),\n              color = \"gold\") +\n            ggplot2::theme_bw()\n    } else {\n        plot(summary(augsynth, ...), inf = inf)\n    }\n}\n\n\n#' Summary function for augsynth\n#' @param object augsynth object\n#' @param inf Boolean, whether to get confidence intervals around the point estimates\n#' @param inf_type Type of inference algorithm. Options are\n#'         \\itemize{\n#'          \\item{\"conformal\"}{Conformal inference (default)}\n#'          \\item{\"jackknife+\"}{Jackknife+ algorithm over time periods}\n#'          \\item{\"jackknife\"}{Jackknife over units}\n#'         }\n#' @param linear_effect Boolean, whether to invert the conformal inference hypothesis test to get confidence intervals for a linear-in-time treatment effect: intercept + slope * time \n#' @param ... Optional arguments for inference, for more details for each `inf_type` see\n#'         \\itemize{\n#'          \\item{\"conformal\"}{`conformal_inf`}\n#'          \\item{\"jackknife+\"}{`time_jackknife_plus`}\n#'          \\item{\"jackknife\"}{`jackknife_se_single`}\n#'         }\n#' @export\nsummary.augsynth <- function(object, inf = T, inf_type = \"conformal\",\n                             linear_effect = F,\n                             ...) {\n    augsynth <- object\n    summ <- list()\n\n    t0 <- ncol(augsynth$data$X)\n    t_final <- t0 + ncol(augsynth$data$y)\n\n    if(inf) {\n        if(inf_type == \"jackknife\") {\n            att_se <- jackknife_se_single(augsynth)\n        } else if(inf_type == \"jackknife+\") {\n            att_se <- time_jackknife_plus(augsynth, ...)\n        } else if(inf_type == \"conformal\") {\n          att_se <- conformal_inf(augsynth, ...)\n          # get CIs for linear treatment effects\n          if(linear_effect) {\n            att_linear <- conformal_inf_linear(augsynth, ...)\n          }\n        } else {\n            stop(paste(inf_type, \"is not a valid choice of 'inf_type'\"))\n        }\n\n        att <- data.frame(Time = augsynth$data$time,\n                          Estimate = att_se$att[1:t_final])\n        if(inf_type == \"jackknife\") {\n          att$Std.Error <- att_se$se[1:t_final]\n          att_avg_se <- att_se$se[t_final + 1]\n        } else {\n          att_avg_se <- NA\n        }\n        att_avg <- att_se$att[t_final + 1]\n        if(inf_type %in% c(\"jackknife+\", \"nonpar_bs\", \"t_dist\", \"conformal\")) {\n            att$lower_bound <- att_se$lb[1:t_final]\n            att$upper_bound <- att_se$ub[1:t_final]\n        }\n        if(inf_type == \"conformal\") {\n          att$p_val <- att_se$p_val[1:t_final]\n        }\n\n    } else {\n        t0 <- ncol(augsynth$data$X)\n        t_final <- t0 + ncol(augsynth$data$y)\n        att_est <- predict(augsynth, att = T)\n        att <- data.frame(Time = augsynth$data$time,\n                          Estimate = att_est)\n        att$Std.Error <- NA\n        att_avg <- mean(att_est[(t0 + 1):t_final])\n        att_avg_se <- NA\n    }\n\n    summ$att <- att\n\n    if(inf) {\n      if(inf_type %in% c(\"jackknife+\")) {\n        summ$average_att <- data.frame(Value = \"Average Post-Treatment Effect\",\n                              Estimate = att_avg, Std.Error = att_avg_se)\n        summ$average_att$lower_bound <- att_se$lb[t_final + 1]\n        summ$average_att$upper_bound <- att_se$ub[t_final + 1]\n        summ$alpha <-  att_se$alpha\n      }\n      if(inf_type == \"conformal\") {\n        # summ$average_att$p_val <- att_se$p_val[t_final + 1]\n        # summ$average_att$lower_bound <- att_se$lb[t_final + 1]\n        # summ$average_att$upper_bound <- att_se$ub[t_final + 1]\n        # summ$alpha <-  att_se$alpha\n        if(linear_effect) {\n          summ$average_att <- data.frame(\n                                Value = c(\"Average Post-Treatment Effect\",\n                                          \"Treatment Effect Intercept\",\n                                          \"Treatment Effect Slope\"),\n                                Estimate = c(att_avg, att_linear$est_int,\n                                             att_linear$est_slope),\n                                Std.Error = c(att_avg_se, NA, NA),\n                                p_val = c(att_se$p_val[t_final + 1], NA, NA),\n                                lower_bound = c(att_se$lb[t_final + 1],\n                                            att_linear$ci_int[1],\n                                            att_linear$ci_slope[1]),\n                                upper_bound =  c(att_se$ub[t_final + 1],\n                                            att_linear$ci_int[2],\n                                            att_linear$ci_slope[2])\n          )\n        } else {\n          summ$average_att <- data.frame(\n                                Value = c(\"Average Post-Treatment Effect\"),\n                                Estimate = att_avg,\n                                Std.Error = att_avg_se,\n                                p_val = att_se$p_val[t_final + 1],\n                                lower_bound = att_se$lb[t_final + 1],\n                                upper_bound =  att_se$ub[t_final + 1]\n          )\n\n        }\n        summ$alpha <-  att_se$alpha\n      }\n    } else {\n              summ$average_att <- data.frame(Value = \"Average Post-Treatment Effect\",\n                              Estimate = att_avg, Std.Error = att_avg_se)\n    }\n    summ$t_int <- augsynth$t_int\n    summ$call <- augsynth$call\n    summ$l2_imbalance <- augsynth$l2_imbalance\n    summ$scaled_l2_imbalance <- augsynth$scaled_l2_imbalance\n    if(!is.null(augsynth$covariate_l2_imbalance)) {\n      summ$covariate_l2_imbalance <- augsynth$covariate_l2_imbalance\n      summ$scaled_covariate_l2_imbalance <- augsynth$scaled_covariate_l2_imbalance\n    }\n    ## get estimated bias\n\n    if(tolower(augsynth$progfunc) == \"ridge\") {\n        mhat <- augsynth$ridge_mhat\n        w <- augsynth$synw\n    } else {\n        mhat <- augsynth$mhat\n        w <- augsynth$weights\n    }\n    trt <- augsynth$data$trt\n    m1 <- colMeans(mhat[trt==1,,drop=F])\n\n    if(tolower(augsynth$progfunc) == \"none\" | (!augsynth$scm)) {\n        summ$bias_est <- NA\n    } else {\n      summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w\n    }\n    \n    \n    summ$inf_type <- if(inf) inf_type else \"None\"\n    class(summ) <- \"summary.augsynth\"\n    return(summ)\n}\n\n#' Print function for summary function for augsynth\n#' @param x summary object\n#' @param ... Optional arguments\n#' @export\nprint.summary.augsynth <- function(x, ...) {\n    summ <- x\n    \n    ## straight from lm\n    cat(\"\\nCall:\\n\", paste(deparse(summ$call), sep=\"\\n\", collapse=\"\\n\"), \"\\n\\n\", sep=\"\")\n\n    t_final <- nrow(summ$att)\n\n    ## distinction between pre and post treatment\n    att_est <- summ$att$Estimate\n    t_total <- length(att_est)\n    t_int <- summ$att %>% filter(Time <= summ$t_int) %>% nrow()\n    \n    att_pre <- att_est[1:(t_int-1)]\n    att_post <- att_est[t_int:t_total]\n\n\n    out_msg <- \"\"\n\n\n    # print out average post treatment estimate\n    att_post <- summ$average_att$Estimate[1]\n    se_est <- summ$att$Std.Error\n    if(summ$inf_type == \"jackknife\") {\n      se_avg <- summ$average_att$Std.Error\n\n      out_msg <- paste(\"Average ATT Estimate (Jackknife Std. Error): \",\n                        format(round(att_post,3), nsmall=3), \n                        \"  (\",\n                        format(round(se_avg,3)), \")\\n\")\n      inf_type <- \"Jackknife over units\"\n    } else if(summ$inf_type == \"conformal\") {\n      p_val <- summ$average_att$p_val[1]\n      out_msg <- paste(\"Average ATT Estimate (p Value for Joint Null): \",\n                        format(att_post, digits = 3), \n                        \"  (\",\n                        format(p_val, digits = 2), \")\\n\")\n      inf_type <- \"Conformal inference\"\n      if(\"Treatment Effect Slope\" %in% summ$average_att$Value) {\n        lowers <- summ$average_att$lower_bound[2:3]\n        uppers <- summ$average_att$upper_bound[2:3]\n        out_msg_line2 <- paste0(\"Confidence intervals for linear-in-time treatment effects (Intercept + Slope * Time)\\n\",\n        \"\\tIntercept: [\", format(lowers[1], digits = 3), \",\",\n        format(uppers[1], digits = 3), \"]\\n\",\n        \"\\tSlope: [\", format(lowers[2], digits = 3), \",\",\n        format(uppers[2], digits = 3), \"]\\n\")\n        out_msg <- paste0(out_msg, out_msg_line2)\n      }\n    } else if(summ$inf_type == \"jackknife+\") {\n      out_msg <- paste(\"Average ATT Estimate: \",\n                        format(round(att_post,3), nsmall=3), \"\\n\")\n      inf_type <- \"Jackknife+ over time periods\"\n    } else {\n      out_msg <- paste(\"Average ATT Estimate: \",\n                        format(round(att_post,3), nsmall=3), \"\\n\")\n      inf_type <- \"None\"\n    }\n\n\n    out_msg <- paste(out_msg, \n              \"L2 Imbalance: \",\n              format(round(summ$l2_imbalance,3), nsmall=3), \"\\n\",\n              \"Percent improvement from uniform weights: \",\n              format(round(1 - summ$scaled_l2_imbalance,3)*100), \"%\\n\\n\",\n              sep=\"\")\n  if(!is.null(summ$covariate_l2_imbalance)) {\n\n    out_msg <- paste(out_msg,\n                     \"Covariate L2 Imbalance: \",\n                     format(round(summ$covariate_l2_imbalance,3), \n                                  nsmall=3),\n                    \"\\n\",\n                     \"Percent improvement from uniform weights: \",\n                     format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100), \n                     \"%\\n\\n\",\n                     sep=\"\")\n\n  }\n  out_msg <- paste(out_msg, \n              \"Avg Estimated Bias: \",\n              format(round(mean(summ$bias_est), 3),nsmall=3), \"\\n\\n\",\n              \"Inference type: \",\n              inf_type,\n              \"\\n\\n\",\n              sep=\"\")\n  cat(out_msg)\n\n    if(summ$inf_type == \"jackknife\") {\n      out_att <- summ$att[t_int:t_final,] %>% \n              select(Time, Estimate, Std.Error)\n    } else if(summ$inf_type == \"conformal\") {\n      out_att <- summ$att[t_int:t_final,] %>% \n              select(Time, Estimate, lower_bound, upper_bound, p_val)\n      names(out_att) <- c(\"Time\", \"Estimate\", \n                          paste0((1 - summ$alpha) * 100, \"% CI Lower Bound\"),\n                          paste0((1 - summ$alpha) * 100, \"% CI Upper Bound\"),\n                          paste0(\"p Value\"))\n    } else if(summ$inf_type == \"jackknife+\") {\n      out_att <- summ$att[t_int:t_final,] %>% \n              select(Time, Estimate, lower_bound, upper_bound)\n      names(out_att) <- c(\"Time\", \"Estimate\", \n                          paste0((1 - summ$alpha) * 100, \"% CI Lower Bound\"),\n                          paste0((1 - summ$alpha) * 100, \"% CI Upper Bound\"))\n    } else {\n      out_att <- summ$att[t_int:t_final,] %>% \n              select(Time, Estimate)\n    }\n    out_att %>%\n      mutate_at(vars(-Time), ~ round(., 3)) %>%\n      print(row.names = F)\n\n    \n}\n\n#' Plot function for summary function for augsynth\n#' @param x Summary object\n#' @param inf Boolean, whether to plot confidence intervals\n#' @param ... Optional arguments\n#' @export\nplot.summary.augsynth <- function(x, inf = T, ...) {\n    summ <- x\n    # if (\"inf\" %in% names(list(...))) {\n    #     inf <- list(...)$inf\n    # } else {\n    #     inf <- T\n    # }\n    \n    p <- summ$att %>%\n        ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))\n    if(inf) {\n        if(all(is.na(summ$att$lower_bound))) {\n            p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error,\n                        ymax=Estimate+2*Std.Error),\n                    alpha=0.2)\n        } else {\n            p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound,\n                        ymax=upper_bound),\n                    alpha=0.2)\n        }\n\n    }\n    p + ggplot2::geom_line() +\n        ggplot2::geom_vline(xintercept=summ$t_int, lty=2) +\n        ggplot2::geom_hline(yintercept=0, lty=2) + \n        ggplot2::theme_bw()\n\n}\n\n\n\n#' augsynth\n#' \n#' @description A package implementing the Augmented Synthetic Controls Method\n#' @docType package\n#' @name augsynth-package\n#' @importFrom magrittr \"%>%\"\n#' @importFrom purrr reduce\n#' @import dplyr\n#' @import tidyr\n#' @importFrom stats terms\n#' @importFrom stats formula\n#' @importFrom stats update \n#' @importFrom stats delete.response \n#' @importFrom stats model.matrix \n#' @importFrom stats model.frame \n#' @importFrom stats na.omit\nNULL\n"
  },
  {
    "path": "R/augsynth_pre.R",
    "content": "################################################################################\n## Main function for the augmented synthetic controls Method\n################################################################################\n\n\n#' Fit Augmented SCM\n#' @param form outcome ~ treatment | auxillary covariates\n#' @param unit Name of unit column\n#' @param time Name of time column\n#' @param data Panel data as dataframe\n#' @param t_int Time of intervention (used for single-period treatment only)\n#' @param ... Optional arguments\n#' \\itemize{\n#'   \\item Single period augsynth with/without multiple outcomes\n#'     \\itemize{\n#'       \\item{\"progfunc\"}{What function to use to impute control outcomes: Ridge=Ridge regression (allows for standard errors), None=No outcome model, EN=Elastic Net, RF=Random Forest, GSYN=gSynth, MCP=MCPanel, CITS=CITS, CausalImpact=Bayesian structural time series with CausalImpact, seq2seq=Sequence to sequence learning with feedforward nets}\n#'       \\item{\"scm\"}{Whether the SCM weighting function is used}\n#'       \\item{\"fixedeff\"}{Whether to include a unit fixed effect, default F }\n#'       \\item{\"cov_agg\"}{Covariate aggregation functions, if NULL then use mean with NAs omitted}\n#'     }\n#'   \\item Multi period (staggered) augsynth\n#'    \\itemize{\n#'          \\item{\"relative\"}{Whether to compute balance by relative time}\n#'          \\item{\"n_leads\"}{How long past treatment effects should be estimated for}\n#'          \\item{\"n_lags\"}{Number of pre-treatment periods to balance, default is to balance all periods}\n#'          \\item{\"alpha\"}{Fraction of balance for individual balance}\n#'          \\item{\"lambda\"}{Regularization hyperparameter, default = 0}\n#'          \\item{\"force\"}{Include \"none\", \"unit\", \"time\", \"two-way\" fixed effects. Default: \"two-way\"}\n#'          \\item{\"n_factors\"}{Number of factors for interactive fixed effects, default does CV}\n#'         }\n#' }\n#' \n#' @return augsynth object that contains:\n#'         \\itemize{\n#'          \\item{\"weights\"}{weights}\n#'          \\item{\"data\"}{Panel data as matrices}\n#'         }\n#' @export\n#' \naugsynth <- function(form, unit, time, data, t_int=NULL, ...) {\n\n  call_name <- match.call()\n\n  form <- Formula::Formula(form)\n  unit_quosure <- enquo(unit)\n  time_quosure <- enquo(time)\n  \n\n  ## format data\n  outcome <- terms(formula(form, rhs=1))[[2]]\n  trt <- terms(formula(form, rhs=1))[[3]]\n\n  # check for multiple outcomes\n  multi_outcome <- length(outcome) != 1\n\n  ## get first treatment times\n  trt_time <- data %>%\n      group_by(!!unit_quosure) %>%\n      filter(!all(!!trt == 0)) %>%\n      summarise(trt_time = min((!!time_quosure)[(!!trt) == 1])) %>%\n      mutate(trt_time = replace_na(as.numeric(trt_time), Inf))\n\n  num_trt_years <- sum(is.finite(unique(trt_time$trt_time)))\n\n  if(multi_outcome & num_trt_years > 1) {\n    stop(\"augsynth is not currently implemented for more than one outcome and more than one treated unit\")\n  } else if(num_trt_years > 1) {\n    message(\"More than one treatment time found. Running multisynth.\")\n    if(\"progfunc\" %in% names(list(...))) {\n      warning(\"`progfunc` is not an argument for multisynth, so it is ignored\")\n    }\n    return(multisynth(form, !!enquo(unit), !!enquo(time), data, ...)) \n  } else {\n    if (is.null(t_int)) {\n      t_int <- trt_time %>% filter(is.finite(trt_time)) %>%\n        summarise(t_int = max(trt_time)) %>% pull(t_int)\n    }\n    if(!multi_outcome) {\n      message(\"One outcome and one treatment time found. Running single_augsynth.\")\n      return(single_augsynth(form, !!enquo(unit), !!enquo(time), t_int,\n                             data = data, ...))\n    } else {\n      message(\"Multiple outcomes and one treatment time found. Running augsynth_multiout.\")\n      return(augsynth_multiout(form, !!enquo(unit), !!enquo(time), t_int,\n                               data = data, ...))\n    }\n  }\n}\n"
  },
  {
    "path": "R/cv.R",
    "content": "drop_time_t <- function(wide_data, Z, t_drop) {\n  new_wide_data <- list()\n  new_wide_data$trt <- wide_data$trt\n  \n  if (is.list(wide_data$X)) {\n    # TODO\n  } else {\n    new_wide_data$X <- wide_data$X[, -t_drop, drop = F]\n    new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F], \n                             wide_data$y)\n    \n    X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]\n    x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),\n                 ncol=1)\n    y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]\n    y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])\n    \n    new_synth_data <- list()\n    new_synth_data$Z0 <- t(X0)\n    new_synth_data$X0 <- t(X0)\n    new_synth_data$Z1 <- x1\n    new_synth_data$X1 <- x1\n    \n    return(list(wide_data = new_wide_data,\n                synth_data = new_synth_data,\n                Z = Z)) \n  }\n}\n\ndrop_time_and_refit <- function(wide_data, Z, t_drop, progfunc, scm, fixedeff, ...) {\n  new_data <- drop_time_t(wide_data, Z, t_drop)\n  new_ascm <- do.call(fit_augsynth_internal,\n                      c(list(wide = new_data$wide,\n                             synth_data = new_data$synth_data,\n                             Z = new_data$Z,\n                             progfunc = progfunc,\n                             scm = scm,\n                             fixedeff = fixedeff, ...)))\n  return(new_ascm)\n}\n\ncv_internal <- function(wide_data, Z, progfunc, scm, fixedeff, lambdas, holdout_periods, ...) {\n  X <- wide_data$X\n  lambda_error_vals <- vapply(lambdas, function(lambda){\n    errors <- apply(holdout_periods, 1, function(t_drop){\n      new_ascm <- drop_time_and_refit(wide_data, Z, t_drop, progfunc, scm, fixedeff, lambda = lambda, ...)\n      err <- sum((predict(new_ascm, att = T)[(ncol(X)-length(t_drop)+1):ncol(X)])^2)\n      err\n    })\n    lambda_error <- mean(errors)\n    lambda_error_se <- sd(errors) / sqrt(length(errors))\n    c(lambda_error, lambda_error_se)\n  }, numeric(2))\n  return(list(lambda_errors = lambda_error_vals[1,], lambda_errors_se = lambda_error_vals[2,]))\n}\n\ncv_ridge <- function(wide_data, synth_data, Z, progfunc, scm, fixedeff, how = 'time', holdout_length = 1, lambdas = NULL, \n               lambda_min_ratio = 1e-8, n_lambda = 20, lambda_max = NULL, min_1se = T, V = NULL, ...) {\n  X <- wide_data$X\n  trt <- wide_data$trt\n  if (is.null(lambdas)) {\n    if(is.null(lambda_max)) {\n      X_cent <- apply(X, 2, function(x) x - mean(x[trt==0]))\n      X_c <- X_cent[trt==0,,drop=FALSE]\n      t0 <- ncol(X_c)\n      \n      if(is.null(V)) {\n        V <- diag(rep(1, t0))\n      } else if(is.vector(V)) {\n        V <- diag(V)\n      } else if(ncol(V) == 1 & nrow(V) == t0) {\n        V <- diag(c(V))\n      } else if(ncol(V) == t0 & nrow(V) == 1) {\n        V <- diag(c(V))\n      } else if(nrow(V) == t0) {\n      } else {\n        stop(\"`V` must be a vector with t0 elements or a t0xt0 matrix\")\n      }\n      X_c <- X_c %*% V\n\n      if(!is.null(Z)) {\n        Z_cent <- apply(Z, 2, function(x) x - mean(x[trt==0]))\n        Z_c <- Z_cent[trt==0,,drop=FALSE]\n        Xc_hat <- Z_c %*% solve(t(Z_c) %*% Z_c) %*% t(Z_c) %*% X_c\n        res_c <- X_c - Xc_hat\n        X_c <- res_c\n      }\n      lambda_max <- svd(X_c)$d[1] ^ 2\n    }\n    lambdas <- create_lambda_list(lambda_max, lambda_min_ratio, n_lambda)\n  }\n  \n  if (how == 'time') {\n    period_starts <- 1:(ncol(X) - holdout_length)\n    if (holdout_length == 1) {\n      holdout_periods <- matrix(period_starts, nrow = length(period_starts), ncol = 1)\n    } else {\n      holdout_periods <- t(vapply(period_starts, function(t) t:(t+holdout_length-1), numeric(holdout_length)))\n    }\n    results <- cv_internal(wide_data, Z, progfunc, scm, fixedeff, lambdas, holdout_periods, ...)\n    lambda <- choose_lambda(lambdas, results$lambda_errors, results$lambda_errors_se, min_1se)\n    return(list(lambda = lambda, lambdas = lambdas, lambda_errors = results$lambda_errors, lambda_errors_se = results$lambda_errors_se))\n  }\n}"
  },
  {
    "path": "R/data.R",
    "content": "#' Economic indicators for US states from 1990-2016\n#' \n#' \n#' @format A dataframe with 5250 rows and 32 variables:\n#' \\describe{\n#'  \\item{fips}{FIPS code for each state}\n#'  \\item{year}{Year of measurement}\n#'  \\item{qtr}{Quarter (1-4) of measurement}\n#'  \\item{state}{Name of State}\n#'  \\item{gdp}{Gross State Product (millions of $) Values before 2005 are linearly interpolated between years}\n#'  \\item{revenuepop}{State and local revenue per capita}\n#'  \\item{rev_state_total}{State total general revenue (millions of $)}\n#'  \\item{rev_local_total}{Local total general revenue (millions of $)}\n#'  \\item{popestimate}{Population estimate}\n#'  \\item{qtrly_estabs_count}{Count of establishments for a given quarter}\n#'  \\item{month1_emplvl, month2_emplvl, month3_emplvl}{ Employment level for first, second, and third months of a given quarter}\n#'  \\item{total_qtrly_wages}{Total wages for a givne quarter}\n#'  \\item{taxable_qtrly_wage}{Taxable wages for a given quarter}\n#'  \\item{avg_wkly_wage}{Average weekly wage for a given quarter}\n#'  \\item{year_qtr}{Year and quarter combined into one continuous variable}\n#'  \\item{treated}{Whether the state passed tax cuts before the given year and quareter}\n#'  \\item{lngdpcapita}{Natural log of GDP per capita}\n#'  \\item{emplvlcapita}{Average employment level per capita}\n#'  \\item{Xcapita}{Per capita value of X}\n#'  \\item{abb}{State abbreviation}\n#' }\n\"kansas\""
  },
  {
    "path": "R/eligible_donors.R",
    "content": "##############################################################################\n## Code to get eligible donor units based on covariates\n##############################################################################\n\nget_donors <- function(X, y, trt, Z, time_cohort, n_lags,\n                       n_leads, how = \"knn\", \n                       exact_covariates = NULL, ...) {\n\n  # first get eligible donors by treatment time\n  donors <- get_eligible_donors(trt, time_cohort, n_leads)\n\n  # get donors with no NA values\n  nona_donors <- get_nona_donors(X, y, trt, n_lags, n_leads, time_cohort)\n\n  donors <- lapply(1:length(donors),\n                     function(j) donors[[j]] & nona_donors[[j]])\n\n  # if Z isn't NULL, futher restrict the donors by matching\n  if(!is.null(Z)) {\n    if(ncol(Z) != 0) {\n      donors <- get_matched_donors(trt, Z, donors, how, exact_covariates, ...)\n    }\n  }\n\n  return(donors)\n}\n\n\nget_eligible_donors <- function(trt, time_cohort, n_leads) {\n\n    # get treatment times\n    if(time_cohort) {\n        grps <- unique(trt[is.finite(trt)])\n    } else {\n        grps <- trt[is.finite(trt)]\n    }\n\n    J <- length(grps)\n\n    # only allow weights on donors treated after n_leads\n    donors <- lapply(1:J, function(j) trt > n_leads + grps[j])\n\n    return(donors)\n}\n\n#' Get donors that don't have missing outcomes where treated units have outcomes\nget_nona_donors <- function(X, y, trt, n_lags, n_leads, time_cohort) {\n\n  n <- length(trt)\n  # find na treatment times\n  fulldat <- cbind(X, y)\n  is_na <- is.na(fulldat[is.finite(trt), , drop = F])\n  # aggregate by time cohort\n  if(time_cohort) {\n    grps <- unique(trt[is.finite(trt)])\n    # if doing a time cohort, convert the boolean mask\n    finite_trt <- trt[is.finite(trt)]\n    is_na <- t(sapply(grps,\n                    function(tj) apply(is_na[finite_trt == tj, , drop = F],\n                                       2, all)))\n  } else {\n      grps <- trt[is.finite(trt)]\n  }\n  not_na <- !is.na(fulldat)\n  J <- length(grps)\n  lapply(1:J,\n             function(j) {\n               idxs <- max(grps[j] - n_lags + 1, 1):min(grps[j] + n_leads,\n                                                        ncol(fulldat))\n               isna_j <- is_na[j, idxs]\n               apply(not_na[, idxs, drop = F][, !isna_j, drop = F], 1, all)\n        }) -> donors\n\n  return(donors)\n}\n\nget_matched_donors <- function(trt, Z, donors, how, exact_covariates = NULL, k = NULL, ...) {\n\n  J <- sum(is.finite(trt))\n  trt_idx <- which(is.finite(trt))\n  if(is.null(exact_covariates)) {\n    if(how == \"exact\") {\n      return(\n        lapply(1:J,\n            function(j) donors[[j]] & apply(t(Z) == Z[trt_idx[j], ], 2, all)\n        )\n      )\n    } else if(how == \"knn\") {\n        return(get_knn_donors(trt, Z, donors, k))\n    } else {\n      stop(\"Option for exact matching must be in ('exact', 'knn')\")\n    }\n  } else {\n        if(how == \"exact\") {\n      return(\n        lapply(1:J,\n            function(j) donors[[j]] & apply(t(Z) == Z[trt_idx[j], \n                                                   exact_covariates], 2, all)\n        )\n      )\n    } else if(how == \"knn\") {\n        donors <- lapply(1:J,\n            function(j) { donors[[j]] &\n              apply(t(Z[, exact_covariates, drop = F]) == \n                Z[trt_idx[j],exact_covariates], 2, all)\n            }\n              )\n        approx_covs <- which(!colnames(Z) %in% exact_covariates)\n        if(length(approx_covs != 0)) {\n          return(get_knn_donors(trt, Z[, approx_covs, drop = F], donors, k))\n        } else {\n          return(donors)\n        }\n        \n    } else {\n      stop(\"Option for exact matching must be in ('exact', 'knn')\")\n    }\n  }\n\n}\n\nget_knn_donors <- function(trt, Z, donors, k) {\n\n  if(is.null(k)) {\n    stop(\"Number of neighbors for knn not selected, please choose k.\")\n  }\n  # knn matching within time cohort\n  trt_idxs <- which(is.finite(trt))\n  lapply(1:length(trt_idxs), \n        function(j) {\n          idx <- trt_idxs[j]\n          # idxs for treated units treated at time tj\n          Z_tj <- Z[idx, , drop = F]\n\n          # get donors for treated cohort\n          donors_tj <- donors[[j]]\n          Z_donors_tj <- Z[donors_tj, , drop = F]\n          # check that k is less than the number of donors\n          # if not, warn and set k to be the number of donors - 1\n          if(k >= nrow(Z_donors_tj)) {\n            warning(paste(\"Number of potential donor units is less than\",\n                          \"the number of required matches,\",\n                          \"returning all donors as matches\"))\n            return(donors_tj)\n          }\n          # do knn matching\n          nn <- FNN::get.knnx(data = Z_donors_tj, query = Z_tj, k = k)\n          # keep track of which indices these are\n          donors_j <- logical(length(donors_tj))\n          true_idx <- which(donors_tj)[nn$nn.index[1, ]]\n          donors_j[true_idx] <- TRUE\n          return(donors_j)\n         }) -> matches\n  names(matches) <- trt_idxs\n  return(matches)\n}"
  },
  {
    "path": "R/fit_synth.R",
    "content": "#######################################################\n# Helper scripts to fit synthetic controls to simulations\n#######################################################\n\n#' Make a V matrix from a vector (or null)\nmake_V_matrix <- function(t0, V) {\n  if(is.null(V)) {\n        V <- diag(rep(1, t0))\n    } else if(is.vector(V)) {\n        if(length(V) != t0) {\n          stop(paste(\"`V` must be a vector with\", t0, \"elements or a\", t0, \n                     \"x\", t0, \"matrix\"))\n        }\n        V <- diag(V)\n    } else if(ncol(V) == 1 & nrow(V) == t0) {\n        V <- diag(c(V))\n    } else if(ncol(V) == t0 & nrow(V) == 1) {\n        V <- diag(c(V))\n    } else if(nrow(V) == t0) {\n    } else {\n        stop(paste(\"`V` must be a vector with\", t0, \"elements or a\", t0, \n                     \"x\", t0, \"matrix\"))\n    }\n\n  return(V)\n}\n\n#' Fit synthetic controls on outcomes after formatting data\n#' @param synth_data Panel data in format of Synth::dataprep\n#' @param V Matrix to scale the obejctive by\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Synth weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#' }\nfit_synth_formatted <- function(synth_data, V = NULL) {\n\n\n    t0 <- dim(synth_data$Z0)[1]\n    ## if no  is supplied, set equal to 1\n\n    V <- make_V_matrix(t0, V)\n\n    weights <- synth_qp(synth_data$X1, t(synth_data$X0), V)\n    l2_imbalance <- sqrt(sum((synth_data$Z0 %*% weights - synth_data$Z1)^2))\n\n    ## primal objective value scaled by least squares difference for mean\n    uni_w <- matrix(1/ncol(synth_data$Z0), nrow=ncol(synth_data$Z0), ncol=1)\n    unif_l2_imbalance <- sqrt(sum((synth_data$Z0 %*% uni_w - synth_data$Z1)^2))\n    scaled_l2_imbalance <- l2_imbalance / unif_l2_imbalance\n\n    return(list(weights=weights,\n                l2_imbalance=l2_imbalance,\n                scaled_l2_imbalance=scaled_l2_imbalance))\n}\n\n#' Solve the synth QP directly\n#' @param X1 Target vector\n#' @param X0 Matrix of control outcomes\n#' @param V Scaling matrix\n#' @noRd\nsynth_qp <- function(X1, X0, V) {\n    \n    Pmat <- X0 %*% V %*% t(X0)\n    qvec <- - t(X1) %*% V %*% t(X0)\n\n    n0 <- nrow(X0)\n    A <- rbind(rep(1, n0), diag(n0))\n    l <- c(1, numeric(n0))\n    u <- c(1, rep(1, n0))\n\n    settings = osqp::osqpSettings(verbose = FALSE,\n                                  eps_rel = 1e-8,\n                                  eps_abs = 1e-8)\n    sol <- osqp::solve_osqp(P = Pmat, q = qvec,\n                            A = A, l = l, u = u, \n                            pars = settings)\n\n    return(sol$x)\n}\n"
  },
  {
    "path": "R/format.R",
    "content": "################################################################################\n## Scripts to format panel data into matrices\n################################################################################\n\n#' Format \"long\" panel data into \"wide\" program evaluation matrices\n#' @param outcome Name of outcome column\n#' @param trt Name of treatment column\n#' @param unit Name of unit column\n#' @param time Name of time column\n#' @param t_int Time of intervention\n#' @param data Panel data as dataframe\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"X\"}{Matrix of pre-treatment outcomes}\n#'          \\item{\"trt\"}{Vector of treatment assignments}\n#'          \\item{\"y\"}{Matrix of post-treatment outcomes}\n#'         }\nformat_data <- function(outcome, trt, unit, time, t_int, data) {\n\n    ## pre treatment outcomes\n    X <- data %>%\n        filter(!!time < t_int) %>%\n        select(!!unit, !!time, !!outcome) %>%\n        spread(!!time, !!outcome) %>%\n        select(-!!unit) %>%\n        as.matrix()\n    \n\n\n    ## post treatment outcomes\n    y <- data %>%\n        filter(!!time >= t_int) %>%\n        select(!!unit, !!time, !!outcome) %>%\n        spread(!!time, !!outcome) %>%\n        select(-!!unit) %>%\n        as.matrix()\n\n\n    ## treatment status\n    trt <- data %>%\n        select(!!unit, !!trt) %>%\n        group_by(!!unit) %>%\n        summarise(trt = max(!!trt)) %>%\n        ungroup() %>%\n        pull(trt)\n\n    return(list(X=X, trt=trt, y=y))\n}\n\n\n#' Format \"long\" panel data into \"wide\" program evaluation matrices\n#' @param outcomes Vectors of names of outcome columns\n#' @param trt Name of treatment column\n#' @param unit Name of unit column\n#' @param time Name of time column\n#' @param t_int Time of intervention\n#' @param data Panel data as dataframe\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"X\"}{List of matrices of pre-treatment outcomes}\n#'          \\item{\"trt\"}{Vector of treatment assignments}\n#'          \\item{\"y\"}{List of matrices of post-treatment outcomes}\n#'         }\nformat_data_multi <- function(outcomes, trt, unit, time, t_int, data) {\n\n\n    lapply(outcomes, \n        function(outcome) format_data(outcome, trt, unit, \n                                     time, t_int, data)\n          ) -> formats\n\n    # X <- simplify2array(lapply(formats, function(x) x$X))\n    # y <- simplify2array(lapply(formats, function(x) x$y))\n    # X <- lapply(formats, function(x) t(na.omit(t(x$X))))\n    X <- lapply(formats, `[[`, \"X\")\n    y <- lapply(formats, function(x) t(na.omit(t(x$y))))\n    trt <- formats[[1]]$trt\n    return(list(X = X, trt = trt, y = y))\n}\n\n\n\n\n#' Format \"long\" panel data into \"wide\" program evaluation matrices with staggered adoption\n#' @param outcome Name of outcome column\n#' @param trt Name of treatment column\n#' @param unit Name of unit column\n#' @param time Name of time column\n#' @param data Panel data as dataframe\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"X\"}{Matrix of pre-treatment outcomes}\n#'          \\item{\"trt\"}{Vector of treatment assignments}\n#'          \\item{\"y\"}{Matrix of post-treatment outcomes}\n#'         }\nformat_data_stag <- function(outcome, trt, unit, time, data) {\n\n    # arrange data by time first\n    data <- data %>% arrange(!!time)\n      \n    ## get first treatment times\n    trt_time <- data %>%\n        group_by(!!unit) %>%\n        summarise(trt_time=(!!time)[(!!trt) == 1][1]) %>%\n        mutate(trt_time=replace_na(as.numeric(trt_time), Inf))\n    \n\n    t_int <- trt_time %>% filter(is.finite(trt_time)) %>%\n        summarise(t_int=max(trt_time)) %>% pull(t_int)\n\n    ## ## boolean mask of available data for treatment groups\n    ## mask <- data %>% inner_join(trt_time %>%\n    ##                             filter(is.finite(trt_time))) %>%\n    ##     filter(!!time < t_int) %>%\n    ##     mutate(trt=1-!!trt) %>%\n    ##     select(!!unit, !!time, trt_time, trt) %>%\n    ##     spread(!!time, trt) %>% \n    ##     group_by(trt_time) %>% \n    ##     summarise_all(list(max)) %>%\n    ##     arrange(trt_time) %>% \n    ##     select(-trt_time, -!!unit) %>%\n    ##     as.matrix()\n\n    ## boolean mask of available data for treatment groups\n    mask <- data %>% inner_join(trt_time %>%\n                                filter(is.finite(trt_time)),\n                                by = rlang::as_name(unit)) %>%\n        filter(!!time < t_int) %>%\n        mutate(trt=1-!!trt) %>%\n        select(!!unit, !!time, trt_time, trt) %>%\n        spread(!!time, trt) %>% \n        ## arrange(!!unit) %>% \n        select(-trt_time, -!!unit) %>%\n        as.matrix()\n    \n    # outcomes as a matrix\n    Xy <- data %>%\n        select(!!unit, !!time, !!outcome) %>%\n        spread(!!time, !!outcome) %>%\n        select(-!!unit) %>%\n        as.matrix()\n\n    pre_times <- data %>% filter(!!time < t_int) %>%\n        distinct(!!time) %>% pull(!!time)\n    post_times <- data %>% filter(!!time >= t_int) %>%\n        distinct(!!time) %>% pull(!!time)\n    X <- Xy[, as.character(pre_times), drop = F]\n    y <- Xy[, as.character(post_times), drop = F]\n\n    if(nrow(X) != nrow(y)) {\n      stop(\"There are not the same number of units after the last unit is treated as before the last unit is treated\")\n    }\n\n    t_vec <- data %>% pull(!!time) %>% unique() %>% sort()\n    trt <- sapply(trt_time$trt_time, function(x) which(t_vec == x)-1) %>%\n        as.numeric() %>%\n        replace_na(Inf)\n   \n\n    units <- data %>%\n        filter(!!time < t_int) %>%\n        select(!!unit, !!time, !!outcome) %>%\n        spread(!!time, !!outcome) %>%\n        pull(!!unit)\n\n    \n    return(list(X=X,\n                trt=trt,\n                y=y,\n                mask=mask,\n                time = t_vec,\n                units=units))\n}\n\n\n#' Format program eval matrices into synth form\n#'\n#' @param X Matrix of pre-treatment outcomes\n#' @param trt Vector of treatment assignments\n#' @param y Matrix of post-treatment outcomes\n#' @noRd\n#' @return List with data formatted as Synth::dataprep\nformat_synth <- function(X, trt, y) {\n\n\n    synth_data <- list()\n\n    ## pre-treatment values as covariates\n    synth_data$Z0 <- t(X[trt==0,,drop=F])\n\n    ## average treated units together\n    synth_data$Z1 <- as.matrix((colMeans(X[trt==1,,drop=F])), ncol=1)\n\n    ## combine everything together also\n    synth_data$Y0plot <- t(cbind(X[trt==0,,drop=F], y[trt==0,,drop=F]))\n    synth_data$Y1plot <- as.matrix(colMeans(\n        cbind(X[trt==1,,drop=F], y[trt==1,,drop=F])), ncol=1)\n\n\n    ## predictors are pre-period outcomes\n    synth_data$X0 <- synth_data$Z0\n    synth_data$X1 <- synth_data$Z1\n\n    return(synth_data)\n    \n}\n\n#' Remove unit means \n#' @param wide_data X, y, trt\n#' @param synth_data List with data formatted as Synth::dataprep\n#' @noRd\ndemean_data <- function(wide_data, synth_data) {\n\n    # pre treatment means\n    means <- rowMeans(wide_data$X)\n\n    new_wide_data <- list()\n    new_X <- wide_data$X - means\n    trt <- wide_data$trt\n\n    new_wide_data$X <- new_X\n    new_wide_data$y <- wide_data$y - means\n    new_wide_data$trt <- trt\n\n    new_synth_data <- list()\n    new_synth_data$X0 <- t(new_X[trt == 0,, drop = FALSE])\n    new_synth_data$Z0 <- new_synth_data$X0\n    new_synth_data$X1 <- as.matrix((colMeans(new_X[trt==1,,drop = F])), \n                                   ncol = 1)\n    new_synth_data$Z1 <- new_synth_data$X1\n\n\n    # estimate post-treatment as pre-treatment means\n    mhat <- replicate(ncol(wide_data$X) + ncol(wide_data$y), means)\n\n    return(list(wide = new_wide_data,\n                synth_data = new_synth_data,\n                mhat = mhat))\n}\n\n#' Helper function to extract covariate matrix from data\n#' @param form Formula as outcome ~ treatment | covariates\n#' @param unit Name of unit column\n#' @param time Name of time column\n#' @param t_int Time of intervention\n#' @param data Panel data as dataframe\n#' @param cov_agg Covariate aggregation function\n#' @noRd\nextract_covariates <- function(form, unit, time, t_int, data, cov_agg) {\n\n    ## if no aggregation functions, use the mean (omitting NAs)\n    if(is.null(cov_agg)) {\n        cov_agg <- c(function(x) mean(x, na.rm=T))\n    }\n\n    cov_form <- update(formula(delete.response(terms(form, rhs=2, data=data))),\n                        ~. - 1) ## ensure that there is no intercept\n\n    ## pull out relevant covariates and aggregate\n    pre_data <- data %>% \n        filter(!! (time) < t_int)\n\n    model.matrix(cov_form,\n                    model.frame(cov_form, pre_data,\n                                na.action=NULL) ) %>%\n        data.frame() %>%\n        mutate(unit=pull(pre_data, !!unit)) %>%\n        group_by(unit) %>%\n        summarise_all(cov_agg) -> Z\n\n    # recombine with any missing units and convert to matrix\n    data %>% distinct(!!unit) %>%\n      rename(unit = !!unit) %>%\n      left_join(Z, by = \"unit\") %>%\n      arrange(unit) %>%\n      select(-unit) %>%\n      as.matrix() -> Z\n    \n    if(nrow(distinct(data, !!unit))  != nrow(Z)) {\n      stop(\"Some units missing all covariate data\")\n    }\n\n    # check if any covariates have no variation\n    Zsds <- apply(Z, 2, sd)\n\n    if(any(Zsds == 0)) {\n      zero_covs <- paste(colnames(Z)[Zsds == 0], collapse = \", \")\n      stop(paste(\"The following covariates have no variation across units:\",\n                 zero_covs))\n    }\n    return(Z)\n}\n\n#' Check that we can actually run multisynth on the data\n#' @param wide Output of format_data_stag\n#' @param fixedeff Whether to include a unit fixed effect\n#' @param n_leads How long past treatment effects should be estimated for, default is number of post treatment periods for last treated unit\n#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods\ncheck_data_stag <- function(wide, fixedeff, n_leads, n_lags) {\n\n  # If there are less than 5 pre-treatment outcomes, give a warning\n  less_5 <- wide$units[wide$trt < 5]\n  warn_msg <- \"\"\n  if(length(less_5) != 0) {\n    warn_msg <- paste0(\n      warn_msg,\n      \"The following units have less than 5 pre-treatment outcomes: (\",\n      paste(less_5, collapse = \",\"),\n      \"). Be cautious!\"\n    )\n  }\n\n  # check if there are any always treated units\n  always_trt <- wide$units[wide$trt == 0]\n\n  # If including a fixed effect, check that there is more than one pretreatment\n  # outcome for each unit\n  n1 <- wide$units[wide$trt == 1]\n\n  err_msg <- \"\"\n  if(length(always_trt) != 0) {\n    err_msg <- paste0(\n      err_msg,\n      \"The following units are always treated and should be removed: (\",\n      paste(always_trt, collapse = \",\"),\n      \")\\n\")\n  }\n\n  if(length(n1) != 0 & fixedeff) {\n    if(nchar(err_msg) > 0) {\n      err_msg <- paste0(err_msg, \"  Also: \")\n    }\n    err_msg <- paste0(\n      err_msg,\n      \"You are including a fixed effect with `fixedeff = T`, but the \",\n      \"following units only have one pre-treatment outcome: (\",\n      paste(n1, collapse = \",\"),\n      \"). Either remove these units or set `fixedeff = F`.\\n\"\n    )\n  }\n  # check if there are never treated units\n  if(max(wide$trt) < ncol(wide$X) + ncol(wide$y)) {\n    if(nchar(err_msg) > 0) {\n      err_msg <- paste0(err_msg, \"  Also: \")\n    }\n    err_msg <- paste0(\n      err_msg,\n      \"All units are eventually treated. The last treatment time is \",\n      wide$time[max(wide$trt)],\n      \". To run multisynth, remove all periods after this time. \",\n      \"Units treated at this time will be considered 'never treated' in the \",\n      \"narrowed sample.\\n\"\n    )\n  }\n\n  if(nchar(warn_msg) > 0) {\n    warning(warn_msg)\n  }\n  if(nchar(err_msg) > 0) {\n    stop(err_msg)\n  }\n\n}"
  },
  {
    "path": "R/globalVariables.R",
    "content": "utils::globalVariables(c(\"time\", \"val\", \"post\", \"weight\", \".\", \"Time\",\n                         \"Estimate\", \"Std.Error\", \"Level\", \"last_time\",\n                         \"is_avg\", \"label\", \"Outcome\", \"unit\", \"obs\",\n                         \"lambdas\", \"errors_se\",\n                         \"upper_bound\", \"lower_bound\"))"
  },
  {
    "path": "R/highdim.R",
    "content": "################################################################################\n## Methods to use flexible outcome models\n################################################################################\n\n##### Augmented SCM with general outcome models\n\n#' Use zero weights, do nothing but output everything in the right way\n#' @param synth_data Panel data in format of Synth::dataprep\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Synth weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#' }\nfit_zero_weights <- function(synth_data) {\n    \n    ## Imbalance is uniform weights imbalance\n    uni_w <- matrix(1/ncol(synth_data$Z0), nrow=ncol(synth_data$Z0), ncol=1)\n    unif_l2_imbalance <- sqrt(sum((synth_data$Z0 %*% uni_w - synth_data$Z1)^2))\n    scaled_l2_imbalance <- 1\n    \n    return(list(weights=rep(0, ncol(synth_data$Z0)),\n                l2_imbalance=unif_l2_imbalance,\n                scaled_l2_imbalance=scaled_l2_imbalance))\n}\n\n\n\n#' Fit E[Y(0)|X] and for each post-period and balance pre-period\n#'\n#' @param wide_data Output of `format_ipw`\n#' @param synth_data Output of `synth_data`\n#' @param fit_progscore Function to fit prognostic score\n#' @param fit_weights Function to fit synth weights\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Ridge ASCM weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#'          \\item{\"mhat\"}{Outcome model estimate}\n#' }\nfit_augsyn_formatted <- function(wide_data, synth_data,\n                                fit_progscore, fit_weights, ...) {\n\n\n    X <- wide_data$X\n    y <- wide_data$y\n    trt <- wide_data$trt\n    \n    ## fit prognostic scores\n    fitout <- do.call(fit_progscore,\n                          list(X=X, y=y, trt=trt, ...))\n    \n    ## fit synth\n    syn <- fit_weights(synth_data)\n\n    syn$params <- fitout$params\n\n    syn$mhat <- fitout$y0hat\n    \n    return(syn)\n}\n\n\n#' Fit outcome model and balance pre-period\n#' @param wide_data Output of `format_ipw`\n#' @param synth_data Output of `synth_data`\n#' @param progfunc What function to use to impute control outcomes\n#'                 EN=Elastic Net, RF=Random Forest, GSYN=gSynth,\n#'                 Comp=softImpute, MCP=MCPanel, CITS=CITS\n#'                 CausalImpact=Bayesian structural time series with CausalImpact\n#'                 seq2seq=Sequence to sequence learning with feedforward nets\n#' @param scm Whether the SCM weighting function is used\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Ridge ASCM weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#'          \\item{\"mhat\"}{Outcome model estimate}\n#' }\nfit_augsyn <- function(wide_data, synth_data,\n                       progfunc=c(\"EN\", \"RF\", \"GSYN\", \"MCP\",\"CITS\", \"CausalImpact\", \"seq2seq\"),\n                       scm=T, ...) {\n    ## prognostic score and weight functions to use\n    progfunc = tolower(progfunc)\n    if(progfunc == \"en\") {\n        progf <- fit_prog_reg\n    } else if(progfunc == \"rf\") {\n        progf <- fit_prog_rf\n    } else if(progfunc == \"gsyn\"){\n        progf <- fit_prog_gsynth\n    } else if(progfunc == \"mcp\"){\n        progf <- fit_prog_mcpanel\n    } else if(progfunc == \"cits\") {\n        progf <- fit_prog_cits\n    } else if(progfunc == \"causalimpact\") {\n        progf <- fit_prog_causalimpact\n    } else if(progfunc == \"seq2seq\"){\n        progf <- fit_prog_seq2seq\n    } else {\n        stop(\"progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq'\")\n    }\n\n    if(scm) {\n        weightf <- fit_synth_formatted\n    } else {\n        ## still fit synth even if none\n        ## TODO: This is a dumb wasteful hack\n        weightf <- fit_zero_weights\n    }\n    return(fit_augsyn_formatted(wide_data, synth_data,\n                                progf, weightf, ...))\n}\n\n\n\n### Combine synth and gsynth by balancing pre-period residuals\n#' Fit outcome model and balance residuals\n#'\n#' @param wide_data Output of `format_data`\n#' @param synth_data Output of `format_synth`\n#' @param fit_progscore Function to fit prognostic score\n#' @param fit_weights Function to fit synth weights\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Ridge ASCM weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#'          \\item{\"mhat\"}{Outcome model estimate}\n#' }\nfit_residaug_formatted <- function(wide_data, synth_data,\n                                  fit_progscore, fit_weights, ...) {\n\n\n    X <- wide_data$X\n    y <- wide_data$y\n    trt <- wide_data$trt\n\n    ## fit prognostic scores\n    fitout <- do.call(fit_progscore, list(X=X, y=y, trt=trt, ...))\n\n    \n    y0hat <- fitout$y0hat\n\n    ## get residuals\n    ctrl_resids <- fitout$params$ctrl_resids\n    trt_resids <- fitout$params$trt_resids\n    \n    ## replace outcomes with pre-period residuals\n    t0 <- dim(X)[2]\n\n    synth_data$Z0 <- ctrl_resids[1:t0, ]\n    synth_data$Z1 <- as.matrix(trt_resids[1:t0])\n    \n    ## fit synth weights\n    syn <- fit_weights(synth_data)\n\n    syn$params <- fitout$params    \n\n    ## return predicted values for treatment and control\n    syn$mhat <- y0hat\n    \n    return(syn)\n}\n#' Fit outcome model and balance residuals\n#'\n#' @param wide_data Output of `format_data`\n#' @param synth_data Output of `format_synth`\n#' @param progfunc What function to use to impute control outcomes\n#'                 GSYN=gSynth, MCP=MCPanel,\n#'                 CITS=Comparative interrupted time series\n#'                 CausalImpact=Bayesian structural time series with CausalImpact\n#' @param weightfunc What function to use to fit weights\n#'                   SCM=Vanilla Synthetic Controls\n#'                   NONE=No reweighting, just outcome model\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Ridge ASCM weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#'          \\item{\"mhat\"}{Outcome model estimate}\n#' }\nfit_residaug <- function(wide_data, synth_data,\n                        progfunc=c(\"GSYN\", \"MCP\", \"CITS\", \"CausalImpact\"),\n                        weightfunc=c(\"SC\",\"ENT\", \"SVD\", \"NONE\"), ...) {\n\n    ## prognostic score and weight functions to use\n    if(progfunc == \"GSYN\"){\n        progf <- fit_prog_gsynth\n    } else if(progfunc == \"MCP\"){\n        progf <- fit_prog_mcpanel\n    } else if(progfunc == \"CITS\") {\n        progf <- fit_prog_cits\n    } else if(progfunc == \"CausalImpact\") {\n        progf <- fit_prog_causalimpact\n    } else {\n        stop(\"progfunc must be one of 'GSYN', 'MCP', 'CITS', 'CausalImpact'\")\n    }\n\n    \n    ## weight function to use\n    if(weightfunc == \"SCM\") {\n        weightf <- fit_synth_formatted\n    } else if(weightfunc == \"NONE\") {\n        ## still fit synth even if none\n        ## TODO: This is a dumb wasteful hack\n        weightf <- fit_synth_formatted\n    } else {\n        stop(\"weightfunc must be one of 'SCM', 'NONE'\")\n    }\n\n    return(fit_residaug_formatted(wide_data, synth_data,\n                                  progf, weightf, ...))\n}\n\n"
  },
  {
    "path": "R/inference.R",
    "content": "################################################################################\n## Code for inference\n################################################################################\n\n#' Jackknife+ algorithm over time\n#' @param ascm Fitted `augsynth` object\n#' @param alpha Confidence level\n#' @param conservative Whether to use the conservative jackknife+ procedure\n#' @return List that contains:\n#'         \\itemize{\n#'          \\item{\"att\"}{Vector of ATT estimates}\n#'          \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n#'          \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n#'          \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n#'          \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n#'          \\item{\"alpha\"}{Level of confidence interval}\n#'         }\ntime_jackknife_plus <- function(ascm, alpha = 0.05, conservative = F) {\n    wide_data <- ascm$data\n    synth_data <- ascm$data$synth_data\n    n <- nrow(wide_data$X)\n    n_c <- dim(synth_data$Z0)[2]\n    Z <- wide_data$Z\n\n    t0 <- dim(synth_data$Z0)[1]\n    tpost <- ncol(wide_data$y)\n    t_final <- dim(synth_data$Y0plot)[1]\n\n    jack_ests <- lapply(1:t0, \n        function(tdrop) {\n            # drop unit i\n            new_data <- drop_time_t(wide_data, Z, tdrop)\n            # refit\n            new_ascm <- do.call(fit_augsynth_internal,\n                    c(list(wide = new_data$wide,\n                            synth_data = new_data$synth_data,\n                            Z = new_data$Z,\n                            progfunc = ascm$progfunc,\n                            scm = ascm$scm,\n                            fixedeff = ascm$fixedeff),\n                        ascm$extra_args))\n            # get ATT estimates and held out error for time t\n            # t0 is prediction for held out time\n            est <- predict(new_ascm, att = F)[(t0 +1):t_final]\n            est <- c(est, mean(est))\n            err <- c(colMeans(wide_data$X[wide_data$trt == 1,\n                                         tdrop,\n                                         drop = F]) -\n                    predict(new_ascm, att = F)[t0])\n            list(err, rbind(est + abs(err), est - abs(err), est + err, est))\n        })\n    # get errors and jackknife distribution\n    held_out_errs <- vapply(jack_ests, `[[`, numeric(1), 1)\n    jack_dist <- vapply(jack_ests, `[[`,\n                        matrix(0, nrow = 4, ncol = tpost + 1), 2)\n\n    out <- list()\n    att <- predict(ascm, att = T)\n    out$att <- c(att, \n                 mean(att[(t0 + 1):t_final]))\n    # held out ATT\n    out$heldout_att <- c(held_out_errs, \n                          att[(t0 + 1):t_final], \n                          mean(att[(t0 + 1):t_final]))\n\n    # out$se <- rep(NA, 10 + tpost)\n    if(conservative) {\n        qerr <- stats::quantile(abs(held_out_errs), 1 - alpha)\n        out$lb <- c(rep(NA, t0), apply(jack_dist[4,,], 1, min) - qerr)\n        out$ub <- c(rep(NA, t0), apply(jack_dist[4,,], 1, max) + qerr)\n    } else {\n        out$lb <- c(rep(NA, t0), apply(jack_dist[2,,], 1, stats::quantile, alpha / 2))\n        out$ub <- c(rep(NA, t0), apply(jack_dist[1,,], 1, stats::quantile, 1 - alpha / 2))\n    }\n    # shift back to ATT scale\n    y1 <- predict(ascm, att = F) + att\n    y1 <-  c(y1, mean(y1[(t0 + 1):t_final]))\n    shifted_lb <- y1 - out$ub\n    shifted_ub <- y1 - out$lb\n    out$lb <- shifted_lb\n    out$ub <- shifted_ub\n    out$alpha <- alpha\n\n\n    return(out)\n}\n\n#' Drop time period from pre-treatment data\n#' @param wide_data (X, y, trt)\n#' @param Z Covariates matrix\n#' @param t_drop Time to drop\n#' @noRd\ndrop_time_t <- function(wide_data, Z, t_drop) {\n\n        new_wide_data <- list()\n        new_wide_data$trt <- wide_data$trt\n        new_wide_data$X <- wide_data$X[, -t_drop, drop = F]\n        new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F], \n                                 wide_data$y)\n\n        X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]\n        x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,,\n                                              drop = F]),\n                     ncol=1)\n        y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]\n        y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])\n\n        new_synth_data <- list()\n        new_synth_data$Z0 <- t(X0)\n        new_synth_data$X0 <- t(X0)\n        new_synth_data$Z1 <- x1\n        new_synth_data$X1 <- x1\n\n        return(list(wide_data = new_wide_data,\n                    synth_data = new_synth_data,\n                    Z = Z)) \n}\n\n#' Conformal inference procedure to compute p-values and point-wise confidence intervals\n#' @param ascm Fitted `augsynth` object\n#' @param alpha Confidence level\n#' @param stat_func Function to compute test statistic\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations; default is \"block\"\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param grid_size Number of grid points to use when inverting the hypothesis test\n#' @return List that contains:\n#'         \\itemize{\n#'          \\item{\"att\"}{Vector of ATT estimates}\n#'          \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n#'          \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n#'          \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n#'          \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n#'          \\item{\"p_val\"}{p-value for test of no post-treatment effect}\n#'          \\item{\"alpha\"}{Level of confidence interval}\n#'         }\nconformal_inf <- function(ascm, alpha = 0.05, \n                          stat_func = NULL, type = \"iid\",\n                          q = 1, ns = 1000, grid_size = 50) {\n  wide_data <- ascm$data\n  synth_data <- ascm$data$synth_data\n  n <- nrow(wide_data$X)\n  n_c <- dim(synth_data$Z0)[2]\n  Z <- wide_data$Z\n\n  t0 <- dim(synth_data$Z0)[1]\n  tpost <- ncol(wide_data$y)\n  t_final <- dim(synth_data$Y0plot)[1]\n\n  # grid of nulls\n  att <- predict(ascm, att = T)\n  post_att <- att[(t0 +1):t_final]\n  post_sd <- sqrt(mean(post_att ^ 2))\n  # iterate over post-treatment periods to get pointwise CIs\n  vapply(1:tpost,\n         function(j) {\n          # fit using t0 + j as a pre-treatment period and get reisduals\n          new_wide_data <- wide_data\n          new_wide_data$X <- cbind(wide_data$X, wide_data$y[, j, drop = TRUE])\n          if(tpost > 1) {\n            new_wide_data$y <- wide_data$y[, -j, drop = FALSE]\n          } else {\n            # set the post period has to be *something*\n            new_wide_data$y <- matrix(1, nrow = n, ncol = 1)\n          }\n\n\n          # make a grid around the estimated ATT\n          grid <- seq(att[t0 + j] - 2 * post_sd, att[t0 + j] + 2 * post_sd,\n                      length.out = grid_size)\n          compute_permute_ci(new_wide_data, ascm, grid, 1, alpha, type,\n                             q, ns, stat_func)\n         },\n         numeric(3)) -> cis\n\n  # test a null post-treatment effect\n  new_wide_data <- wide_data\n  new_wide_data$X <- cbind(wide_data$X, wide_data$y)\n  new_wide_data$y <- matrix(1, nrow = n, ncol = 1)\n  null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y), \n                                 type, q, ns, stat_func)\n  out <- list()\n  att <- predict(ascm, att = T)\n  out$att <- c(att, mean(att[(t0 + 1):t_final]))\n  # out$se <- rep(NA, t_final)\n  # out$sigma <- NA\n  out$lb <- c(rep(NA, t0), cis[1, ], NA)\n  out$ub <- c(rep(NA, t0), cis[2, ], NA)\n  out$p_val <- c(rep(NA, t0), cis[3, ], null_p)\n  out$alpha <- alpha\n  return(out)\n}\n\n\n#' Conformal inference procedure to compute a confidence interval for a linear in time effect\n#' @param ascm Fitted `augsynth` object\n#' @param alpha Confidence level\n#' @param stat_func Function to compute test statistic\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations; default is \"iid\"\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param grid_size Number of grid points to use when inverting the hypothesis test\n#' @return List that contains:\n#'         \\itemize{\n#'          \\item{\"att\"}{Vector of ATT estimates}\n#'          \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n#'          \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n#'          \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n#'          \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n#'          \\item{\"p_val\"}{p-value for test of no post-treatment effect}\n#'          \\item{\"alpha\"}{Level of confidence interval}\n#'         }\nconformal_inf_linear <- function(ascm, alpha = 0.05, \n                          stat_func = NULL, type = \"iid\",\n                          q = 1, ns = 1000, grid_size = 50) {\n  wide_data <- ascm$data\n  synth_data <- ascm$data$synth_data\n  n <- nrow(wide_data$X)\n  n_c <- dim(synth_data$Z0)[2]\n  Z <- wide_data$Z\n\n  t0 <- dim(synth_data$Z0)[1]\n  tpost <- ncol(wide_data$y)\n  t_final <- dim(synth_data$Y0plot)[1]\n\n  # grid of nulls\n  att <- predict(ascm, att = T)\n  post_att <- att[(t0 +1):t_final]\n  post_second <- sqrt(mean(post_att^2))\n\n  # grid for slope\n  # use ols to get pilot estimate\n  ts <- 1:tpost\n  lm_out <- summary(lm(post_att ~ ts))$coefficients\n  # grid for intercept\n  grid_int <- seq(lm_out[1,1] - 2 * post_second,\n                  lm_out[1,1] + 2 * post_second,\n                  length.out = grid_size)\n  if(tpost == 2) {\n    warning(paste0(\"There are 2 post-treatment time periods, so a linear model has a perfect fit. A confidence interval for the slope may not be reasonable here.\"))\n    grid_slope <- seq(lm_out[2,1] - abs(lm_out[2,1]),\n                  lm_out[2,1] + abs(lm_out[2,1]),\n                  length.out = grid_size)\n  } else if(tpost <= 1) {\n    stop(\"There is only one post-treatment time period, so an intercept and a slope cannot be computed.\")\n  } else {\n    grid_slope <- seq(lm_out[2,1] - 4 * lm_out[2,2] * sqrt(tpost),\n                  lm_out[2,1] + 4 * lm_out[2,2] * sqrt(tpost),\n                  length.out = grid_size)\n  }\n\n  # test a null post-treatment effect\n  new_wide_data <- wide_data\n  new_wide_data$X <- cbind(wide_data$X, wide_data$y)\n  new_wide_data$y <- matrix(1, nrow = n, ncol = 1)\n  null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y), \n                                 type, q, ns, stat_func)\n\n  # confidence interval for linear in time treatment effects\n  cis <- compute_permute_ci_linear(new_wide_data, ascm, grid_int, grid_slope,\n                                   ncol(wide_data$y), alpha, type, q, ns, stat_func)\n\n  return(cis)\n}\n\n\n#' Compute conformal test statistics\n#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector\n#' @param ascm Fitted `augsynth` object\n#' @param h0 Null hypothesis to test\n#' @param post_length Number of post-treatment periods\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param stat_func Function to compute test statistic\n#' \n#' @return List that contains:\n#'         \\itemize{\n#'          \\item{\"resids\"}{Residuals after enforcing the null}\n#'          \\item{\"test_stats\"}{Permutation distribution of test statistics}\n#'          \\item{\"stat_func\"}{Test statistic function}\n#'         }\n#' @noRd\ncompute_permute_test_stats <- function(wide_data, ascm, h0,\n                                       post_length, type,\n                                       q, ns, stat_func) {\n  # format data\n  new_wide_data <- wide_data\n  t0 <- ncol(wide_data$X) - post_length\n  tpost <- t0 + post_length\n  # adjust outcomes for null\n  new_wide_data$X[wide_data$trt == 1,(t0 + 1):tpost ] <- new_wide_data$X[wide_data$trt == 1,(t0 + 1):tpost] - h0\n  X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]\n  x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),\n              ncol=1)\n\n  new_synth_data <- list()\n  new_synth_data$Z0 <- t(X0)\n  new_synth_data$X0 <- t(X0)\n  new_synth_data$Z1 <- x1\n  new_synth_data$X1 <- x1\n\n  # fit synth with adjusted data and get residuals\n  new_ascm <- do.call(fit_augsynth_internal,\n                    c(list(wide = new_wide_data,\n                            synth_data = new_synth_data,\n                            Z = wide_data$Z,\n                            progfunc = ascm$progfunc,\n                            scm = ascm$scm,\n                            fixedeff = ascm$fixedeff),\n                        ascm$extra_args))\n  resids <- predict(new_ascm, att = T)[1:tpost]\n  # permute residuals and compute test statistic\n  if(is.null(stat_func)) {\n    stat_func <- function(x) (sum(abs(x) ^ q)  / sqrt(length(x))) ^ (1 / q)\n  }\n  if(type == \"iid\") {\n    test_stats <- sapply(1:ns, \n                        function(x) {\n                          reorder <- sample(resids)\n                          stat_func(reorder[(t0 + 1):tpost])\n                        })\n  } else {\n    ## increment time by one step and wrap\n    test_stats <- sapply(1:tpost,\n                        function(j) {\n                          reorder <- resids[(0:tpost -1 + j) %% tpost + 1]\n                          stat_func(reorder[(t0 + 1):tpost])\n                        })\n  }\n  \n  return(list(resids = resids,\n              test_stats = test_stats,\n              stat_func = stat_func))\n}\n\n\n#' Compute conformal p-value\n#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector\n#' @param ascm Fitted `augsynth` object\n#' @param h0 Null hypothesis to test\n#' @param post_length Number of post-treatment periods\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param stat_func Function to compute test statistic\n#' \n#' @return Computed p-value\n#' @noRd\ncompute_permute_pval <- function(wide_data, ascm, h0,\n                                 post_length, type,\n                                 q, ns, stat_func) {\n  t0 <- ncol(wide_data$X) - post_length\n  tpost <- t0 + post_length\n  out <- compute_permute_test_stats(wide_data, ascm, h0,\n                                    post_length, type, q, ns, stat_func)\n  mean(out$stat_func(out$resids[(t0 + 1):tpost]) <= out$test_stats)\n}\n\n#' Compute conformal p-value\n#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector\n#' @param ascm Fitted `augsynth` object\n#' @param grid Set of null hypothesis to test for inversion\n#' @param post_length Number of post-treatment periods\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param stat_func Function to compute test statistic\n#' \n#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)\n#' @noRd\ncompute_permute_ci <- function(wide_data, ascm, grid,\n                               post_length, alpha, type,\n                               q, ns, stat_func) {\n  # make sure 0 is in the grid\n  grid <- c(grid, 0)\n  ps <-sapply(grid, \n              function(x) {\n                compute_permute_pval(wide_data, ascm, x, \n                                     post_length, type, q, ns, stat_func)\n              })\n  c(min(grid[ps >= alpha]), max(grid[ps >= alpha]), ps[grid == 0])\n}\n\n\n\n#' Compute conformal confidence interval for a linear model for effects\n#' int + slope * time\n#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector\n#' @param ascm Fitted `augsynth` object\n#' @param grid_int Set of null hypothesis values for the intercept\n#' @param grid_slope Set of null hypothesis values for the slope\n#' @param post_length Number of post-treatment periods\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param stat_func Function to compute test statistic\n#' \n#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)\n#' @noRd\ncompute_permute_ci_linear <- function(wide_data, ascm, grid_int, grid_slope,\n                               post_length, alpha, type,\n                               q, ns, stat_func) {\n  # make sure 0 is in both grids\n  # grid_int <- c(grid_int, 0)\n  # grid_slope <- c(grid_slope, 0)\n  # make the combined grid\n  grid_comb <- expand.grid(grid_int, grid_slope)\n  grid_comb$p_val <-apply(grid_comb, 1,\n              function(x) {\n                compute_permute_pval(wide_data, ascm, x[1] + x[2] * (1:post_length), \n                                     post_length, type, q, ns, stat_func)\n              })\n  ci_int <- c(min(grid_comb[grid_comb$p_val >= alpha, 1]),\n              max(grid_comb[grid_comb$p_val >= alpha, 1]))\n  ci_slope <- c(min(grid_comb[grid_comb$p_val >= alpha, 2]),\n                   max(grid_comb[grid_comb$p_val >= alpha, 2]))\n  int_slope_est <- as.numeric(grid_comb[which.max(grid_comb$p_val), 1:2])\n  return(list(est_int = int_slope_est[1], ci_int = ci_int,\n              est_slope = int_slope_est[2], ci_slope = ci_slope))\n}\n\n\n#' Jackknife+ algorithm over time\n#' @param ascm Fitted `augsynth` object\n#' @param alpha Confidence level\n#' @param conservative Whether to use the conservative jackknife+ procedure\n#' @return List that contains:\n#'         \\itemize{\n#'          \\item{\"att\"}{Vector of ATT estimates}\n#'          \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n#'          \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n#'          \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n#'          \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n#'          \\item{\"alpha\"}{Level of confidence interval}\n#'         }\ntime_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative = F) {\n    wide_data <- ascm_multi$data\n    data_list <- ascm_multi$data_list\n\n    n <- nrow(wide_data$X)\n    k <- length(data_list$X)\n\n\n    t0 <- min(sapply(data_list$X, ncol))\n    tpost <- max(sapply(data_list$y, ncol))\n    t_final <- t0 + tpost\n    Z <- wide_data$Z\n\n    jack_ests <- lapply(1:t0, \n        function(tdrop) {\n            # drop unit i\n            new_data_list <- drop_time_t_multiout(data_list, Z, tdrop)\n            # refit\n            new_ascm <- do.call(fit_augsynth_multiout_internal,\n                    c(list(wide_list = new_data_list,\n                            combine_method = ascm_multi$combine_method,\n                            Z = data_list$Z,\n                            progfunc = ascm_multi$progfunc,\n                            scm = ascm_multi$scm,\n                            fixedeff = ascm_multi$fixedeff,\n                            outcomes_str = ascm_multi$outcomes),\n                        ascm_multi$extra_args))\n            # get ATT estimates and held out error for time t\n            # t0 is prediction for held out time\n            est <- predict(new_ascm, att = F)[(t0 +1):t_final, , drop = F]\n            est <- rbind(est, colMeans(est))\n            # err <- c(colMeans(wide_data$X[wide_data$trt == 1,\n            #                              tdrop,\n            #                              drop = F]) -\n            #         predict(new_ascm, att = F)[t0])\n            err <- c(predict(new_ascm, att = T)[t0, , drop = F])\n            list(err, t(t(est) + abs(err)), t(t(est) - abs(err)), t(t(est) + err), est)\n        })\n    # get errors and jackknife distribution\n    held_out_errs <- matrix(vapply(jack_ests, `[[`, numeric(k), 1), nrow = k)\n    jack_dist_high <- vapply(jack_ests, `[[`,\n                        matrix(0, nrow = tpost + 1, ncol = k), 2)\n    jack_dist_low <- vapply(jack_ests, `[[`,\n                        matrix(0, nrow = tpost + 1, ncol = k), 3)\n    jack_dist_cons <- vapply(jack_ests, `[[`,\n                        matrix(0, nrow = tpost + 1, ncol = k), 4)\n\n    out <- list()\n    att <- predict(ascm_multi, att = T)\n    out$att <- rbind(att, \n                      colMeans(att[(t0 + 1):t_final, , drop = F]))\n    # held out ATT\n\n    out$heldout_att <- rbind(t(held_out_errs), \n                              att[(t0 + 1):t_final, , drop = F], \n                              colMeans(att[(t0 + 1):t_final, , drop = F]))\n    if(conservative) {\n        qerr <- apply(abs(held_out_errs), 1, \n                      stats::quantile, 1 - alpha, type = 1)\n        out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),\n                        t(t(apply(jack_dist_cons, 1:2, min)) - qerr))\n        out$ub <- rbind(matrix(NA, nrow = t0, ncol = k),\n                        t(t(apply(jack_dist_cons, 1:2, max)) + qerr))\n\n    } else {\n        out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),\n                        apply(jack_dist_low, 1:2,\n                              stats::quantile, alpha, type = 1))\n        out$ub <- rbind(matrix(NA, nrow = t0, ncol = k), \n                        apply(jack_dist_high, 1:2, \n                              stats::quantile, 1 - alpha, type = 1))\n    }\n    # shift back to ATT scale\n    y1 <- predict(ascm_multi, att = F) + att\n    y1 <-  rbind(y1, colMeans(y1[(t0 + 1):t_final, , drop = F]))\n    shifted_lb <- y1 - out$ub\n    shifted_ub <- y1 - out$lb\n    out$lb <- shifted_lb\n    out$ub <- shifted_ub\n    out$alpha <- alpha\n\n\n    return(out)\n}\n\n#' Drop time period from pre-treatment data\n#' @param wide_data (X, y, trt)\n#' @param Z Covariates matrix\n#' @param t_drop Time to drop\n#' @noRd\ndrop_time_t_multiout <- function(data_list, Z, t_drop) {\n\n        new_data_list <- list()\n        new_data_list$trt <- data_list$trt\n        new_data_list$X <- lapply(data_list$X,\n                                  function(x) x[, -t_drop, drop = F])\n        new_data_list$y <- lapply(1:length(data_list$y),\n                                  function(k) {\n                                    cbind(data_list$X[[k]][, t_drop, drop = F], \n                                          data_list$y[[k]])\n                                  })\n        return(new_data_list)\n}\n\n\n#' Conformal inference procedure to compute p-values and point-wise confidence intervals\n#' @param ascm Fitted `augsynth` object\n#' @param alpha Confidence level\n#' @param stat_func Function to compute test statistic\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param grid_size Number of grid points to use when inverting the hypothesis test (default is 1, so only to test joint null)\n#' @return List that contains:\n#'         \\itemize{\n#'          \\item{\"att\"}{Vector of ATT estimates}\n#'          \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n#'          \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n#'          \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n#'          \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n#'          \\item{\"p_val\"}{p-value for test of no post-treatment effect}\n#'          \\item{\"alpha\"}{Level of confidence interval}\n#'         }\nconformal_inf_multiout <- function(ascm_multi, alpha = 0.05, \n                                    stat_func = NULL, type = \"iid\",\n                                    q = 1, ns = 1000, grid_size = 1,\n                                    lin_h0 = NULL) {\n  wide_data <- ascm_multi$data\n  data_list <- ascm_multi$data_list\n\n  n <- nrow(wide_data$X)\n  k <- length(data_list$X)\n\n\n  t0 <- min(sapply(data_list$X, ncol))\n  tpost <- max(sapply(data_list$y, ncol))\n  t_final <- t0 + tpost\n\n  # grid of nulls\n  att <- predict(ascm_multi, att = T)\n  post_att <- att[(t0 +1):t_final,, drop = F]\n  post_sd <- apply(post_att, 2, function(x) sqrt(mean(x ^ 2, na.rm = T)))\n  # iterate over post-treatment periods to get pointwise CIs\n  \n  vapply(1:tpost,\n         function(j) {\n          # fit using t0 + j as a pre-treatment period and get residuals\n          new_data_list <- data_list\n          new_data_list$X <- lapply(1:k,\n              function(i) {\n                Xi <- cbind(data_list$X[[i]], data_list$y[[i]][, j, drop = TRUE])\n                colnames(Xi) <- c(colnames(data_list$X[[i]]),\n                                  colnames(data_list$y[[i]])[j])\n                Xi\n          })\n          \n          \n          if(tpost > 1) {\n            new_data_list$y <- lapply(1:k,\n              function(i) {\n                data_list$y[[i]][, -j, drop = FALSE]\n            })\n          } else {\n            # set the post period has to be *something*\n            new_data_list$y <- lapply(1:k,\n              function(i) {\n                x <- matrix(1, nrow = n, ncol = 1)\n                colnames(x) <- max(as.numeric(colnames(data_list$y[[i]]))) + 1\n                x\n            })\n          }\n\n\n          # make a grid around the estimated ATT\n          if(is.null(lin_h0)) {\n            grid <- lapply(1:k, \n            function(i) {\n              seq(att[t0 + j, i] - 2 * post_sd[i], att[t0 + j, i] + 2 * post_sd[i],\n                    length.out = grid_size)\n            })\n          } else {\n            grid <- seq(min(att[t0 + j, ]) - 2 * max(post_sd),\n                max(att[t0 + j, ]) + 2 * max(post_sd),\n                length.out = grid_size)\n          }\n          if(grid_size > 1) {\n            compute_permute_ci_multiout(new_data_list, ascm_multi, grid, 1,\n                                    alpha, type, q, ns, lin_h0, stat_func)\n          } else {\n            rbind(matrix(0, ncol = k, nrow = 2),\n              compute_permute_pval_multiout(new_data_list, ascm_multi, numeric(k),\n                                          1, type, q, ns, stat_func))\n          }\n          \n         },\n         matrix(0, ncol = k, nrow=3)) -> cis\n  # # test a null post-treatment effect\n\n  new_data_list <- data_list\n  new_data_list$X <- lapply(1:k,\n      function(i) {\n        Xi <- cbind(data_list$X[[i]], data_list$y[[i]])\n        colnames(Xi) <- c(colnames(data_list$X[[i]]),\n                          colnames(data_list$y[[i]]))\n        Xi\n      })\n  # set post treatment to be *something*\n  new_data_list$y <- lapply(1:k,\n      function(i) {\n        data_list$y[[i]][, 1, drop = FALSE]\n    })\n  null_p <- compute_permute_pval_multiout(new_data_list, ascm_multi,\n                                          numeric(k), \n                                          tpost, type, q, ns, stat_func)\n  if(is.null(lin_h0)) {\n    grid <- lapply(1:k, \n            function(i) {\n              seq(min(att[(t0 + 1):tpost, i]) - 4 * post_sd[i],\n                  max(att[(t0 + 1):tpost, i]) + 4 * post_sd[i],\n                    length.out = grid_size)\n            })\n  } else {\n    grid <- seq(min(att[t0 + 1, ]) - 3 * max(post_sd),\n                max(att[t0 + 1, ]) + 3 * max(post_sd),\n                length.out = grid_size)\n  }\n  null_ci <- compute_permute_ci_multiout(new_data_list, ascm_multi, grid,\n                                          tpost, alpha, type, q, ns,\n                                          lin_h0, stat_func)\n  out <- list()\n  att <- predict(ascm_multi, att = T)\n  out$att <- rbind(att, apply(att[(t0 + 1):t_final, , drop = F], 2, mean))\n  out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),\n                  t(matrix(cis[1, ,], nrow = k)),\n                  # rep(NA, k)\n                  null_ci[1,]\n                  )\n  colnames(out$lb) <- ascm_multi$outcomes\n  out$ub <- rbind(matrix(NA, nrow = t0, ncol = k),\n                  t(matrix(cis[2, ,], nrow = k)),\n                  # rep(NA, k)\n                  null_ci[2,]\n                  )\n  colnames(out$ub) <- ascm_multi$outcomes\n  out$p_val <- rbind(matrix(NA, nrow = t0, ncol = k),\n                  t(matrix(cis[3, ,], nrow = k)),\n                  # rep(null_p, k)\n                  null_ci[3,])\n  colnames(out$p_val) <- ascm_multi$outcomes\n  out$alpha <- alpha\n  return(out)\n}\n\n\n\n#' Compute conformal test statistics\n#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector\n#' @param ascm Fitted `augsynth` object\n#' @param h0 Null hypothesis to test\n#' @param post_length Number of post-treatment periods\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param stat_func Function to compute test statistic\n#' \n#' @return List that contains:\n#'         \\itemize{\n#'          \\item{\"resids\"}{Residuals after enforcing the null}\n#'          \\item{\"test_stats\"}{Permutation distribution of test statistics}\n#'          \\item{\"stat_func\"}{Test statistic function}\n#'         }\n#' @noRd\ncompute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0,\n                                              post_length, type,\n                                              q, ns, stat_func) {\n  # format data\n  new_data_list <- data_list\n  t0 <- ncol(data_list$X[[1]]) - post_length\n  tpost <- t0 + post_length\n  k <- length(data_list$X)\n  # adjust outcomes for null\n  for(i in 1:k) {\n    new_data_list$X[[k]][data_list$trt == 1,(t0 + 1):tpost ] <- new_data_list$X[[k]][data_list$trt == 1,(t0 + 1):tpost] - h0[i]\n  }\n  # fit synth with adjusted data and get residuals\n  new_ascm <- do.call(fit_augsynth_multiout_internal,\n                    c(list(wide_list = new_data_list,\n                            combine_method = ascm_multi$combine_method,\n                            Z = data_list$Z,\n                            progfunc = ascm_multi$progfunc,\n                            scm = ascm_multi$scm,\n                            fixedeff = ascm_multi$fixedeff,\n                            outcomes_str = ascm_multi$outcomes),\n                        ascm_multi$extra_args))\n\n  resids <- predict(new_ascm, att = T)[1:tpost, , drop = F]\n\n  # permute residuals and compute test statistic\n  if(is.null(stat_func)) {\n    stat_func <- function(x) {\n      x <- na.omit(x)\n      (sum(abs(x) ^ q)  / sqrt(length(x))) ^ (1 / q)\n    }\n  }\n  if(type == \"iid\") {\n    test_stats <- sapply(1:ns, \n                        function(x) {\n                          idxs <- sample(1:nrow(resids))\n                          reorder <- resids[idxs, , drop = F]\n                          apply(reorder[(t0 + 1):tpost, ,drop = F], 2, stat_func)\n                        })\n  } else {\n    ## increment time by one step and wrap\n    test_stats <- sapply(0:(tpost - 1),\n                        function(j) {\n                          reorder <- resids[(0:(tpost -1) + j) %% tpost + 1, ,drop = F]\n                          if(!all(dim(reorder) == dim(resids))) {\n                            stop(\"Error in block resampling\")\n                          }\n                          apply(reorder[(t0 + 1):tpost, , drop = F], 2, stat_func)\n                        })\n  }\n  \n  return(list(resids = resids,\n              test_stats = matrix(test_stats, nrow = k),\n              stat_func = stat_func))\n}\n\n\n#' Compute conformal p-value\n#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector\n#' @param ascm Fitted `augsynth` object\n#' @param h0 Null hypothesis to test\n#' @param post_length Number of post-treatment periods\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param stat_func Function to compute test statistic\n#' \n#' @return Computed p-value\n#' @noRd\ncompute_permute_pval_multiout <- function(data_list, ascm_multi, h0,\n                                        post_length, type,\n                                        q, ns, stat_func) {\n  t0 <- ncol(data_list$X[[1]]) - post_length\n  tpost <- t0 + post_length\n\n  out <- compute_permute_test_stats_multiout(data_list, ascm_multi, h0,\n                                          post_length, type, q, ns, stat_func)\n  k <- length(data_list$X)\n\n  comb_stat <- mean(apply(out$resids[(t0 + 1):tpost, , drop = F], 2, out$stat_func), na.rm = TRUE)\n  comb_test_stats <- apply(out$test_stats, 2, mean, na.rm = TRUE)\n  # if(h0 == 0) {\n  #   hist(comb_test_stats)\n  #   abline(v = comb_stat)\n  #   print(mean(comb_stat <= comb_test_stats))\n  #   print(1 - mean(comb_stat > comb_test_stats))\n  # }\n  1 - mean(comb_stat > comb_test_stats)\n}\n\n#' Compute conformal p-value\n#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector\n#' @param ascm Fitted `augsynth` object\n#' @param grid Set of null hypothesis to test for inversion\n#' @param post_length Number of post-treatment periods\n#' @param type Either \"iid\" for iid permutations or \"block\" for moving block permutations\n#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`\n#' @param ns Number of resamples for \"iid\" permutations\n#' @param stat_func Function to compute test statistic\n#' \n#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)\n#' @noRd\ncompute_permute_ci_multiout <- function(data_list, ascm_multi, grid,\n                                      post_length, alpha, type,\n                                      q, ns, lin_h0 = NULL, stat_func) {\n  # make sure 0 is in the grid\n  if(is.null(lin_h0)) {\n    grid <- lapply(grid, function(x) c(x, 0))\n    k <- length(grid)\n    # get all combinations of grid\n    grid <- expand.grid(grid)\n    grid_low <- NULL\n  } else {\n    k <- length(lin_h0)\n    # keep track of low dimensional grid\n    grid_low <- c(grid, 0)\n    # transform into high dimensional grid with linear hypothesis\n    grid <- sapply(lin_h0, function(x) x * grid_low)\n  }\n  ps <- apply(grid, 1,\n              function(x) {\n                compute_permute_pval_multiout(data_list, ascm_multi, x, \n                                     post_length, type, q, ns, stat_func)\n              })\n  sapply(1:k, \n    function(i) c(min(grid[ps >= alpha, i]), \n                  max(grid[ps >= alpha, i]), \n                  ps[apply(grid == 0, 1, all)]))\n}\n\n\n\n\n#' Drop unit i from data\n#' @param wide_data (X, y, trt)\n#' @param Z Covariates matrix\n#' @param i Unit to drop\n#' @noRd\ndrop_unit_i <- function(wide_data, Z, i) {\n\n        new_wide_data <- list()\n        new_wide_data$trt <- wide_data$trt[-i]\n        new_wide_data$X <- wide_data$X[-i,, drop = F]\n        new_wide_data$y <- wide_data$y[-i,, drop = F]\n\n        X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]\n        x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),\n                     ncol=1)\n        y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]\n        y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])\n\n        new_synth_data <- list()\n        new_synth_data$Z0 <- t(X0)\n        new_synth_data$X0 <- t(X0)\n        new_synth_data$Z1 <- x1\n        new_synth_data$X1 <- x1\n        new_Z <- if(!is.null(Z)) Z[-i, , drop = F] else NULL\n\n        return(list(wide_data = new_wide_data,\n                    synth_data = new_synth_data,\n                    Z = new_Z))\n}\n\n#' Drop unit i from data\n#' @param wide_list (X, y, trt)\n#' @param Z Covariates matrix\n#' @param i Unit to drop\n#' @noRd\ndrop_unit_i_multiout <- function(wide_list, Z, i) {\n\n        new_wide_data <- list()\n        new_wide_data$trt <- wide_list$trt[-i]\n        new_wide_data$X <- lapply(wide_list$X, function(x) x[-i,, drop = F])\n        new_wide_data$y <- lapply(wide_list$y, function(x) x[-i,, drop = F])\n        new_Z <- if(!is.null(Z)) Z[-i, , drop = F] else NULL\n\n        return(list(wide_list = new_wide_data,\n                    Z = new_Z))\n}\n\n\n#' Estimate standard errors for single ASCM with the jackknife\n#' Do this for ridge-augmented synth\n#' @param ascm Fitted augsynth object\n#' \n#' @return List that contains:\n#'         \\itemize{\n#'          \\item{\"att\"}{Vector of ATT estimates}\n#'          \\item{\"se\"}{Standard error estimate}\n#'          \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n#'          \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n#'          \\item{\"alpha\"}{Level of confidence interval}\n#'         }\njackknife_se_single <- function(ascm) {\n\n    wide_data <- ascm$data\n    synth_data <- ascm$data$synth_data\n    n <- nrow(wide_data$X)\n    n_c <- dim(synth_data$Z0)[2]\n    Z <- wide_data$Z\n\n    t0 <- dim(synth_data$Z0)[1]\n    tpost <- ncol(wide_data$y)\n    t_final <- dim(synth_data$Y0plot)[1]\n    errs <- matrix(0, n_c, t_final - t0)\n\n\n    # only drop out control units with non-zero weights\n    nnz_weights <- numeric(n)\n    nnz_weights[wide_data$trt == 0] <- round(ascm$weights, 3) != 0\n    # if more than one unit is treated, include them in the jackknife\n    if(sum(wide_data$trt) > 1) {\n      nnz_weights[wide_data$trt == 1] <- 1\n    }\n\n    trt_idxs <- (1:n)[as.logical(nnz_weights)]\n\n\n    # jackknife estimates\n    ests <- vapply(trt_idxs,\n                   function(i) {\n                       # drop unit i\n                       new_data <- drop_unit_i(wide_data, Z, i)\n                       # refit\n                       new_ascm <- do.call(fit_augsynth_internal,\n                                c(list(wide = new_data$wide,\n                                       synth_data = new_data$synth_data,\n                                       Z = new_data$Z,\n                                       progfunc = ascm$progfunc,\n                                       scm = ascm$scm,\n                                       fixedeff = ascm$fixedeff),\n                                  ascm$extra_args))\n                       # get ATT estimates\n                       est <- predict(new_ascm, att = T)[(t0 + 1):t_final]\n                       c(est, mean(est))\n                   },\n                   numeric(tpost + 1))\n    # convert to matrix\n    ests <- matrix(ests, nrow = tpost + 1, ncol = length(trt_idxs))\n    ## standard errors\n    se2 <- apply(ests, 1,\n                 function(x) (n - 1) / n * sum((x - mean(x, na.rm = T)) ^ 2))\n    se <- sqrt(se2)\n\n    out <- list()\n    att <- predict(ascm, att = T)\n    out$att <- c(att, mean(att[(t0 + 1):t_final]))\n\n    out$se <- c(rep(NA, t0), se)\n    # out$sigma <- NA\n    return(out)\n}\n\n\n#' Compute standard errors using the jackknife\n#' @param multisynth fitted multisynth object\n#' @param relative Whether to compute effects according to relative time\n#' @noRd\njackknife_se_multi <- function(multisynth, relative=NULL, alpha = 0.05, att_weight = NULL) {\n    ## get info from the multisynth object\n    if(is.null(relative)) {\n        relative <- multisynth$relative\n    }\n    n_leads <- multisynth$n_leads\n    n <- nrow(multisynth$data$X)\n    att <- predict(multisynth, att=T, att_weight = att_weight)\n    outddim <- nrow(att)\n\n    J <- length(multisynth$grps)\n\n    ## drop each unit and estimate overall treatment effect\n    jack_est <- vapply(1:n,\n                       function(i) {\n                           msyn_i <- drop_unit_i_multi(multisynth, i)\n                           pred <- predict(msyn_i[[1]], relative=relative, att=T, att_weight = att_weight)\n                           if(nrow(pred) < outddim) {\n                               pred <- rbind(\n                                   pred[1:(nrow(pred)-1), ],\n                                   matrix(NA, nrow=outddim-nrow(pred), ncol=ncol(pred)),\n                                   pred[nrow(pred), ]\n                               )\n                           }\n                           if(length(msyn_i[[2]]) != 0) {\n                               out <- matrix(NA, nrow=nrow(pred), ncol=(J+1))\n                               out[,-(msyn_i[[2]]+1)] <- pred\n                           } else {\n                               out <- pred\n                           }\n                           out\n                       },\n                       matrix(0, nrow=outddim,ncol=(J+1)))\n\n    se2 <- apply(jack_est, c(1,2),\n                function(x) (n-1) / n * sum((x - mean(x,na.rm=T))^2, na.rm=T))\n    lower_bound <- att - qnorm(1 - alpha / 2) * sqrt(se2)\n    upper_bound <- att + qnorm(1 - alpha / 2) * sqrt(se2)\n    return(list(att = att, se = sqrt(se2),\n                lower_bound = lower_bound, upper_bound = upper_bound))\n\n}\n\n#' Helper function to drop unit i and refit\n#' @param msyn multisynth_object\n#' @param i Unit to drop\n#' @noRd\ndrop_unit_i_multi <- function(msyn, i) {\n\n    n <- nrow(msyn$data$X)\n    time_cohort <- msyn$time_cohort\n    which_t <- (1:n)[is.finite(msyn$data$trt)]\n\n    not_miss_j <- which_t %in% setdiff(which_t, i)\n\n    # drop unit i from data\n    drop_i <- msyn$data\n    drop_i$X <- msyn$data$X[-i, , drop = F]\n    drop_i$y <- msyn$data$y[-i, , drop = F]\n    drop_i$trt <- msyn$data$trt[-i]\n    drop_i$mask <- msyn$data$mask[not_miss_j,, drop = F]\n\n    if(!is.null(msyn$data$Z)) {\n      drop_i$Z <- msyn$data$Z[-i, , drop = F]\n    } else {\n      drop_i$Z <- NULL\n    }\n\n    long_df <- msyn$long_df\n    unit <- colnames(long_df)[1]\n    # make alphabetical, because the ith unit is the index in alphabetical ordering\n    long_df <- long_df[order(long_df[, unit, drop = TRUE]),]\n    ith_unit <- unique(long_df[,unit, drop = TRUE])[i]\n    long_df <- long_df[long_df[,unit, drop = TRUE] != ith_unit,]\n\n    # re-fit everything\n    args_list <- list(wide = drop_i, relative = msyn$relative,\n                      n_leads = msyn$n_leads, n_lags = msyn$n_lags,\n                      nu = msyn$nu, lambda = msyn$lambda,\n                      V = msyn$V,\n                      force = msyn$force, n_factors = msyn$n_factors,\n                      scm = msyn$scm, time_w = msyn$time_w,\n                      lambda_t = msyn$lambda_t,\n                      fit_resids = msyn$fit_resids,\n                      time_cohort = msyn$time_cohort, long_df = long_df,\n                      how_match = msyn$how_match)\n    msyn_i <- do.call(multisynth_formatted, c(args_list, msyn$extra_pars))\n\n    # check for dropped treated units/time periods\n    if(time_cohort) {\n        dropped <- which(!msyn$grps %in% msyn_i$grps)\n    } else {\n        dropped <- which(!not_miss_j)\n    }\n    return(list(msyn_i,\n                dropped))\n}\n\n\n#' Estimate standard errors for multi outcome ascm with jackknife\n#' @param ascm Fitted augsynth object\n#' @noRd\njackknife_se_multiout <- function(ascm) {\n\n    wide_data <- ascm$data\n    wide_list <- ascm$data_list\n    n <- nrow(wide_data$X)\n    Z <- wide_data$Z\n\n\n    # only drop out control units with non-zero weights\n    nnz_weights <- numeric(n)\n    nnz_weights[wide_data$trt == 0] <- round(ascm$weights, 3) != 0\n\n    trt_idxs <- (1:n)[as.logical(nnz_weights)]\n\n    # jackknife estimates\n    ests <- lapply(trt_idxs,\n                   function(i) {\n                       # drop unit i\n                       new_data <- drop_unit_i_multiout(wide_list, Z, i)\n                       # refit\n                       new_ascm <- do.call(fit_augsynth_multiout_internal,\n                                c(list(wide = new_data$wide,\n                                       combine_method = ascm$combine_method,\n                                       Z = new_data$Z,\n                                       progfunc = ascm$progfunc,\n                                       scm = ascm$scm,\n                                       fixedeff = ascm$fixedeff,\n                                       outcomes_str = ascm$outcomes),\n                                  ascm$extra_args))\n                        new_ascm$outcomes <- ascm$outcomes\n                        new_ascm$data_list <- ascm$data_list\n                        new_ascm$data$time <- ascm$data$time\n                       # get ATT estimates\n                       est <- predict(new_ascm, att = T)\n                       est <- est[as.numeric(rownames(est)) >= ascm$t_int,, drop = F]\n                       rbind(est, colMeans(est, na.rm = T))\n                   })\n    ests <- simplify2array(ests)\n    ## standard errors\n    se2 <- apply(ests, c(1, 2),\n                 function(x) (n - 1) / n * sum((x - mean(x, na.rm = T)) ^ 2))\n    se <- sqrt(se2)\n    out <- list()\n    att <- predict(ascm, att = T)\n    att_post <- colMeans(att[as.numeric(rownames(att)) >= ascm$t_int,, drop = F],\n                         na.rm = T)\n    out$att <- rbind(att, att_post)\n    t0 <- sum(as.numeric(rownames(att)) < ascm$t_int)\n    out$se <- rbind(matrix(NA, t0, ncol(se)), se)\n    out$sigma <- NA\n    return(out)\n}\n\n\n\n#' Compute the weighted bootstrap distribution\n#' @param multisynth fitted multisynth object\n#' @param rweight Function to draw random weights as a function of n (e.g rweight(n))\n#' @param relative Whether to compute effects according to relative time\n#' @noRd\nweighted_bootstrap_multi <- function(multisynth,\n                                    rweight = rwild_b,\n                                    n_boot = 1000,\n                                    alpha = 0.05,\n                                    att_weight = NULL,\n                                    relative=NULL) {\n  ## get info from the multisynth object\n  if(is.null(relative)) {\n      relative <- multisynth$relative\n  }\n\n  n <- nrow(multisynth$data$X)\n  att <- predict(multisynth, att=T, att_weight = att_weight)\n  outddim <- nrow(att)\n  n1 <- sum(is.finite(multisynth$data$trt))\n  J <- length(multisynth$grps)\n\n\n  # draw random weights to get bootstrap distribution\n  bs_est <- vapply(1:n_boot,\n                      function(i) {\n                        Z <- rweight(n)# / sqrt(n1)\n\n                        predict(multisynth, att=T, att_weight = att_weight, bs_weight = Z) - sum(Z) / n1 * att\n                      },\n                      matrix(0, nrow=outddim,ncol=(J+1)))\n\n  se2 <- apply(bs_est, c(1,2),\n              function(x) mean((x - mean(x))^2, na.rm=T))\n  bias <- apply(bs_est, c(1,2),\n              function(x) mean(x, na.rm=T))\n  upper_bound <- att - apply(bs_est, c(1,2),\n              function(x) quantile(x, alpha / 2, na.rm = T))\n  \n  lower_bound <- att - apply(bs_est, c(1,2),\n              function(x) quantile(x, 1 - alpha / 2, na.rm = T))\n\n  return(list(att = att,\n              bias = bias,\n              se = sqrt(se2),\n              upper_bound = upper_bound,\n              lower_bound = lower_bound))\n\n}\n\n#' Bayesian bootstrap\n#' @param n Number of units\n#' @export\nrdirichlet_b <- function(n) {\n  Z <- as.numeric(rgamma(n, 1, 1))\n  return(Z / sum(Z) * n)\n}\n\n#' Non-parametric bootstrap\n#' @param n Number of units\n#' @export\nrmultinom_b <- function(n) as.numeric(rmultinom(1, n, rep(1 / n, n)))\n\n#' Wild bootstrap (Mammen 1993)\n#' @param n Number of units\n#' @export\nrwild_b <- function(n) {\n  sample(c(-(sqrt(5) - 1) / 2, (sqrt(5) + 1) / 2 ), n,\n         replace = TRUE,\n         prob = c((sqrt(5) + 1)/ (2 * sqrt(5)), (sqrt(5) - 1) / (2 * sqrt(5))))\n}"
  },
  {
    "path": "R/multi_outcomes.R",
    "content": "#' Fit Augmented SCM with multiple outcomes\n#' @param form outcome ~ treatment | auxillary covariates\n#' @param unit Name of unit column\n#' @param time Name of time column\n#' @param t_int Time of intervention\n#' @param data Panel data as dataframe\n#' @param progfunc What function to use to impute control outcomes\n#'                 Ridge=Ridge regression (allows for standard errors),\n#'                 None=No outcome model,\n#' @param scm Whether the SCM weighting function is used\n#' @param fixedeff Whether to include a unit fixed effect, default F \n#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted\n#' @param combine_method How to combine outcomes: `concat` concatenates outcomes and `avg` averages them, default: 'avg'\n#' @param ... optional arguments for outcome model\n#'\n#' @return augsynth object that contains:\n#'         \\itemize{\n#'          \\item{\"weights\"}{Ridge ASCM weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#'          \\item{\"mhat\"}{Outcome model estimate}\n#'          \\item{\"data\"}{Panel data as matrices}\n#'         }\n#' @export\naugsynth_multiout <- function(form, unit, time, t_int, data,\n                              progfunc=c(\"Ridge\", \"None\"),\n                              scm=T,\n                              fixedeff = FALSE,\n                              cov_agg=NULL,\n                              combine_method = \"avg\",\n                              ...) {\n    call_name <- match.call()\n\n    form <- Formula::Formula(form)\n    unit <- enquo(unit)\n    time <- enquo(time)\n    \n    ## format data\n    outcome <- terms(formula(form, rhs=1))[[2]]\n    trt <- terms(formula(form, rhs=1))[[3]]\n\n    outcomes_str <- all.vars(outcome)\n    outcomes <- sapply(outcomes_str, quo)\n    # get outcomes as a list\n    wide_list <- format_data_multi(outcomes, trt, unit, time, t_int, data)\n    \n\n    \n\n    ## add covariates\n    if(length(form)[2] == 2) {\n        cov_form <- paste(deparse(terms(formula(form, rhs = 2))[[3]]), collapse = \"\")\n        new_form <- as.formula(paste(\"~\", cov_form))\n        Z <- extract_covariates(new_form, unit, time, t_int, data, cov_agg)\n    } else {\n        Z <- NULL\n    }\n\n    # only allow ridge augmentation\n    if(! tolower(progfunc) %in% c(\"none\", \"ridge\")) {\n      stop(paste(progfunc, \"is not a valid augmentation function with multiple outcomes. Only `none` or `ridge` are allowable options for `prog_func`\"))\n    }\n\n    # fit augmented SCM\n    augsynth <- fit_augsynth_multiout_internal(wide_list, combine_method, Z,\n                                               progfunc, scm,\n                                               fixedeff, outcomes_str, ...)\n\n    # add some extra data\n    augsynth$data$time <- data %>% distinct(!!time) %>% pull(!!time)\n    augsynth$call <- call_name\n    augsynth$t_int <- t_int \n    augsynth$combine_method <- combine_method\n\n    treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)\n    control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% \n                        distinct(!!unit) %>% pull(!!unit)\n    augsynth$weights <- matrix(augsynth$weights)\n    rownames(augsynth$weights) <- control_units\n\n    return(augsynth)\n}\n\n\n#' Internal function to fit augmented SCM with multiple outcomes\n#' @param wide_list List of matrices for each outcome formatted from format_data\n#' @param combine_method How to combine outcomes\n#' @param Z Matrix of auxiliary covariates\n#' @param progfunc outcome model to use\n#' @param scm Whether to fit SCM\n#' @param fixedeff Whether to de-mean synth\n#' @param ... Extra args for outcome model\n#' @noRd\nfit_augsynth_multiout_internal <- function(wide_list, combine_method, Z,\n                                           progfunc, scm, fixedeff, \n                                           outcomes_str, ...) {\n\n\n    # combine into a matrix for fitting and balancing\n    out <- combine_outcomes(wide_list, combine_method, fixedeff, ...)\n    wide_bal <- out$wide_bal\n    mhat <- out$mhat\n    V <- out$V\n    synth_data <- do.call(format_synth, wide_bal)\n\n    # set Y1 and Y0plot to be raw concatenated outcomes\n    X <- do.call(cbind, wide_list$X)\n    y <- do.call(cbind, wide_list$y)\n    trt <- wide_list$trt\n    synth_data$Y0plot <- t(cbind(X, y)[trt == 0,, drop = F])\n    synth_data$Y1plot <- colMeans(cbind(X, y)[trt == 1,, drop = F])\n\n\n    augsynth <- fit_augsynth_internal(wide_bal, synth_data, Z, progfunc, \n                                      scm, fixedeff = F, V = V, ...)\n\n    # potentially add back in fixed effects\n    augsynth$mhat <- mhat# + augsynth$mhat\n\n    augsynth$data <- list(X = X, trt = trt, y = y, Z = Z)\n    augsynth$data_list <- wide_list\n    augsynth$outcomes <- outcomes_str\n    # change fixedeff flag to match input (rather than fixedeff = F in fit_augsynth_internal)\n    augsynth$fixedeff <- fixedeff\n    ##format output\n    class(augsynth) <- c(\"augsynth_multiout\", \"augsynth\")\n    return(augsynth)\n}\n\n#' Helper function to combine multiple outcomes into a single balance matrix\n#' @param wide_list List of lists of pre/post treatment data for each outcome\n#' @param combine_method How to combine outcomes\n#' @param fixedeff Whether to take out unit fixed effects or not\n#' @param nu Weighting between concatenated and averaged objectives\n#' @param ... Extra arguments for combination\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"X\"}{Matrix of combined pre-treatment outcomes}\n#'          \\item{\"trt\"}{Vector of treatment assignments}\n#'          \\item{\"y\"}{Matrix of combined post-treatment outcomes}\n#'         }\ncombine_outcomes <- function(wide_list, combine_method, fixedeff,\n                             nu = NULL, ...) {\n  n_outs <- length(wide_list$X)\n  n_units <- Map(nrow, wide_list$X) %>% Reduce(max, .)\n  # take out unit fixed effects\n  demean_j <- function(j) {\n    means <- rowMeans(wide_list$X[[j]], na.rm = TRUE)\n\n    new_wide_data <- list()\n    new_X <- wide_list$X[[j]] - means\n    new_y <- wide_list$y[[j]] - means\n\n    new_wide_data$X <- new_X\n    new_wide_data$y <- new_y\n    new_wide_data$mhat_pre <- replicate(ncol(wide_list$X[[j]]),\n                                        means)\n    new_wide_data$mhat_post <- replicate(ncol(wide_list$y[[j]]),\n                                        means)\n    return(new_wide_data)\n  }\n  if(fixedeff) {\n    new_wide_list <- lapply(1:n_outs, demean_j)\n    wide_list$X <- lapply(new_wide_list, function(x) x$X)\n    wide_list$y <- lapply(new_wide_list, function(x) x$y)\n    mhat_pre <- lapply(new_wide_list, function(x) x$mhat_pre)\n    mhat_post <- lapply(new_wide_list, function(x) x$mhat_post)\n  } else {\n    mhat_pre <- lapply(\n      1:n_outs,\n      function(j) matrix(0, nrow = n_units, ncol = ncol(wide_list$X[[j]])))\n    mhat_post <- lapply(\n      1:n_outs,\n      function(j) matrix(0, nrow = n_units, ncol = ncol(wide_list$y[[j]])))\n  }\n\n    # combine outcomes\n    if(combine_method == \"concat\") {\n      # center X and scale by overall variance for outcome\n      # X <- lapply(wide_list$X, function(x) t(t(x) - colMeans(x)) / sd(x))\n      wide_bal <- list(X = do.call(cbind, lapply(wide_list$X, function(x) t(na.omit(t(x))))),\n                        y = do.call(cbind, lapply(wide_list$y, function(x) t(na.omit(t(x))))),\n                        trt = wide_list$trt)\n\n      # V matrix scales by inverse variance for outcome and number of periods\n      V <- do.call(c, \n          lapply(wide_list$X, \n            function(x) rep(1 / (sqrt(nrow(na.omit(t(x)))) * \n                    sd(x[wide_list$trt == 0, , drop = F], na.rm=T)),\n                    nrow(na.omit(t(x))))))\n\n    # } else if(combine_method == \"svd\") {\n    #     wide_bal <- list(X = do.call(cbind, wide_list$X),\n    #                      y = do.call(cbind, wide_list$y),\n    #                      trt = wide_list$trt)\n\n    #     # first get the standard deviations of the outcomes to put on the same scale\n    #     sds <- do.call(c, \n    #         lapply(wide_list$X, \n    #             function(x) rep((sqrt(ncol(x)) * sd(x, na.rm=T)), ncol(x))))\n\n    #     # do an SVD on centered and scaled outcomes\n    #     X0 <- wide_bal$X[wide_bal$trt == 0, , drop = FALSE]\n    #     X0 <- t((t(X0) - colMeans(X0)) / sds)\n    #     k <- if(is.null(k)) ncol(X0) else k\n    #     V <- diag(1 / sds) %*% svd(X0)$v[, 1:k, drop = FALSE]\n  } else if(combine_method == \"avg\") {\n      # average pre-treatment outcomes, dividing by standard deviation and removing missing values\n      X_avg <- rowMeans(simplify2array(lapply(wide_list$X,\n                                  function(x) (x - mean(x[wide_list$trt == 0,], na.rm = TRUE)) / sd(x[wide_list$trt == 0,], na.rm = TRUE))), dims = 2, na.rm = TRUE)\n      # remove any time periods with NAs\n      X_avg <- t(na.omit(t(X_avg)))\n    wide_bal <- list(X = X_avg,\n        y = rowMeans(simplify2array(wide_list$y), dims = 2, na.rm = TRUE),\n        trt = wide_list$trt)\n\n    V <- diag(ncol(wide_bal$X))\n\n  } else if(combine_method == \"avg_concat\") {\n      # average pre-treatment outcomes, dividing by standard deviation and removing missing values\n      # standardize the outcomes\n      X_list_std<- lapply(wide_list$X,function(x) (x - mean(x[wide_list$trt == 0,], na.rm = TRUE)) / sd(x[wide_list$trt == 0,], na.rm = TRUE))\n\n\n      X_avg <- rowMeans(simplify2array(X_list_std), dims = 2, na.rm = TRUE)\n      # remove any time periods with NAs\n      X_avg <- t(na.omit(t(X_avg)))\n\n      X_concat <- do.call(cbind, lapply(X_list_std, function(x) t(na.omit(t(x)))))\n\n      # V matrix assigns weight nu to the averaged objective and (1 - nu) to the concatenated objective\n      # V <- c(rep(sqrt(nu), ncol(X_avg)),\n      #         sqrt(1 - nu) / sqrt(n_outs) * do.call(c, \n      #             lapply(wide_list$X,\n      #               function(x) rep(1 / (sqrt(nrow(na.omit(t(x)))) * \n      #                       sd(x[wide_list$trt == 0, , drop = F], na.rm=T)),\n      #                       nrow(na.omit(t(x))))))\n      # )\n      V <- c(rep(sqrt(nu), ncol(X_avg)), rep(sqrt(1 - nu) / sqrt(n_outs), ncol(X_concat)))\n      wide_bal <- list(\n        X = cbind(X_avg, X_concat),\n        y = do.call(cbind, lapply(wide_list$y, function(x) t(na.omit(t(x))))),\n        trt = wide_list$trt\n      )\n    } else {\n        stop(paste(\"combine_method should be one of ('avg', 'concat', 'avg_concat'),\", \n            combine_method, \" is not a valid combining option\"))\n    }\n\n  mhat_pre <- do.call(cbind, mhat_pre)\n  mhat_post <- do.call(cbind, mhat_post)\n  mhat <- cbind(mhat_pre, mhat_post)\n\n  return(list(wide_bal = wide_bal, mhat = mhat, V = V))\n\n}\n\n#' Get prediction of ATT or average outcome under control\n#' @param object augsynth_multiout object\n#' @param ... Optional arguments, including \\itemize{\\item{\"att\"}{Whether to return the ATT or average outcome under control}}\n#'\n#' @return Vector of predicted post-treatment control averages\n#' @export\npredict.augsynth_multiout <- function(object, ...) {\n    if (\"att\" %in% names(list(...))) {\n        att <- list(...)$att\n    } else {\n        att <- F\n    }\n\n    # call augsynth predict\n    pred <- NextMethod()\n\n    # separate out by outcome\n    n_outs <- length(object$data_list$X)\n    max_t <- max(sapply(1:n_outs, \n      function(k) ncol(object$data_list$X[[k]]) + ncol(object$data_list$y[[k]])))\n    pred_reshape <- matrix(NA, ncol = n_outs, \n                               nrow = max_t)\n    colnames <- lapply(1:n_outs, \n      function(k) colnames(cbind(object$data_list$X[[k]], \n                                 object$data_list$y[[k]])))\n    rownames(pred_reshape) <- colnames[[which.max(sapply(colnames, length))]]\n    colnames(pred_reshape) <- object$outcomes\n    # get outcome names for predictions\n\n    pre_outs <- do.call(c, \n                        sapply(1:n_outs, \n                               function(j) {\n                                   rep(object$outcomes[j],\n                                       ncol(object$data_list$X[[j]]))\n                               }, simplify = FALSE))\n    \n    post_outs <- do.call(c,\n                         sapply(1:n_outs, \n                                function(j) {\n                                    rep(object$outcomes[j],\n                                        ncol(object$data_list$y[[j]]))\n                               }, simplify = FALSE))\n    # print(pred)\n    # print(cbind(names(pred), c(pre_outs, post_outs)))\n    \n    pred_reshape[cbind(names(pred), c(pre_outs, post_outs))] <- pred\n    return(pred_reshape)\n}\n\n\n#' Print function for augsynth\n#' @param x augsynth_multiout object\n#' @param ... Optional arguments\n#' @export\nprint.augsynth_multiout <- function(x, ...) {\n    ## straight from lm\n    cat(\"\\nCall:\\n\", paste(deparse(x$call), sep=\"\\n\", collapse=\"\\n\"), \"\\n\\n\", sep=\"\")\n\n    ## print att estimates\n    att <- predict(x, att = T)\n    att_post <- data.frame(\n        colMeans(att[as.numeric(rownames(att)) >= x$t_int,, drop = F]))\n    names(att_post) <- c(\"\")\n    cat(\"Average ATT Estimate:\\n\")\n    print(att_post)\n    cat(\"\\n\\n\")\n}\n\n#' Summary function for augsynth\n#' @param object augsynth_multiout object\n#' @param inf whether or not to perform inference\n#' @param inf_typ Type of inference, default is \"conformal\"\n#' @param grid_size Grid to compute prediction intervals over, default is 1 and only p-values are computed\n#' @param ... Optional arguments, including \\itemize{\\item{\"se\"}{Whether to plot standard error}}\n#' @export\nsummary.augsynth_multiout <- function(object, inf = T, inf_type = \"conformal\", grid_size = 1, ...) {\n    \n\n    summ <- list()\n\n    if(inf) {\n        if(inf_type == \"conformal\") {\n          if(grid_size > 1) {\n            cat(paste0(\"A grid size of \", grid_size, \" will require \",\n                           grid_size, \"^\", length(object$outcomes),\n                           \" = \", grid_size ^ length(object$outcomes),\n                           \" evaluations. This could take a while...\"))\n          }\n          att_se <- conformal_inf_multiout(object, grid_size = grid_size, ...)\n        } else {\n          stop(\"Only conformal inference is supported for multiple outcomes\")\n        }\n        # if(inf_type == \"jackknife\") {\n        #     att_se <- jackknife_se_multiout(object)\n        # } else if(inf_type == \"jackknife+\") {\n        #   att_se <- time_jackknife_plus_multiout(object, ...)\n        # } else if(inf_type == \"conformal\") {\n        #   att_se <- conformal_inf_multiout(object, ...)\n        # } else {\n        #     stop(paste(inf_type, \"is not a valid choice of 'inf_type'\"))\n        # }\n\n        t_final <- nrow(att_se$att)\n\n        att_df <- data.frame(att_se$att[1:(t_final - 1),, drop=F])\n        names(att_df) <- object$outcomes\n        att_df$Time <- object$data$time\n        att_df <- att_df %>% gather(Outcome, Estimate, -Time)\n\n        # if(inf_type == \"jackknife\") {\n        #   se_df <- data.frame(att_se$se[1:(t_final - 1),, drop=F])\n        #   names(se_df) <- object$outcomes\n        #   se_df$Time <- object$data$time\n        #   se_df <- se_df %>% gather(Outcome, Std.Error, -Time)\n\n        #   att <- inner_join(att_df, se_df, by = c(\"Time\", \"Outcome\"))\n        # } else if(inf_type %in% c(\"conformal\", \"jackknife+\")) {\n          \n        lb_df <- data.frame(att_se$lb[1:(t_final - 1),, drop=F])\n        names(lb_df) <- object$outcomes\n        lb_df$Time <- object$data$time\n        lb_df <- lb_df %>% gather(Outcome, lower_bound, -Time)\n\n        ub_df <- data.frame(att_se$ub[1:(t_final - 1),, drop=F])\n        names(ub_df) <- object$outcomes\n        ub_df$Time <- object$data$time\n        ub_df <- ub_df %>% gather(Outcome, upper_bound, -Time)\n\n        att <- inner_join(att_df, lb_df, by = c(\"Time\", \"Outcome\")) %>%\n            inner_join(ub_df, by = c(\"Time\", \"Outcome\")) \n          # if(inf_type == \"conformal\") {\n\n          pval_df <- data.frame(att_se$p_val[1:(t_final - 1),, drop=F])\n          names(pval_df) <- object$outcomes\n          pval_df$Time <- object$data$time\n          pval_df <- pval_df %>% gather(Outcome, p_val, -Time)\n          att <- inner_join(att, pval_df, by = c(\"Time\", \"Outcome\")) \n          # }\n        # }\n        if(grid_size == 1) {\n          att <- att %>% mutate(lower_bound = NA, upper_bound = NA)\n        }\n\n        att_avg <- data.frame(att_se$att[t_final,, drop = F])\n        names(att_avg) <- object$outcomes\n        att_avg <- gather(att_avg, Outcome, Estimate)\n\n        # if(inf_type == \"jackknife\") {\n        #   att_avg_se <- data.frame(att_se$se[t_final,, drop = F])\n        #   names(att_avg_se) <- object$outcomes\n        #   att_avg_se <- gather(att_avg_se, Outcome, Std.Error)\n        #   average_att <- inner_join(att_avg, att_avg_se, by=\"Outcome\")\n        # } else if(inf_type %in% c(\"conformal\", \"jackknife+\")){\n        att_avg_lb <- data.frame(att_se$lb[t_final,, drop = F])\n        names(att_avg_lb) <- object$outcomes\n        att_avg_lb <- gather(att_avg_lb, Outcome, lower_bound)\n\n        att_avg_ub <- data.frame(att_se$ub[t_final,, drop = F])\n        names(att_avg_ub) <- object$outcomes\n        att_avg_ub <- gather(att_avg_ub, Outcome, upper_bound)\n        \n\n        average_att <- inner_join(att_avg, att_avg_lb, by=\"Outcome\") %>%\n            inner_join(att_avg_ub, by = \"Outcome\")\n          \n          # if(inf_type == \"conformal\") {\n        att_avg_pval <- data.frame(att_se$p_val[t_final,, drop = F])\n        names(att_avg_pval) <- object$outcomes\n        att_avg_pval <- gather(att_avg_pval, Outcome, p_val)\n\n        average_att <- inner_join(average_att, att_avg_pval, by = \"Outcome\")\n\n        if(grid_size == 1) {\n          average_att <- average_att %>% mutate(lower_bound = NA, upper_bound = NA)\n        }\n          # }\n        # } else {\n        #   average_att <- gather(att_avg, Outcome, Estimate)\n        # }\n        \n\n    } else {\n        att_est <- predict(object, att = T)\n        att_df <- data.frame(att_est)\n        names(att_df) <- object$outcomes\n        att_df$Time <- object$data$time\n        att <- att_df %>% gather(Outcome, Estimate, -Time)\n        att$Std.Error <- NA\n        t_int <- min(sapply(object$data_list$X, ncol))\n        att_avg <- data.frame(t(colMeans(\n            att_est[t_int:nrow(att_est),, drop = F])))\n        print(att_avg)\n        names(att_avg) <- object$outcomes\n        average_att <- gather(att_avg, Outcome, Estimate)\n        average_att$Std.Error <- NA\n    }\n\n      # get average of all outcomes\n    sds <- data.frame(Outcome = object$outcomes,\n                      sdo = sapply(object$data_list$X,\n                                    function(x)  sd(x[object$data_list$trt == 0,], na.rm = TRUE)))\n\n    att %>%\n      inner_join(sds, by = \"Outcome\") %>%\n      mutate(Estimate = Estimate / sdo) %>%\n      group_by(Time) %>%\n      summarise(Estimate = mean(Estimate, na.rm = TRUE)) %>%\n      mutate(Outcome = \"Average\") %>%\n      bind_rows(att, .) -> att\n\n    summ$att <- att\n    summ$average_att <- average_att\n    summ$t_int <- object$t_int\n    summ$call <- object$call\n    summ$l2_imbalance <- object$l2_imbalance\n    summ$scaled_l2_imbalance <- object$scaled_l2_imbalance\n    summ$inf_type <- inf_type\n    ## get estimated bias\n\n    if(object$progfunc == \"Ridge\") {\n        mhat <- object$ridge_mhat\n        w <- object$synw\n    } else {\n        mhat <- object$mhat\n        w <- object$weights\n    }\n    trt <- object$data$trt\n    m1 <- colMeans(mhat[trt==1,,drop=F])\n\n    summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w\n    if(object$progfunc == \"None\" | (!object$scm)) {\n        summ$bias_est <- NA\n    }\n\n    \n\n    \n    class(summ) <- \"summary.augsynth_multiout\"\n    return(summ)\n}\n\n\n#' Print function for summary function for augsynth\n#' @param x summary.augsynth_multiout object\n#' @param ... Optional arguments\n#' @export\nprint.summary.augsynth_multiout <- function(x, ...) {\n    ## straight from lm\n    cat(\"\\nCall:\\n\", paste(deparse(x$call), sep=\"\\n\", collapse=\"\\n\"), \"\\n\\n\", sep=\"\")\n    \n    att_est <- x$att$Estimate\n    ## get pre-treatment fit by outcome\n    imbal <- x$att %>% \n        filter(Time < x$t_int) %>%\n        group_by(Outcome) %>%\n        summarise(Pre.RMSE = sqrt(mean(Estimate ^ 2, na.rm = TRUE)))\n\n    cat(paste(\"Overall L2 Imbalance (Scaled):\",\n              format(round(x$l2_imbalance,3), nsmall=3), \"  (\",\n              format(round(x$scaled_l2_imbalance,3), nsmall=3), \")\\n\\n\",\n            #   \"Avg Estimated Bias: \",\n            #   format(round(mean(summ$bias_est), 3),nsmall=3), \"\\n\\n\",\n              sep=\"\"))\n    cat(\"Average ATT Estimate:\\n\")\n    print(inner_join(x$average_att, imbal, by = \"Outcome\"))\n    cat(\"\\n\\n\")\n}\n\n\n#' Plot function for summary function for augsynth\n#' @importFrom graphics plot\n#' @param x summary.augsynth_multiout object\n#' @param inf Boolean, whether to plot uncertainty intervals, default TRUE\n#' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE\n#' @param ... Optional arguments for summary function\n#' \n#' @export\nplot.augsynth_multiout  <- function(x, inf = T, plt_avg = F, ...) {\n  plot(summary(x, ...), inf =  inf, plt_avg = plt_avg)\n}\n\n#' Plot function for summary function for augsynth\n#' @param x summary.augsynth_multiout object\n#' @param inf Boolean, whether to plot uncertainty intervals, default TRUE\n#' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE\n#' \n#' @export\nplot.summary.augsynth_multiout <- function(x, inf = F, plt_avg = F, ...) {\n    if(plt_avg) {\n      p <- x$att %>%\n        ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))\n    } else {\n      p <- x$att %>%\n        filter(Outcome != \"Average\") %>% \n        ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))\n    }\n    if(inf) {\n      if(x$inf_type == \"jackknife\") {\n        p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error,\n                        ymax=Estimate+2*Std.Error),\n                    alpha=0.2, data = . %>% filter(Outcome != \"Average\"))\n      } else if(x$inf_type %in% c(\"conformal\", \"jackknife+\")) {\n        p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound,\n                        ymax=upper_bound),\n                    alpha=0.2, data =  . %>% filter(Outcome != \"Average\"))\n      }\n\n    }\n    p + ggplot2::geom_line() +\n        ggplot2::geom_vline(xintercept=x$t_int, lty=2) +\n        ggplot2::geom_hline(yintercept=0, lty=2) +\n        ggplot2::facet_wrap(~ Outcome, scales = \"free_y\") +\n        ggplot2::theme_bw()\n\n}"
  },
  {
    "path": "R/multi_synth_qp.R",
    "content": "################################################################################\n## Solve the multisynth problem as a QP\n################################################################################\n\n\n#' Internal function to fit synth with staggered adoption with a QP solver\n#' @param X Matrix of pre-final intervention outcomes, or list of such matrices after transformations\n#' @param trt Vector of treatment levels/times\n#' @param mask Matrix with indicators for observed pre-intervention times for each treatment group\n#' @param n_leads Number of time periods after treatment to impute control values.\n#'            For units treated at time T_j, all units treated after T_j + n_leads\n#'            will be used as control values. If larger than the number of periods,\n#'            only never never treated units (pure controls) will be used as comparison units\n#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods\n#' @param relative Whether to re-index time according to treatment date, default T\n#' @param nu Hyper-parameter that controls trade-off between overall and individual balance.\n#'              Larger values of nu place more emphasis on individual balance.\n#'              Balance measure is\n#'                nu ||global|| + (1-nu) ||individual||\n#'              Default: 0\n#' @param lambda Regularization hyper-parameter. Default, 0\n#' @param time_cohort Whether to average synthetic controls into time cohorts\n#' @param norm_pool Normalizing value for pooled objective, default: number of treated units squared\n#' @param norm_sep Normalizing value for separate objective, default: number of treated units\n#' @param verbose Whether to print logs for osqp\n#' @param eps_rel Relative error tolerance for osqp\n#' @param eps_abs Absolute error tolerance for osqp\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Matrix of unit weights}\n#'          \\item{\"imbalance\"}{Matrix of overall and group specific imbalance}\n#'          \\item{\"global_l2\"}{Imbalance overall}\n#'          \\item{\"ind_l2\"}{Matrix of imbalance for each group}\n#'         }\nmultisynth_qp <- function(X, trt, mask, Z = NULL, n_leads=NULL, n_lags=NULL,\n                          relative=T, nu=0, lambda=0, V = NULL, time_cohort = FALSE,\n                          donors = NULL, norm_pool = NULL, norm_sep = NULL,\n                          verbose = FALSE, \n                          eps_rel=1e-4, eps_abs=1e-4) {\n\n    # if Z has no columns then set it to NULL\n    if(!is.null(Z)) {\n      if(ncol(Z) == 0) {\n        Z <- NULL\n      }\n    }\n\n    n <- if(typeof(X) == \"list\") dim(X[[1]])[1] else dim(X)[1]\n    d <- if(typeof(X) == \"list\") dim(X[[1]])[2] else dim(X)[2]\n\n    if(is.null(n_leads)) {\n        n_leads <- d+1\n    } else if(n_leads > d) {\n        n_leads <- d+1\n    }\n    if(is.null(n_lags)) {\n        n_lags <- d\n    } else if(n_lags > d) {\n        n_lags <- d\n    }\n    V <- make_V_matrix(n_lags, V)\n    ## treatment times\n    if(time_cohort) {\n        grps <- unique(trt[is.finite(trt)])\n        which_t <- lapply(grps, function(tj) (1:n)[trt == tj])\n        # if doing a time cohort, convert the boolean mask\n        mask <- unique(mask)\n    } else {\n        grps <- trt[is.finite(trt)]\n        which_t <- (1:n)[is.finite(trt)]\n    }\n\n\n    J <- length(grps)\n    if(is.null(norm_sep)) {\n      norm_sep <- 1#J\n    }\n    if(is.null(norm_pool)) {\n      norm_pool <- 1#J ^ 2\n    }\n    n1 <- sapply(1:J, function(j) length(which_t[[j]]))\n\n    # if no specific donors passed in,\n    # then all donors treated after n_lags are eligible\n    if(is.null(donors)) {\n      donors <- get_eligible_donors(trt, time_cohort, n_leads)\n    }\n\n    ## handle X differently if it is a list\n    if(typeof(X) == \"list\") {\n        x_t <- lapply(1:J, function(j) colSums(X[[j]][which_t[[j]], mask[j,]==1, drop=F]))\n        \n        # Xc contains pre-treatment data for valid donor units\n        Xc <- lapply(1:nrow(mask),\n                 function(j) X[[j]][donors[[j]], mask[j,]==1, drop=F])\n        \n        # std dev of outcomes for first treatment time\n        sdx <- sd(X[[1]][is.finite(trt)])\n    } else {\n        x_t <- lapply(1:J, function(j) colSums(X[which_t[[j]], mask[j,]==1, drop=F]))        \n        \n        # Xc contains pre-treatment data for valid donor units\n        Xc <- lapply(1:nrow(mask),\n                 function(j) X[donors[[j]], mask[j,]==1, drop=F])\n\n        # std dev of outcomes\n        sdx <- sd(X[is.finite(trt)])\n    }\n\n    # get covariates for donors\n    if(!is.null(Z)) {\n      # scale covariates to have same variance as pure control outcomes\n      Z_scale <- sdx * apply(Z, 2, \n        function(z) (z - mean(z[!is.finite(trt)])) / sd(z[!is.finite(trt)]))\n\n      z_t <- lapply(1:J, function(j) colSums(Z_scale[which_t[[j]], , drop = F]))\n      Zc <- lapply(1:J, function(j) Z_scale[donors[[j]], , drop = F])\n    } else {\n      z_t <- lapply(1:J, function(j) c(0))\n      Zc <- lapply(1:J, function(j) Matrix::Matrix(0,\n                                                   nrow = sum(donors[[j]]),\n                                                   ncol = 1))\n    }\n    dz <- ncol(Zc[[1]])\n\n    # replace NA values with zero\n    x_t <- lapply(x_t, function(xtk) tidyr::replace_na(xtk, 0))\n    Xc <- lapply(Xc, function(xck) apply(xck, 2, tidyr::replace_na, 0))\n    ## make matrices for QP\n    n0s <- sapply(Xc, nrow)\n    if(any(n0s == 0)) {\n      stop(\"Some treated units have no possible donor units!\")\n    }\n    n0 <- sum(n0s)\n\n    const_mats <- make_constraint_mats(trt, grps, n_leads, n_lags, Xc, Zc, d, n1)\n    Amat <- const_mats$Amat\n    lvec <- const_mats$lvec\n    uvec <- const_mats$uvec\n\n    ## quadratic balance measures\n\n    qvec <- make_qvec(Xc, x_t, z_t, nu, n_lags, d, V, norm_pool, norm_sep)\n\n    Pmat <- make_Pmat(Xc, x_t, dz, nu, n_lags, lambda, d, V, norm_pool, norm_sep)\n\n    ## Optimize\n    settings <- do.call(osqp::osqpSettings, \n                        c(list(verbose = verbose, \n                               eps_rel = eps_rel, \n                               eps_abs = eps_abs)))\n\n    out <- osqp::solve_osqp(Pmat, qvec, Amat, lvec, uvec, pars = settings)\n\n    ## get weights\n    total_ctrls <- n0 * J\n    weights <- matrix(out$x[1:total_ctrls], nrow = n0)\n    nj0 <- as.numeric(lapply(Xc, nrow))\n    nj0cumsum <- c(0, cumsum(nj0))\n    imbalance <- vapply(1:J,\n                        function(j) {\n                            dj <- length(x_t[[j]])\n                            ndim <- min(dj, n_lags)\n                            c(numeric(d-ndim),\n                            x_t[[j]][(dj-ndim+1):dj] -\n                                t(Xc[[j]][,(dj-ndim+1):dj, drop = F]) %*% \n                                out$x[(nj0cumsum[j] + 1):nj0cumsum[j + 1]])\n                        },\n                        numeric(d))\n    avg_imbal <- rowMeans(t(t(imbalance)))\n\n    Vsq <- t(V) %*% V\n    global_l2 <- c(sqrt(t(avg_imbal[(d - n_lags + 1):d]) %*% Vsq %*%\n                          avg_imbal[(d - n_lags + 1):d])) / sqrt(d)\n    avg_l2 <- mean(apply(imbalance, 2,\n                  function(x) c(sqrt(t(x[(d - n_lags + 1):d]) %*% Vsq %*%\n                                x[(d - n_lags + 1):d]))))\n    ind_l2 <- sqrt(mean(\n      apply(imbalance, 2,\n      function(x) c(x[(d - n_lags + 1):d] %*% Vsq %*% x[(d - n_lags + 1):d]) /\n          sum(x[(d - n_lags + 1):d] != 0))))\n    # pad weights with zeros for treated units and divide by number of treated units\n    vapply(1:J,\n           function(j) {\n             weightj <-  numeric(n)\n             weightj[donors[[j]]] <- out$x[(nj0cumsum[j] + 1):nj0cumsum[j + 1]]\n             weightj\n           },\n           numeric(n)) -> weights\n\n    weights <- t(t(weights) / n1)\n    # manually enforce non-negativity constraint\n    # (osqp solver only enforces constraint up to a tolerance)\n    weights <- pmax(weights, 0)\n\n    output <- list(weights = weights,\n                   imbalance = cbind(avg_imbal, imbalance),\n                   global_l2 = global_l2,\n                   ind_l2 = ind_l2,\n                   avg_l2 = avg_l2,\n                   V = V)\n\n    if(!is.null(Z)) {\n      # imbalance in auxiliary covariates\n      z_t <- sapply(1:J, function(j) colMeans(Z[which_t[[j]], , drop = F]))\n      imbal_z <- z_t - t(Z) %*% weights\n      avg_imbal_z <- rowSums(t(t(imbal_z) * n1)) / sum(n1)\n      global_l2_z <- sqrt(sum(avg_imbal_z ^ 2))\n      ind_l2_z <- sum(apply(imbal_z, 2, function(x) sqrt(sum(x ^ 2))))\n      imbal_z <- cbind(avg_imbal_z, imbal_z)\n      rownames(imbal_z) <- colnames(Z)\n\n      output$imbalance_aux <- imbal_z\n      output$global_l2_aux <- global_l2_z\n      output$ind_l2_aux <- ind_l2_z\n    }\n    \n\n    \n    \n\n    return(output)\n}\n\n\n#' Create constraint matrices for multisynth QP\n#' @param trt Vector of treatment levels/times\n#' @param grps Treatment times\n#' @param n_leads Number of time periods after treatment to impute control values.\n#' @param n_lags Number of pre-treatment periods to balance\n#' @param Xc List of outcomes for possible comparison units\n#' @param d Max number of lagged outcomes\n#' @param n1 Vector of number of treated units per cohort\n#' @noRd\n#' @return \n#'         \\itemize{\n#'          \\item{\"Amat\"}{Linear constraint matrix}\n#'          \\item{\"lvec\"}{Lower bounds for linear constraints}\n#'          \\item{\"uvec\"}{Upper bounds for linear constraints}\n#'         }\nmake_constraint_mats <- function(trt, grps, n_leads, n_lags, Xc, Zc, d, n1) {\n\n    J <- length(grps)\n    idxs0  <- trt  > n_leads + min(grps)\n\n    n0 <- sum(idxs0)\n\n    ## sum to n1 constraint\n    A1 <- do.call(Matrix::bdiag, lapply(1:(J), function(x) rep(1, n0)))\n    A1 <- Matrix::bdiag(lapply(1:J, function(j) rep(1, nrow(Xc[[j]]))))\n    \n    Amat <- as.matrix(Matrix::t(A1))\n    Amat <- Matrix::rbind2(Matrix::t(A1), Matrix::Diagonal(nrow(A1)))\n\n    dz <- ncol(Zc[[1]])\n    # constraints for transformed weights\n    A_trans1 <- do.call(Matrix::bdiag,\n                       lapply(1:J,\n                        function(j)  {\n                            dj <- ncol(Xc[[j]])\n                            ndim <- min(dj, n_lags)\n                            max_dim <- min(d, n_lags)\n                            mat <- Xc[[j]][, (dj - ndim + 1):dj, drop = F]\n                            n0 <- nrow(mat)\n                            zero_mat <- Matrix::Matrix(0, n0, max_dim - ndim)\n                            Matrix::t(cbind(zero_mat, mat))\n                       }))\n\n    # sum of total number of pre-periods\n    sum_tj <- min(d, n_lags) * J\n    A_trans2 <- - Matrix::Diagonal(sum_tj)\n    A_trans <- Matrix::cbind2(\n      Matrix::cbind2(A_trans1, A_trans2),\n      Matrix::Matrix(0, nrow = nrow(A_trans1), ncol = dz * J))\n\n    # constraints for transformed weights on auxiliary covariates\n    A_transz <- Matrix::t(Matrix::bdiag(Zc))\n    A_transz <- Matrix::cbind2(\n      Matrix::cbind2(A_transz, \n                     Matrix::Matrix(0, nrow = nrow(A_transz), ncol = sum_tj)),\n      -Matrix::Diagonal(dz * J))\n\n    # add in zero columns for transformated weights\n    Amat <- Matrix::cbind2(Amat, \n                           Matrix::Matrix(0,\n                                          nrow = nrow(Amat),\n                                          ncol = sum_tj + dz * J))\n    Amat <- Matrix::rbind2(Matrix::rbind2(Amat, A_trans), A_transz)\n\n    lvec <- c(n1, # sum to n1 constraint\n              numeric(nrow(A1)), # lower bound by zero\n              numeric(sum_tj), # constrain transformed weights\n              numeric(dz * J) # constrain transformed weights\n             ) \n    \n    uvec <- c(n1, #sum to n1 constraint\n              rep(Inf, nrow(A1)),\n              numeric(sum_tj), # constrain transformed weights\n              numeric(dz * J) # constrain transformed weights\n              )\n\n\n    return(list(Amat = Amat, lvec = lvec, uvec = uvec))\n}\n\n#' Make the vector in the QP\n#' @param Xc List of outcomes for possible comparison units\n#' @param x_t List of outcomes for treated units\n#' @param nu Hyperparameter between global and individual balance\n#' @param n_lags Number of lags to balance\n#' @param d Largest number of pre-intervention time periods\n#' @param V Scaling matrix\n#' @param norm_pool Normalizing value for pooled objective\n#' @param norm_sep Normalizing value for separate objective\n#' @noRd\nmake_qvec <- function(Xc, x_t, z_t, nu, n_lags, d, V, norm_pool, norm_sep) {\n\n    J <- length(x_t)\n    Vsq <- t(V) %*% V\n    qvec <- lapply(1:J,\n                   function(j) {\n                       dj <- length(x_t[[j]])\n                       ndim <- min(dj, n_lags)\n                       max_dim <- min(d, n_lags)\n                       vec <- x_t[[j]][(dj - ndim + 1):dj] / ndim\n                       Vsq %*% c(numeric(max_dim - ndim), vec)\n                   })\n\n    avg_target_vec <- lapply(x_t,\n                            function(xtk) {\n                                dk <- length(xtk)\n                                ndim <- min(dk, n_lags)\n                                max_dim <- min(d, n_lags)\n                                c(numeric(max_dim - ndim), \n                                    xtk[(dk - ndim + 1):dk])\n                            }) %>% reduce(`+`) %*% Vsq\n    qvec_avg <- rep(avg_target_vec, J)\n    # qvec <- - (nu * qvec_avg / n_lags + (1 - nu) * reduce(qvec, c))\n    # qvec <- - (nu * qvec_avg / (J ^ 2 * n_lags) +\n    #           (1 - nu) * reduce(qvec, c) / J)\n    qvec <- - (nu * qvec_avg / (norm_pool * n_lags * J ^ 2) +\n               (1 - nu) * reduce(qvec, c) / (norm_sep * J))\n\n    qvec_avg_z <- z_t %>% reduce(`+`)\n    qvec_avg_z <- rep(qvec_avg_z, J)\n    # qvec_z <- - (nu * qvec_avg_z + (1 - nu) * reduce(z_t, c)) / length(z_t[[1]])\n    # qvec_z <- - (nu * qvec_avg_z / J ^2 +\n    #              (1 - nu) * reduce(z_t, c) / J) / length(z_t[[1]])\n    qvec_z <- - (nu * qvec_avg_z / (norm_pool * J ^ 2) +\n                 (1 - nu) * reduce(z_t, c) / (norm_sep * J)) / length(z_t[[1]])\n\n    total_ctrls <- lapply(Xc, nrow) %>% reduce(`+`)\n    return(c(numeric(total_ctrls), qvec, qvec_z))\n}\n\n\n#' Make the matrix in the QP\n#' @param Xc List of outcomes for possible comparison units\n#' @param x_t List of outcomes for treated units\n#' @param nu Hyperparameter between global and individual balance\n#' @param n_lags Number of lags to balance\n#' @param lambda Regularization hyperparameter\n#' @param d Largest number of pre-intervention time periods\n#' @param V Scaling matrix\n#' @param norm_pool Normalizing value for pooled objective\n#' @param norm_sep Normalizing value for separate objective\n#' @noRd\nmake_Pmat <- function(Xc, x_t, dz, nu, n_lags, lambda, d, V,\n                      norm_pool, norm_sep) {\n\n    J <- length(x_t)\n\n    Vsq <- t(V) %*% V\n    ndims <- vapply(1:J,\n                    function(j) min(length(x_t[[j]]), n_lags),\n                    numeric(1))\n    max_dim <- min(d, n_lags)\n    total_dim <- sum(ndims)\n    total_dim <- max_dim * J\n    V1 <- Matrix::bdiag(lapply(ndims, \n                        function(ndim) Matrix::Diagonal(max_dim, 1 / ndim)))\n    V1 <- Matrix::bdiag(lapply(ndims, function(ndim) Vsq / ndim))\n    tile_sparse <- function(j) {\n        kronecker(Matrix::Matrix(1, nrow = j, ncol = j), Vsq)\n    }\n    tile_sparse_cov <- function(d, j) {\n        kronecker(Matrix::Matrix(1, nrow = j, ncol = j),\n                  Matrix::Diagonal(d))\n    }\n    V2 <- tile_sparse(J) / n_lags\n    # Pmat <- nu * V2 + (1 - nu) * V1\n    # Pmat <- nu * V2 / J ^ 2 + (1 - nu) * V1 / J\n    Pmat <- nu * V2 / (norm_pool * J ^ 2) + (1 - nu) * V1 / (norm_sep * J)\n    V1_z <- Matrix::Diagonal(dz * J, 1 / dz)\n    V2_z <- tile_sparse_cov(dz, J) / dz\n    # Pmat_z <- nu * V2_z + (1 - nu) * V1_z\n    # Pmat_z <- nu * V2_z / J ^ 2 + (1 - nu) * V1_z / J\n    Pmat_z <- nu * V2_z / (norm_pool * J ^ 2) + (1 - nu) * V1_z / (norm_sep * J)\n    # combine\n    total_ctrls <- lapply(Xc, nrow) %>% reduce(`+`)\n    Pmat <- Matrix::bdiag(Matrix::Matrix(0, nrow = total_ctrls,\n                                         ncol = total_ctrls),\n                          Pmat, Pmat_z)\n    I0 <- Matrix::bdiag(Matrix::Diagonal(total_ctrls),\n                        Matrix::Matrix(0, nrow = total_dim + dz * J,\n                                          ncol = total_dim + dz * J))\n    return(Pmat + lambda * I0)\n\n}"
  },
  {
    "path": "R/multisynth_class.R",
    "content": "################################################################################\n## Fitting, plotting, summarizing staggered synth\n################################################################################\n\n#' Fit staggered synth\n#' @param form outcome ~ treatment | weighting covariates | approximate matching covaraites | exact matching covariates\n#' \\itemize{\n#'    \\item{outcome}{Name of the outcome of interest}\n#'    \\item{treatment}{Name of the treatment assignment variable}\n#'    \\item{weighting covariates}{Auxiliary covariates to weight on}\n#'    \\item{approximate matching covariates}{Auxiliary covariates to approximately match one before weighting}\n#'    \\item{exact matching covariates}{Auxiliary covariates to exactly match on before weighting}\n#' }\n#' If covariates are time-varying, their average value before the first unit is treated will be used. This can be changed by supplying a custom aggregation function to cov_agg.\n#' @param unit Name of unit column\n#' @param time Name of time column\n#' @param data Panel data as dataframe\n#' @param n_leads How long past treatment effects should be estimated for, default is number of post treatment periods for last treated unit\n#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods\n#' @param nu Fraction of balance for individual balance\n#' @param lambda Regularization hyperparameter, default = 0\n#' @param V Scaling matrix for synth optimization, default NULL is identity\n#' @param fixedeff Whether to include a unit fixed effect, default TRUE\n#' @param n_factors Number of factors for interactive fixed effects, setting to NULL fits with CV, default is 0\n#' @param scm Whether to fit scm weights\n#' @param time_cohort Whether to average synthetic controls into time cohorts, default FALSE\n#' @param cov_agg Covariate aggregation function\n#' @param eps_abs Absolute error tolerance for osqp\n#' @param eps_rel Relative error tolerance for osqp\n#' @param verbose Whether to print logs for osqp\n#' @param ... Extra arguments\n#' \n#' @return multisynth object that contains:\n#'         \\itemize{\n#'          \\item{\"weights\"}{weights matrix where each column is a set of weights for a treated unit}\n#'          \\item{\"data\"}{Panel data as matrices}\n#'          \\item{\"imbalance\"}{Matrix of treatment minus synthetic control for pre-treatment time periods, each column corresponds to a treated unit}\n#'          \\item{\"global_l2\"}{L2 imbalance for the pooled synthetic control}\n#'          \\item{\"scaled_global_l2\"}{L2 imbalance for the pooled synthetic control, scaled by the imbalance for unitform weights}\n#'          \\item{\"ind_l2\"}{Average L2 imbalance for the individual synthetic controls}\n#'          \\item{\"scaled_ind_l2\"}{Average L2 imbalance for the individual synthetic controls, scaled by the imbalance for unitform weights}\n#'         \\item{\"n_leads\", \"n_lags\"}{Number of post treatment outcomes (leads) and pre-treatment outcomes (lags) to include in the analysis}\n#'          \\item{\"nu\"}{Fraction of balance for individual balance}\n#'          \\item{\"lambda\"}{Regularization hyperparameter}\n#'          \\item{\"scm\"}{Whether to fit scm weights}\n#'          \\item{\"grps\"}{Time periods for treated units}\n#'          \\item{\"y0hat\"}{Pilot estimates of control outcomes}\n#'          \\item{\"residuals\"}{Difference between the observed outcomes and the pilot estimates}\n#'          \\item{\"n_factors\"}{Number of factors for interactive fixed effects}\n#'         }\n#' @export\nmultisynth <- function(form, unit, time, data,\n                       n_leads=NULL, n_lags=NULL,\n                       nu=NULL, lambda=0, V = NULL,\n                       fixedeff = TRUE,\n                       n_factors=0,\n                       scm=T,\n                       time_cohort = F,\n                       how_match = \"knn\",\n                       cov_agg = NULL,\n                       eps_abs = 1e-4,\n                       eps_rel = 1e-4,\n                       verbose = FALSE, ...) {\n    call_name <- match.call()\n\n    form <- Formula::Formula(form)\n    unit <- enquo(unit)\n    time <- enquo(time)\n\n    ## format data\n    outcome <- terms(formula(form, rhs=1))[[2]]\n    trt <- terms(formula(form, rhs=1))[[3]]\n    wide <- format_data_stag(outcome, trt, unit, time, data)\n\n    check_data_stag(wide, fixedeff, n_leads, n_lags)\n\n    force <- if(fixedeff) 3 else 2\n\n    # get covariates\n    if(length(form)[2] == 2) {\n      Z <- extract_covariates(form, unit, time, wide$time[min(wide$trt) + 1],\n                              data, cov_agg)\n    } else if(length(form)[2] == 3) {\n      app_form <- Formula::Formula(formula(form, rhs = 1:2))\n      Z_weight <- extract_covariates(app_form, unit, time,\n                                  wide$time[min(wide$trt) + 1],\n                                  data, cov_agg)\n      exact_form <- Formula::Formula(formula(form, rhs = c(1,3)))\n      Z_match<- extract_covariates(exact_form, unit, time,\n                                  wide$time[min(wide$trt) + 1],\n                                  data, cov_agg)\n      Z <- cbind(Z_weight, Z_match)\n      wide$match_covariates <- colnames(Z_match)\n    } else if(length(form)[2] == 4) {\n      if(time_cohort) {\n        stop(\"Aggregating by time cohort and matching on covariates are not \",\n             \"implemented together. If matching then you cannot aggregate \",\n             \"by time cohort.\")\n      }\n      weight_form <- Formula::Formula(formula(form, rhs = c(1,2)))\n      Z_weight <- extract_covariates(weight_form, unit, time,\n                                      wide$time[min(wide$trt) + 1],\n                                      data, cov_agg)\n      app_form <- Formula::Formula(formula(form, rhs = c(1,3)))\n      Z_app <- extract_covariates(app_form, unit, time,\n                                  wide$time[min(wide$trt) + 1],\n                                  data, cov_agg)\n      exact_form <- Formula::Formula(formula(form, rhs = c(1,4)))\n      Z_exact <- extract_covariates(exact_form, unit, time,\n                                  wide$time[min(wide$trt) + 1],\n                                  data, cov_agg)\n      Z <- cbind(Z_weight, Z_app, Z_exact)\n      wide$exact_covariates <- colnames(Z_exact)\n      wide$match_covariates <- c(colnames(Z_app), wide$exact_covariates)\n    } else {\n        Z <- NULL\n    }\n    wide$Z <- Z\n\n    # if n_leads is NULL set it to be the largest possible number of leads\n    # for the last treated unit\n    if(is.null(n_leads)) {\n        n_leads <- ncol(wide$y)\n    } else if(n_leads > max(apply(1-wide$mask, 1, sum, na.rm = T)) +\n                                                              ncol(wide$y)) {\n        n_leads <- max(apply(1-wide$mask, 1, sum, na.rm = T)) + ncol(wide$y)\n    }\n\n    ## if n_lags is NULL set it to the largest number of pre-treatment periods\n    if(is.null(n_lags)) {\n        n_lags <- ncol(wide$X)\n    } else if(n_lags > ncol(wide$X)) {\n        n_lags <- ncol(wide$X)\n    }\n\n    long_df <- data[c(quo_name(unit), quo_name(time), as.character(trt), as.character(outcome))]\n\n    msynth <- multisynth_formatted(wide = wide, relative = T,\n                                n_leads = n_leads, n_lags = n_lags,\n                                nu = nu, lambda = lambda, V = V,\n                                force = force, n_factors = n_factors,\n                                scm = scm, time_cohort = time_cohort,\n                                time_w = F, lambda_t = 0,\n                                fit_resids = TRUE, eps_abs = eps_abs,\n                                eps_rel = eps_rel, verbose = verbose, long_df = long_df, \n                                how_match = how_match, ...)\n    \n   \n    units <- data %>% arrange(!!unit) %>% distinct(!!unit) %>% pull(!!unit)\n    rownames(msynth$weights) <- units\n\n\n    if(scm) {\n        ## Get imbalance for uniform weights on raw data\n\n        ## TODO: Get rid of this stupid hack of just fitting the weights again with big lambda\n        unif <- multisynth_qp(X=wide$X, ## X=residuals[,1:ncol(wide$X)],\n                            trt=wide$trt,\n                            mask=wide$mask,\n                            Z = Z[, ! colnames(Z) %in% wide$match_covariates,\n                                  drop = F],\n                            n_leads=n_leads,\n                            n_lags=n_lags,\n                            relative=T,\n                            nu=0, lambda=1e10,\n                            V = V,\n                            time_cohort = time_cohort,\n                            donors = msynth$donors,\n                            eps_rel = eps_rel, \n                            eps_abs = eps_abs,\n                            verbose = verbose)\n        ## scaled global balance\n        ## msynth$scaled_global_l2 <- msynth$global_l2  / sqrt(sum(unif$imbalance[,1]^2))\n        msynth$scaled_global_l2 <- msynth$global_l2  / unif$global_l2\n\n        ## balance for individual estimates\n        ## msynth$scaled_ind_l2 <- msynth$ind_l2  / sqrt(sum(unif$imbalance[,-1]^2))\n        msynth$scaled_ind_l2 <- msynth$ind_l2  / unif$ind_l2\n    }\n\n    msynth$call <- call_name\n\n    return(msynth)\n\n}\n\n\n#' Internal funciton to fit staggered synth with formatted data\n#' @param wide List containing data elements\n#' @param relative Whether to compute balance by relative time\n#' @param n_leads How long past treatment effects should be estimated for\n#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods\n#' @param nu Fraction of balance for individual balance\n#' @param lambda Regularization hyperparameter, default = 0\n#' @param V Scaling matrix for synth optimization, default NULL is identity\n#' @param force c(0,1,2,3) what type of fixed effects to include\n#' @param n_factors Number of factors for interactive fixed effects, default does CV\n#' @param scm Whether to fit scm weights\n#' @param time_cohort Whether to average synthetic controls into time cohorts\n#' @param time_w Whether to fit time weights\n#' @param lambda_t Regularization for time regression\n#' @param fit_resids Whether to fit SCM on the residuals or not\n#' @param eps_abs Absolute error tolerance for osqp\n#' @param eps_rel Relative error tolerance for osqp\n#' @param verbose Whether to print logs for osqp\n#' @param long_df A long dataframe with 4 columns in the order unit, time, trt, outcome\n#' @param ... Extra arguments\n#' @noRd\n#' @return multisynth object\nmultisynth_formatted <- function(wide, relative=T, n_leads, n_lags,\n                       nu, lambda, V,\n                       force,\n                       n_factors,\n                       scm, time_cohort, \n                       time_w, lambda_t,\n                       fit_resids,\n                       eps_abs, eps_rel,\n                       verbose, long_df, \n                       how_match, ...) {\n    ## average together treatment groups\n    ## grps <- unique(wide$trt) %>% sort()\n    if(time_cohort) {\n        grps <- unique(wide$trt[is.finite(wide$trt)])\n    } else {\n        grps <- wide$trt[is.finite(wide$trt)]\n    }\n    J <- length(grps)\n\n    ## fit outcome models\n    if(time_w) {\n        # Autoregressive model\n        out <- fit_time_reg(cbind(wide$X, wide$y), wide$trt,\n                            n_leads, lambda_t, ...)\n        y0hat <- out$y0hat\n        residuals <- out$residuals\n        params <- out$time_weights\n    } else if(is.null(n_factors)) {\n        out <- tryCatch({\n            fit_gsynth_multi(long_df, cbind(wide$X, wide$y), wide$trt, force=force)\n        }, error = function(error_condition) {\n            stop(\"Cannot run CV because there are too few pre-treatment periods.\")\n        })\n\n        y0hat <- out$y0hat\n        params <- out$params\n        n_factors <- ncol(params$factor)\n        ## get residuals from outcome model\n        residuals <- cbind(wide$X, wide$y) - y0hat\n        \n    } else if (n_factors != 0) {\n        ## if number of factors is provided don't do CV\n        out <- fit_gsynth_multi(long_df, cbind(wide$X, wide$y), wide$trt,\n                                r=n_factors, CV=0, force=force)\n        y0hat <- out$y0hat\n        params <- out$params        \n        \n        ## get residuals from outcome model\n        residuals <- cbind(wide$X, wide$y) - y0hat\n    } else if(force == 0 & n_factors == 0) {\n        # if no fixed effects or factors, just take out \n        # control averages at each time point\n        # time fixed effects from pure controls\n        pure_ctrl <- cbind(wide$X, wide$y)[!is.finite(wide$trt), , drop = F]\n        y0hat <- matrix(colMeans(pure_ctrl, na.rm = TRUE),\n                          nrow = nrow(wide$X), ncol = ncol(pure_ctrl), \n                          byrow = T)\n        residuals <- cbind(wide$X, wide$y) - y0hat\n        params <- NULL\n    } else {\n        ## take out pre-treatment averages\n        fullmask <- cbind(wide$mask, matrix(0, nrow=nrow(wide$mask),\n                                            ncol=ncol(wide$y)))\n        out <- fit_feff(cbind(wide$X, wide$y), wide$trt, fullmask, force, time_cohort)\n        y0hat <- out$y0hat\n        residuals <- out$residuals\n        params <- NULL\n    }\n\n    ## balance the residuals\n    if(fit_resids) {\n        if(time_w) {\n            # fit scm on residuals after taking out unit fixed effects\n            fullmask <- cbind(wide$mask, matrix(0, nrow=nrow(wide$mask),\n                                            ncol=ncol(wide$y)))\n            out <- fit_feff(cbind(wide$X, wide$y), wide$trt, fullmask, force, time_cohort)\n            bal_mat <- lapply(out$residuals, function(x) x[,1:ncol(wide$X)])\n        } else if(typeof(residuals) == \"list\") {\n            bal_mat <- lapply(residuals, function(x) x[,1:ncol(wide$X)])\n        } else {\n            bal_mat <- residuals[,1:ncol(wide$X)]\n        }\n    } else {\n        # if not balancing residuals, then take out control averages\n        # for each time\n        ctrl_avg <- matrix(colMeans(wide$X[!is.finite(wide$trt), , drop = F]),\n                          nrow = nrow(wide$X), ncol = ncol(wide$X), byrow = T)\n        bal_mat <- wide$X - ctrl_avg\n        bal_mat <- wide$X\n    }\n    \n\n    if(scm) {\n\n        # get eligible set of donor units based on covariates\n        donors <- get_donors(wide$X, wide$y, wide$trt,\n                             wide$Z[, colnames(wide$Z) %in% \n                                      wide$match_covariates, drop = F],\n                             time_cohort, n_lags, n_leads, how = how_match,\n                             exact_covariates = wide$exact_covariates, ...)\n        # run separate synth for scaling\n        sep_fit <- multisynth_qp(X=bal_mat,\n                                    trt=wide$trt,\n                                    mask=wide$mask,\n                                    Z = wide$Z[, !colnames(wide$Z) %in%\n                                                  wide$match_covariates,\n                                                  drop = F],\n                                    n_leads=n_leads,\n                                    n_lags=n_lags,\n                                    relative=relative,\n                                    nu=0, lambda=lambda,\n                                    V = V,\n                                    time_cohort = time_cohort,\n                                    donors = donors,\n                                    eps_rel = eps_rel,\n                                    eps_abs = eps_abs,\n                                    verbose = verbose)\n        # if no nu value is provided, use default based on\n        # global and individual imbalance for separate synth\n        if(is.null(nu)) {\n            # select nu by triangle inequality ratio\n            glbl <- sep_fit$global_l2 * sqrt(nrow(sep_fit$imbalance))\n            ind <- sep_fit$avg_l2\n            nu <- glbl / ind\n\n        }\n\n        msynth <- multisynth_qp(X=bal_mat,\n                                trt=wide$trt,\n                                mask=wide$mask,\n                                Z = wide$Z[, !colnames(wide$Z) %in%\n                                                  wide$match_covariates,\n                                                  drop = F],\n                                n_leads=n_leads,\n                                n_lags=n_lags,\n                                relative=relative,\n                                nu=nu, lambda=lambda,\n                                V = V,\n                                time_cohort = time_cohort,\n                                donors = donors,\n                                norm_pool = sep_fit$global_l2 ^ 2,\n                                norm_sep = sep_fit$ind_l2 ^ 2,\n                                eps_rel = eps_rel,\n                                eps_abs = eps_abs,\n                                verbose = verbose)\n    } else {\n        msynth <- list(weights = matrix(0, nrow = nrow(wide$X), ncol = J),\n                       imbalance=NA,\n                       global_l2=NA,\n                       ind_l2=NA)\n    }\n\n    ## put in data and hyperparams\n    msynth$data <- wide\n    msynth$relative <- relative\n    msynth$n_leads <- n_leads\n    msynth$n_lags <- n_lags\n    msynth$nu <- nu\n    msynth$lambda <- lambda\n    msynth$scm <- scm\n    msynth$time_cohort <- time_cohort\n\n\n    msynth$grps <- grps\n    msynth$y0hat <- y0hat\n    msynth$residuals <- residuals\n\n    msynth$n_factors <- n_factors\n    msynth$force <- force\n\n\n    ## outcome model parameters\n    msynth$params <- params\n\n    # more arguments\n    msynth$scm <- scm\n    msynth$time_w <- time_w\n    msynth$lambda_t <- lambda_t\n    msynth$fit_resids <- fit_resids\n    msynth$extra_pars <- c(list(eps_abs = eps_abs, \n                                eps_rel = eps_rel, \n                                verbose = verbose), \n                           list(...))\n    msynth$long_df <- long_df\n\n    msynth$how_match <- how_match\n    msynth$donors <- donors\n    ##format output\n    class(msynth) <- \"multisynth\"\n    return(msynth)\n}\n\n\n\n\n\n\n#' Get prediction of average outcome under control or ATT\n#' @param object Fit multisynth object\n#' @param att If TRUE, return the ATT, if FALSE, return imputed counterfactual\n#' @param att_weight Weights to place on individual units/cohorts when averaging\n#' @param bs_weight Weight to perturb units by for weighted bootstrap\n#' @param ... Optional arguments\n#'\n#' @return Matrix of predicted post-treatment control outcomes for each treated unit\n#' @export\npredict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = NULL, ...) {\n\n    multisynth <- object\n    relative <- T\n    \n    time_cohort <- multisynth$time_cohort\n    if(is.null(relative)) {\n        relative <- multisynth$relative\n    }\n    n_leads <- multisynth$n_leads\n    d <- ncol(multisynth$data$X)\n    n <- nrow(multisynth$data$X)\n    fulldat <- cbind(multisynth$data$X, multisynth$data$y)\n    ttot <- ncol(fulldat)\n    grps <- multisynth$grps\n    J <- length(grps)\n\n    if(is.null(bs_weight)) {\n      # bs_weight <- rep(1 / sqrt(sum(is.finite(multisynth$data$trt))), n)\n      bs_weight <- rep(1, n)\n    }\n\n    if(time_cohort) {\n        which_t <- lapply(grps, \n                          function(tj) (1:n)[multisynth$data$trt == tj])\n        mask <- unique(multisynth$data$mask)\n    } else {\n        which_t <- (1:n)[is.finite(multisynth$data$trt)]\n        mask <- multisynth$data$mask\n    }\n    \n\n    n1 <- sapply(1:J, function(j) length(which_t[[j]]))\n\n    fullmask <- cbind(mask, matrix(0, nrow = J, ncol = (ttot - d)))\n    \n\n    ## estimate the post-treatment values to get att estimates\n    mu1hat <- vapply(1:J,\n                     function(j) colMeans((bs_weight * fulldat)[which_t[[j]],\n                                                              , drop=FALSE]),\n                     numeric(ttot))\n\n\n\n    ## get average outcome model estimates and reweight residuals\n    if(typeof(multisynth$y0hat) == \"list\") {\n        mu0hat <- vapply(1:J,\n                        function(j) {\n                            y0hat <- colMeans(\n                              (bs_weight * multisynth$y0hat[[j]])[which_t[[j]],\n                                                                , drop=FALSE])\n                            weightsj <- multisynth$weights[,j] * bs_weight\n                            resj <- multisynth$residuals[[j]][weightsj != 0,, drop = F]\n                            y0hat + t(resj) %*% weightsj[weightsj != 0]\n                        }\n                       , numeric(ttot)\n                        )\n    } else {\n        mu0hat <- vapply(1:J,\n                        function(j) {\n                            y0hat <- colMeans(\n                              (bs_weight * multisynth$y0hat)[which_t[[j]],\n                                                              , drop=FALSE])\n                            weightsj <- multisynth$weights[, j] * bs_weight\n                            resj <- multisynth$residuals[weightsj != 0,, drop = F]\n                            y0hat + t(resj) %*% weightsj[weightsj != 0]\n                        }\n                       , numeric(ttot)\n                        )\n    }\n\n    tauhat <- mu1hat - mu0hat\n\n    if(is.null(att_weight)) {\n      att_weight <- rep(1, J)\n    }\n    ## re-index time if relative to treatment\n    if(relative) {\n        total_len <- min(d + n_leads, ttot + d - min(grps)) ## total length of predictions\n        mu0hat <- vapply(1:J,\n                         function(j) {\n                             vec <- c(rep(NA, d-grps[j]),\n                                      mu0hat[1:grps[j],j],\n                                      mu0hat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])\n                             ## last row is post-treatment average\n                             c(vec,\n                               rep(NA, total_len - length(vec)),\n                               mean(mu0hat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))\n                               \n                         },\n                         numeric(total_len +1\n                                 ))\n        \n        tauhat <- vapply(1:J,\n                         function(j) {\n                             vec <- c(rep(NA, d-grps[j]),\n                                      tauhat[1:grps[j],j],\n                                      tauhat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])\n                             ## last row is post-treatment average\n                             c(vec,\n                               rep(NA, total_len - length(vec)),\n                               mean(tauhat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))\n                         },\n                         numeric(total_len +1\n                                 ))\n        # re-index unit weights if they change over time\n        if(is.null(dim(att_weight))) {\n          if(J == 1) {\n            att_weight <- matrix(replicate(total_len + 1, att_weight), ncol = 1)\n          } else {\n            att_weight <- t(replicate(total_len + 1, att_weight))\n          }\n        }\n        att_weight_new <- vapply(1:J,\n                        function(j) {\n                            vec <- c(rep(NA, d-grps[j]),\n                                    att_weight[1:grps[j],j],\n                                    att_weight[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])\n                            ## last row is post-treatment average\n                            c(vec,\n                              rep(NA, total_len - length(vec)),\n                              mean(att_weight[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))\n                              \n                        },\n                        numeric(total_len +1\n                                ))\n          \n        ## get the overall average estimate\n        avg <- apply(mu0hat, 1, function(z) sum(n1 * z, na.rm=T) / sum(n1 * !is.na(z)))\n        avg <- sapply(1:nrow(mu0hat),\n        function(k)  {\n          sum(n1 * mu0hat[k,] * att_weight_new[k,], na.rm=T) /\n            sum(n1 * (!is.na(mu0hat[k,])) * att_weight_new[k, ], na.rm = T)\n        })\n        mu0hat <- cbind(avg, mu0hat)\n\n        avg <- apply(tauhat, 1, function(z) sum(n1 * z, na.rm=T) / sum(n1 * !is.na(z)))\n        avg <- sapply(1:nrow(mu0hat),\n        function(k)  {\n          sum(n1 * tauhat[k,] * att_weight_new[k,], na.rm=T) /\n            sum(n1 * (!is.na(tauhat[k,])) * att_weight_new[k, ], na.rm = T)\n        })\n        tauhat <- cbind(avg, tauhat)\n        \n    } else {\n\n        ## remove all estimates for t > T_j + n_leads\n        vapply(1:J,\n               function(j) c(mu0hat[1:min(grps[j]+n_leads, ttot),j],\n                             rep(NA, max(0, ttot-(grps[j] + n_leads)))),\n               numeric(ttot)) -> mu0hat\n\n        vapply(1:J,\n               function(j) c(tauhat[1:min(grps[j]+n_leads, ttot),j],\n                             rep(NA, max(0, ttot-(grps[j] + n_leads)))),\n               numeric(ttot)) -> tauhat\n\n        \n        ## only average currently treated units\n        avg1 <- rowSums(t(fullmask) *  mu0hat * n1) /\n                rowSums(t(fullmask) *  n1)\n        avg2 <- rowSums(t(1-fullmask) *  mu0hat * n1) /\n            rowSums(t(1-fullmask) *  n1)\n        avg <- replace_na(avg1, 0) * apply(fullmask, 2, min) +\n            replace_na(avg2,0) * apply(1-fullmask, 2, max)\n        cbind(avg, mu0hat) -> mu0hat\n\n        ## only average currently treated units\n        avg1 <- rowSums(t(fullmask) *  tauhat * n1) /\n            rowSums(t(fullmask) *  n1)\n        avg2 <- rowSums(t(1-fullmask) *  tauhat * n1) /\n            rowSums(t(1-fullmask) *  n1)\n        avg <- replace_na(avg1, 0) * apply(fullmask, 2, min) +\n            replace_na(avg2,0) * apply(1 - fullmask, 2, max)\n        cbind(avg, tauhat) -> tauhat\n    }\n    \n\n    if(att) {\n        return(tauhat)\n    } else {\n        return(mu0hat)\n    }\n}\n\n\n#' Print function for multisynth\n#' @param x multisynth object\n#' @param ... Optional arguments\n#' @export\nprint.multisynth <- function(x, att_weight = NULL, ...) {\n    multisynth <- x\n    \n    ## straight from lm\n    cat(\"\\nCall:\\n\", paste(deparse(multisynth$call), \n        sep=\"\\n\", collapse=\"\\n\"), \"\\n\\n\", sep=\"\")\n\n    # print att estimates\n    att_post <- predict(multisynth, att=T, att_weight = att_weight)[,1]\n    att_post <- att_post[length(att_post)]\n\n    cat(paste(\"Average ATT Estimate: \",\n              format(round(mean(att_post),3), nsmall = 3), \"\\n\\n\", sep=\"\"))\n}\n\n\n\n#' Plot function for multisynth\n#' @importFrom graphics plot\n#' @param x Augsynth object to be plotted\n#' @param inf_type Type of inference to perform:\n#'  \\itemize{\n#'    \\item{bootstrap}{Wild bootstrap, the default option}\n#'    \\item{jackknife}{Jackknife}\n#' }\n#' @param inf Whether to compute and plot confidence intervals\n#' @param levels Which units/groups to plot, default is every group\n#' @param label Whether to label the individual levels\n#' @param weights Whether to plot the weights, default = FALSE\n#' @param ... Optional arguments\n#' @export\nplot.multisynth <- function(x, inf_type = \"bootstrap\", inf = T,\n                            levels = NULL, label = T, \n                            weights = FALSE, ...) {\n\n    if(weights) {\n      ever_trt <- x$data$units[is.finite(x$data$trt)]\n      never_trt <- x$data$units[!is.finite(x$data$trt)]\n      weights <- data.frame(x$weights)\n\n      colnames(weights) <- ever_trt\n      weights$unit <- factor(rownames(weights),\n                             levels = sort(rownames(weights), decreasing = TRUE))\n\n      # plotting the weights\n      weights %>%\n        tidyr::pivot_longer(-unit, names_to = \"trt_unit\", values_to = \"weight\") %>%\n        ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) + \n        ggplot2::geom_tile(color = \"white\", size=.5) +\n        ggplot2::scale_fill_gradient(low = \"white\", high = \"black\", limits=c(-0.01,1.01)) +\n        ggplot2::guides(fill = \"none\") +\n        ggplot2::xlab(\"Treated Unit\") +\n        ggplot2::ylab(\"Donor Unit\") +\n        ggplot2::theme_bw() + \n        ggplot2::theme(axis.ticks.x = ggplot2::element_blank(),\n              axis.ticks.y = ggplot2::element_blank())\n    }\n    else {\n      plot(summary(x, inf_type = inf_type, ...),\n         inf = inf, levels = levels, label = label)\n    }\n}\n\n\n\n\n\n#' Summary function for multisynth\n#' @param object multisynth object\n#' @param inf_type Type of inference to perform:\n#'  \\itemize{\n#'    \\item{bootstrap}{Wild bootstrap, the default option}\n#'    \\item{jackknife}{Jackknife}\n#' }\n#' @param ... Optional arguments\n#' \n#' @return summary.multisynth object that contains:\n#'         \\itemize{\n#'          \\item{\"att\"}{Dataframe with ATT estimates, standard errors for each treated unit}\n#'          \\item{\"global_l2\"}{L2 imbalance for the pooled synthetic control}\n#'          \\item{\"scaled_global_l2\"}{L2 imbalance for the pooled synthetic control, scaled by the imbalance for unitform weights}\n#'          \\item{\"ind_l2\"}{Average L2 imbalance for the individual synthetic controls}\n#'          \\item{\"scaled_ind_l2\"}{Average L2 imbalance for the individual synthetic controls, scaled by the imbalance for unitform weights}\n#'         \\item{\"n_leads\", \"n_lags\"}{Number of post treatment outcomes (leads) and pre-treatment outcomes (lags) to include in the analysis}\n#'         }\n#' @export\nsummary.multisynth <- function(object, inf_type = \"bootstrap\", att_weight = NULL, ...) {\n\n    multisynth <- object\n    \n    relative <- T\n\n    n_leads <- multisynth$n_leads\n    d <- ncol(multisynth$data$X)\n    n <- nrow(multisynth$data$X)\n    ttot <- d + ncol(multisynth$data$y)\n\n    trt <- multisynth$data$trt\n    time_cohort <- multisynth$time_cohort\n    if(time_cohort) {\n        grps <- unique(trt[is.finite(trt)])\n        which_t <- lapply(grps, function(tj) (1:n)[trt == tj])\n    } else {\n        grps <- trt[is.finite(trt)]\n        which_t <- (1:n)[is.finite(trt)]\n    }\n    \n    # grps <- unique(multisynth$data$trt) %>% sort()\n    J <- length(grps)\n    \n    # which_t <- (1:n)[is.finite(multisynth$data$trt)]\n    times <- multisynth$data$time\n    \n    summ <- list()\n    ## post treatment estimate for each group and overall\n    # att <- predict(multisynth, relative, att=T)\n    \n    if(inf_type == \"jackknife\") {\n        attse <- jackknife_se_multi(multisynth, relative, att_weight = att_weight, ...)\n    } else if(inf_type == \"bootstrap\") {\n        if(object$force == 2) {\n          warning(\"Wild bootstrap without including a unit fixed effect \",\n                  \"is likely to be very conservative!\")\n        }\n        attse <- weighted_bootstrap_multi(multisynth, att_weight = att_weight, ...)\n    } else {\n        att <- predict(multisynth, relative, att=T, att_weight = att_weight)\n        attse <- list(att = att,\n                      se = matrix(NA, nrow(att), ncol(att)),\n                      upper_bound = matrix(NA, nrow(att), ncol(att)),\n                      lower_bound = matrix(NA, nrow(att), ncol(att)))\n    }\n    \n\n    if(relative) {\n        att <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),\n                                attse$att))\n        if(time_cohort) {\n            col_names <- c(\"Time\", \"Average\", \n                            as.character(times[grps + 1]))\n        } else {\n            col_names <- c(\"Time\", \"Average\", \n                            as.character(multisynth$data$units[which_t]))\n        }\n        names(att) <- col_names\n        att %>% gather(Level, Estimate, -Time) %>%\n            rename(\"Time\"=Time) %>%\n            mutate(Time=Time-1) -> att\n\n        se <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),\n                               attse$se))                            \n        names(se) <- col_names\n        se %>% gather(Level, Std.Error, -Time) %>%\n            rename(\"Time\"=Time) %>%\n            mutate(Time=Time-1) -> se\n        lower_bound <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),\n                                  attse$lower_bound))\n        names(lower_bound) <- col_names\n        lower_bound %>% gather(Level, lower_bound, -Time) %>%\n          rename(\"Time\"=Time) %>%\n          mutate(Time=Time-1) -> lower_bound\n\n        upper_bound <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),\n                                        attse$upper_bound))\n        names(upper_bound) <- col_names\n        upper_bound %>% gather(Level, upper_bound, -Time) %>%\n          rename(\"Time\"=Time) %>%\n            mutate(Time=Time-1) -> upper_bound\n\n    } else {\n        att <- data.frame(cbind(times, attse$att))\n        names(att) <- c(\"Time\", \"Average\", times[grps[1:J]])        \n        att %>% gather(Level, Estimate, -Time) -> att\n\n        se <- data.frame(cbind(times, attse$se))\n        names(se) <- c(\"Time\", \"Average\", times[grps[1:J]])        \n        se %>% gather(Level, Std.Error, -Time) -> se\n\n    }\n\n    summ$att <- inner_join(att, se, by = c(\"Time\", \"Level\")) %>%\n      inner_join(lower_bound, by = c(\"Time\", \"Level\")) %>%\n        inner_join(upper_bound, by = c(\"Time\", \"Level\"))\n\n    summ$relative <- relative\n    summ$grps <- grps\n    summ$call <- multisynth$call\n    summ$global_l2 <- multisynth$global_l2\n    summ$scaled_global_l2 <- multisynth$scaled_global_l2\n\n    summ$ind_l2 <- multisynth$ind_l2\n    summ$scaled_ind_l2 <- multisynth$scaled_ind_l2\n\n    summ$n_leads <- multisynth$n_leads\n    summ$n_lags <- multisynth$n_lags\n\n    class(summ) <- \"summary.multisynth\"\n    return(summ)\n}\n\n#' Print function for summary function for multisynth\n#' @param x summary object\n#' @param level Which unit/group to print results for, default is the overall average\n#' @param ... Optional arguments\n#' @export\nprint.summary.multisynth <- function(x, level = \"Average\", ...) {\n\n    summ <- x\n    \n    ## straight from lm\n    cat(\"\\nCall:\\n\", paste(deparse(summ$call), sep=\"\\n\", collapse=\"\\n\"), \"\\n\\n\", sep=\"\")\n\n    first_lvl <- summ$att %>% filter(Level != \"Average\") %>% pull(Level) %>% min()\n    \n    ## get ATT estimates for treatment level, post treatment\n    if(summ$relative) {\n        summ$att %>%\n            filter(Time >= 0, Level==level) %>%\n            rename(\"Time Since Treatment\"=Time) -> att_est\n    } else if(level == \"average\") {\n        summ$att %>% filter(Time > first_lvl, Level==\"Average\") -> att_est\n    } else {\n        summ$att %>% filter(Time > level, Level==level) -> att_est\n    }\n\n    cat(paste(\"Average ATT Estimate (Std. Error): \",\n              summ$att %>%\n                  filter(Level == level, is.na(Time)) %>%\n                  pull(Estimate) %>%\n                  round(3) %>% format(nsmall=3),\n              \"  (\",\n              summ$att %>%\n                  filter(Level == level, is.na(Time)) %>%\n                  pull(Std.Error) %>%\n                  round(3) %>% format(nsmall=3),\n              \")\\n\\n\", sep=\"\"))\n    \n    cat(paste(\"Global L2 Imbalance: \",\n              format(round(summ$global_l2,3), nsmall=3), \"\\n\",\n              \"Scaled Global L2 Imbalance: \",\n              format(round(summ$scaled_global_l2,3), nsmall=3), \"\\n\",\n              \"Percent improvement from uniform global weights: \", \n              format(round(1-summ$scaled_global_l2,3)*100), \"\\n\\n\",\n              \"Individual L2 Imbalance: \",\n              format(round(summ$ind_l2,3), nsmall=3), \"\\n\",\n              \"Scaled Individual L2 Imbalance: \", \n              format(round(summ$scaled_ind_l2,3), nsmall=3), \"\\n\",\n              \"Percent improvement from uniform individual weights: \", \n              format(round(1-summ$scaled_ind_l2,3)*100), \"\\t\",\n              \"\\n\\n\",\n              sep=\"\"))\n\n\n    print(att_est, row.names=F)\n\n}\n\n#' Plot function for summary function for multisynth\n#' @importFrom ggplot2 aes\n#' \n#' @param x summary object\n#' @param inf Whether to plot confidence intervals\n#' @param levels Which units/groups to plot, default is every group\n#' @param label Whether to label the individual levels\n#' @param weights Whether to plot the weights, default = FALSE\n#' @param ... Optional arguments\n#' @export\nplot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T,\n                                    weights = FALSE, ...) {\n\n\n    if(weights) {\n        ever_trt <- x$data$units[is.finite(x$data$trt)]\n        never_trt <- x$data$units[!is.finite(x$data$trt)]\n        weights <- data.frame(x$weights)\n\n        colnames(weights) <- ever_trt\n        weights$unit <- factor(rownames(weights),\n                              levels = sort(rownames(weights), decreasing = TRUE))\n\n        # plotting the weights\n        weights %>%\n          tidyr::pivot_longer(-unit, names_to = \"trt_unit\", values_to = \"weight\") %>%\n          ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) + \n          ggplot2::geom_tile(color = \"white\", size=.5) +\n          ggplot2::scale_fill_gradient(low = \"white\", high = \"black\", limits=c(-0.01,1.01)) +\n          ggplot2::guides(fill = \"none\") +\n          ggplot2::xlab(\"Treated Unit\") +\n          ggplot2::ylab(\"Donor Unit\") +\n          ggplot2::theme_bw() + \n          ggplot2::theme(axis.ticks.x = ggplot2::element_blank(),\n                axis.ticks.y = ggplot2::element_blank())\n    }\n\n    summ <- x\n    \n    ## get the last time period for each level\n    summ$att %>%\n        filter(!is.na(Estimate),\n               Time >= -summ$n_lags,\n               Time <= summ$n_leads) %>%\n        group_by(Level) %>%\n        summarise(last_time = max(Time)) -> last_times\n\n    if(is.null(levels)) levels <- unique(summ$att$Level)\n\n    summ$att %>% inner_join(last_times) %>%\n        filter(Level %in% levels) %>%\n        mutate(label = ifelse(Time == last_time, Level, NA),\n               is_avg = ifelse((\"Average\" %in% levels) * (Level == \"Average\"),\n                               \"A\", \"B\")) %>%\n        ggplot2::ggplot(ggplot2::aes(x = Time, y = Estimate,\n                                     group = Level,\n                                     color = is_avg,\n                                     alpha = is_avg)) +\n            ggplot2::geom_line(size = 1) +\n            ggplot2::geom_point(size = 1) -> p\n            \n        if(label) {\n          p <- p + ggrepel::geom_label_repel(ggplot2::aes(label = label),\n                                      nudge_x = 1, na.rm = T)\n        } \n        p <- p + ggplot2::geom_hline(yintercept = 0, lty = 2)\n\n    if(summ$relative) {\n        p <- p + ggplot2::geom_vline(xintercept = 0, lty = 2) +\n            ggplot2::xlab(\"Time Relative to Treatment\")\n    } else {\n        p <- p + ggplot2::geom_vline(aes(xintercept = as.numeric(Level)),\n                                     lty = 2, alpha = 0.5,\n                                     summ$att %>% filter(Level != \"Average\"))\n    }\n\n    ## add ses\n    if(inf) {\n        max_time <- max(summ$att$Time, na.rm = T)\n        if(max_time == 0) {\n          error_plt <- ggplot2::geom_errorbar\n          clr <- \"black\"\n          alph <- 1\n        } else {\n          error_plt <- ggplot2::geom_ribbon\n          clr <- NA\n          alph <- 0.2\n        }\n        if(\"Average\" %in% levels) {\n            p <- p + error_plt(\n                ggplot2::aes(ymin=lower_bound,\n                             ymax=upper_bound),\n                alpha = alph, color=clr,\n                data = summ$att %>% \n                        filter(Level == \"Average\",\n                               Time >= 0))\n            \n        } else {\n            p <- p + error_plt(\n                ggplot2::aes(ymin=lower_bound,\n                             ymax=upper_bound),\n                             data = . %>% filter(Time >= 0),\n                alpha = alph, color = clr)\n        }\n    }\n\n    p <- p + ggplot2::scale_alpha_manual(values=c(1, 0.5)) +\n        ggplot2::scale_color_manual(values=c(\"#333333\", \"#818181\")) +\n        ggplot2::guides(alpha=F, color=F) + \n        ggplot2::theme_bw()\n    return(p)\n\n}\n"
  },
  {
    "path": "R/outcome_models.R",
    "content": "################################################################################\n## Code to fit various outcome models\n################################################################################\n\n#' Use a separate regularized regression for each post period\n#' to fit E[Y(0)|X]\n#' @importFrom stats poly\n#' @importFrom stats coef\n#'\n#' @param X Matrix of covariates/lagged outcomes\n#' @param y Matrix of post-period outcomes\n#' @param trt Vector of treatment indicator\n#' @param alpha Mixing between L1 and L2, default: 1 (LASSO)\n#' @param lambda Regularization hyperparameter, if null then CV\n#' @param poly_order Order of polynomial to fit, default 1\n#' @param type How to fit outcome model(s)\n#'             \\itemize{\n#'              \\item{sep }{Separate outcome models}\n#'              \\item{avg }{Average responses into 1 outcome}\n#'              \\item{multi }{Use multi response regression in glmnet}}\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Regression parameters}}\nfit_prog_reg <- function(X, y, trt, alpha=1, lambda=NULL,\n                         poly_order=1, type=\"sep\", ...) {\n    if(!requireNamespace(\"glmnet\", quietly = TRUE)) {\n        stop(\"In order to fit an elastic net outcome model, you must install the glmnet package.\")\n    }\n    \n    extra_params = list(...)\n    if (length(extra_params) > 0) {\n        warning(\"Unused parameters when using elastic net: \", paste(names(extra_params), collapse = \", \"))\n    }\n    \n    X <- matrix(poly(matrix(X),degree=poly_order), nrow=dim(X)[1])\n\n    ## helper function to fit regression with CV\n    outfit <- function(x, y) {\n        if(is.null(lambda)) {\n            lam <- glmnet::cv.glmnet(x, y, alpha=alpha, grouped=FALSE)$lambda.min\n        } else {\n            lam <- lambda\n        }\n        fit <- glmnet::glmnet(x, y, alpha=alpha,\n                              lambda=lam)\n        \n        return(as.matrix(coef(fit)))\n    }\n\n    if(type==\"avg\") {\n        ## if fitting the average post period value, stack post periods together\n        stacky <- c(y)\n        stackx <- do.call(rbind,\n                          lapply(1:dim(y)[2],\n                                 function(x) X))\n        stacktrt <- rep(trt, dim(y)[2])\n        regweights <- outfit(stackx[stacktrt==0,],\n                             stacky[stacktrt==0])\n    } else if(type==\"sep\"){\n        ## fit separate regressions for each post period\n        regweights <- apply(as.matrix(y), 2,\n                            function(yt) outfit(X[trt==0,],\n                                                yt[trt==0]))\n    } else {\n        ## fit multi response regression\n        lam <- glmnet::cv.glmnet(X, y, family=\"mgaussian\",\n                                 alpha=alpha, grouped=FALSE)$lambda.min\n        fit <- glmnet::glmnet(X, y, family=\"mgaussian\",\n                              alpha=alpha,\n                              lambda=lam)\n        regweights <- as.matrix(do.call(cbind, coef(fit)))\n    }\n\n\n    ## Get predicted values\n    y0hat <- cbind(rep(1, dim(X)[1]),\n                   X) %*% regweights\n\n    return(list(y0hat = y0hat,\n                params  = regweights))\n}\n\n\n\n#' Use a separate random forest regression for each post period\n#' to fit E[Y(0)|X]\n#'\n#' @param X Matrix of covariates/lagged outcomes\n#' @param y Matrix of post-period outcomes\n#' @param trt Vector of treatment indicator\n#' @param avg Predict the average post-treatment outcome\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Regression parameters}}\nfit_prog_rf <- function(X, y, trt, avg=FALSE, ...) {\n\n    if(!requireNamespace(\"randomForest\", quietly = TRUE)) {\n        stop(\"In order to fit a random forest outcome model, you must install the randomForest package.\")\n    }\n    \n    extra_params = list(...)\n    if (length(extra_params) > 0) {\n        warning(\"Unused parameters when using random forest: \", paste(names(extra_params), collapse = \", \"))\n    }\n\n    \n    ## helper function to fit RF\n    outfit <- function(x, y) {\n            fit <- randomForest::randomForest(x, y)\n            return(fit)\n    }\n\n\n    if(avg | dim(y)[2] == 1) {\n        ## if fitting the average post period value, stack post periods together\n        stacky <- c(y)\n        stackx <- do.call(rbind,\n                          lapply(1:dim(y)[2],\n                                 function(x) X))\n        stacktrt <- rep(trt, dim(y)[2])\n        fit <- outfit(stackx[stacktrt==0,],\n                      stacky[stacktrt==0])\n\n        ## predict outcome\n        y0hat <- matrix(predict(fit, X), ncol=1)\n\n        \n        ## keep feature importances\n        imports <- randomForest::importance(fit)\n\n        \n    } else {\n        ## fit separate regressions for each post period\n        fits <- apply(as.matrix(y), 2,\n                      function(yt) outfit(X[trt==0,],\n                                          yt[trt==0]))\n        \n        ## predict outcome\n        y0hat <- lapply(fits, function(fit) as.matrix(predict(fit,X))) %>%\n            bind_rows() %>%\n            as.matrix()\n\n        \n        ## keep feature importances\n        imports <- lapply(fits, function(fit) randomForest::importance(fit)) %>%\n            bind_rows() %>%\n            as.matrix()\n\n    }\n\n\n    return(list(y0hat=y0hat,\n                params=imports))\n    \n}\n\n\n#' Use gsynth to fit factor model for E[Y(0)|X]\n#'\n#' @param X Matrix of covariates/lagged outcomes\n#' @param y Matrix of post-period outcomes\n#' @param trt Vector of treatment indicator\n#' @param r Number of factors to use (or start with if CV==1)\n#' @param r.end Max number of factors to consider if CV==1\n#' @param force Fixed effects (0=none, 1=unit, 2=time, 3=two-way)\n#' @param CV Whether to do CV (0=no CV, 1=yes CV)\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Regression parameters}}\nfit_prog_gsynth <- function(X, y, trt, r=0, r.end=5, force=3, CV=1, ...) {\n    if(!requireNamespace(\"gsynth\", quietly = TRUE)) {\n        stop(\"In order to fit generalized synthetic controls, you must install the gsynth package.\")\n    }\n    extra_params = list(...)\n    if (length(extra_params) > 0) {\n        warning(\"Unused parameters when using gSynth: \", paste(names(extra_params), collapse = \", \"))\n    }\n    \n    df_x = data.frame(X, check.names=FALSE)\n    df_x$unit = rownames(df_x)\n    df_x$trt = rep(0, nrow(df_x))\n    df_x <- df_x %>% select(unit, trt, everything())\n    long_df_x = gather(df_x, time, obs, -c(unit,trt))\n\n    df_y = data.frame(y, check.names=FALSE)\n    df_y$unit = rownames(df_y)\n    df_y$trt = trt\n    df_y <- df_y %>% select(unit, trt, everything())\n    long_df_y = gather(df_y, time, obs, -c(unit,trt))\n    long_df = rbind(long_df_x, long_df_y)\n\n    transform(long_df, time = as.numeric(time))\n    transform(long_df, unit = as.numeric(unit))\n    gsyn <- gsynth::gsynth(data = long_df, Y = \"obs\", D = \"trt\", \n                           index = c(\"unit\", \"time\"), force = force, CV = CV, r = r)\n\n    t0 <- dim(X)[2]\n    t_final <- t0 + dim(y)[2]\n    n <- dim(X)[1]\n    ## get predicted outcomes\n    y0hat <- matrix(0, nrow=n, ncol=(t_final-t0))\n    y0hat[trt==0,]  <- t(gsyn$Y.co[(t0+1):t_final,,drop=FALSE] -\n                             gsyn$est.co$residuals[(t0+1):t_final,,drop=FALSE])\n\n    y0hat[trt==1,] <- gsyn$Y.ct[(t0+1):t_final,]\n\n    ## add treated prediction for whole pre-period\n    gsyn$est.co$Y.ct <- gsyn$Y.ct\n\n    ## control and treated residuals\n    gsyn$est.co$ctrl_resids <- gsyn$est.co$residuals\n    gsyn$est.co$trt_resids <- colMeans(cbind(X[trt==1,,drop=FALSE],\n                                             y[trt==1,,drop=FALSE])) -\n        rowMeans(gsyn$est.co$Y.ct)\n    return(list(y0hat=y0hat,\n                params=gsyn$est.co))\n}\n\n\n#' Use Athey (2017) matrix completion panel data code\n#'\n#' @param X Matrix of covariates/lagged outcomes\n#' @param y Matrix of post-period outcomes\n#' @param trt Vector of treatment indicator\n#' @param unit_fixed Whether to estimate unit fixed effects\n#' @param time_fixed Whether to estimate time fixed effects\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Regression parameters}}\nfit_prog_mcpanel <- function(X, y, trt, unit_fixed=1, time_fixed=1, ...) {\n\n\n    if(!requireNamespace(\"MCPanel\", quietly = TRUE)) {\n        stop(\"In order to fit matrix completion, you must install the MCPanel package.\")\n    }\n    \n    extra_params = list(...)\n    if (length(extra_params) > 0) {\n        warning(\"Unused parameters when using MCPanel: \", paste(names(extra_params), collapse = \", \"))\n    }\n    \n    ## create matrix and missingness matrix\n\n    t0 <- dim(X)[2]\n    t_final <- t0 + dim(y)[2]\n    n <- dim(X)[1]    \n    \n    fullmat <- cbind(X, y)\n    maskmat <- matrix(1, nrow=nrow(fullmat), ncol=ncol(fullmat))\n    maskmat[trt==1, (t0+1):t_final] <- 0\n\n    ## estimate matrix\n    mcp <- MCPanel::mcnnm_cv(fullmat, maskmat,\n                             to_estimate_u=unit_fixed, to_estimate_v=time_fixed)\n    \n    ## impute matrix\n    imp_mat <- mcp$L +\n        sweep(matrix(0, nrow=nrow(fullmat), ncol=ncol(fullmat)), 1, mcp$u, \"+\") + # unit fixed\n        sweep(matrix(0, nrow=nrow(fullmat), ncol=ncol(fullmat)), 2, mcp$v, \"+\") # time fixed\n    \n    \n    trtmat <- matrix(0, ncol=n, nrow=t_final)\n    trtmat[t0:t_final, trt == 1] <- 1\n\n    ## get predicted outcomes\n    y0hat <- imp_mat[,(t0+1):t_final,drop=FALSE]\n    params <- mcp\n\n    params$trt_resids <- colMeans(cbind(X[trt==1,,drop=FALSE],\n                                        y[trt==1,,drop=FALSE])) -\n        rowMeans(imp_mat[trt==1,,drop=FALSE])\n\n    params$ctrl_resids <- t(cbind(X[trt==0,,drop=FALSE],\n                                y[trt==0,,drop=FALSE]) - imp_mat[trt==0,,drop=FALSE])\n    params$Y.ct <- t(imp_mat[trt==1,,drop=FALSE])\n    return(list(y0hat=y0hat,\n                params=params))\n    \n}\n\n\n#' Fit a Comparitive interupted time series\n#' to fit E[Y(0)|X]\n#' @importFrom stats lm\n#' @importFrom stats predict\n#'\n#' @param X Matrix of covariates/lagged outcomes\n#' @param y Matrix of post-period outcomes\n#' @param trt Vector of treatment indicator\n#' @param poly_order Order of time trend polynomial to fit, default 1\n#' @param weights Weights to use in WLS, default is no weights\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Regression parameters}}\nfit_prog_cits <- function(X, y, trt, poly_order=1, weights=NULL, ...) {\n\n    extra_params = list(...)\n    if (length(extra_params) > 0) {\n        warning(\"Unused parameters when using CITS: \", paste(names(extra_params), collapse = \", \"))\n    }\n    \n    ## combine back into a panel structure\n    ids <- 1:nrow(X)\n    t0 <- dim(X)[2]\n    t_final <- t0 + dim(y)[2]\n    n <- nrow(X)\n\n\n    if(is.null(weights)) {\n        weights <- rep(1, n)\n    }\n    \n    pnl1 <- data.frame(X)\n    colnames(pnl1) <- 1:(t0)\n\n    pnl1 <- pnl1 %>% mutate(trt=trt, post=0, id=ids, weight=weights) %>%\n        gather(time, val, -trt, -post, -id, -weight) %>%\n        mutate(time=as.numeric(time))\n\n    pnl2 <- data.frame(y)\n    colnames(pnl2) <- (t0+1):t_final\n    pnl2 <- pnl2 %>% mutate(trt=trt, post=1, id=ids, weight=weights) %>%\n        gather(time, val, -trt, -post, -id, -weight) %>%\n        mutate(time=as.numeric(time))\n    \n    \n    pnl <- bind_rows(pnl1, pnl2)\n    \n    ## fit regression\n    if(poly_order == \"fixed\") {\n        fit <- pnl %>%\n            filter(!((post==1) & (trt==1))) %>% ## filter out post-period treated outcomes\n            lm(val ~  as.factor(id) + as.factor(time),\n              .,\n              weights = .$weight \n              )\n    } else if(poly_order > 0) {\n        fit <- pnl %>%\n            filter(!((post==1) & (trt==1))) %>% ## filter out post-period treated outcomes\n        lm(val ~ poly(time, poly_order) + post + trt + poly(time * trt, poly_order),\n              ., \n              weights = .$weight\n              )\n    } else {\n\n        fit <- pnl %>%\n            filter(!((post==1) & (trt==1))) %>% ## filter out post-period treated outcomes\n            lm(val ~  post + trt,\n              .,\n              weights = .$weight \n              )\n    }\n\n    \n    ## get predicted post-period outcomes\n    \n    y0hat <- matrix(0, nrow=n, ncol=(t_final-t0))\n    y0hat[trt==0,]  <- matrix(predict(fit,\n                                      pnl %>% filter(post==1 & trt==0)),\n                              ncol=ncol(y))\n\n    y0hat[trt==1,] <- matrix(predict(fit,\n                                     pnl %>% filter(post==1 & trt==1)),\n                             ncol=ncol(y))\n\n\n    params <- list()\n\n    \n    ## add treated prediction for whole pre-period\n    params$Y.ct <- matrix(predict(fit,\n                                  pnl %>% filter(trt==1),\n                                  ncol=(ncol(X) + ncol(y))))\n\n    ## and control prediction\n    ctrl_pred <- matrix(predict(fit,\n                                pnl %>% filter(trt==0)),\n                                ncol=(ncol(X) + ncol(y)))\n\n    ## control and treated residuals\n    params$ctrl_resids <- t(cbind(X[trt==0,,drop=FALSE],\n                                y[trt==0,,drop=FALSE])) - \n        t(ctrl_pred)\n    params$trt_resids <- colMeans(cbind(X[trt==1,,drop=FALSE],\n                                            y[trt==1,,drop=FALSE])) -\n        rowMeans(params$Y.ct)\n    \n    return(list(y0hat=y0hat,\n                params=params))\n    \n}\n\n\n\n\n#' Fit a bayesian structural time series\n#' to fit E[Y(0)|X]\n#'\n#' @param X Matrix of covariates/lagged outcomes\n#' @param y Matrix of post-period outcomes\n#' @param trt Vector of treatment indicator\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Model parameters}}\nfit_prog_causalimpact <- function(X, y, trt, ...) {\n\n\n    if(!requireNamespace(\"CausalImpact\", quietly = TRUE)) {\n        stop(\"In order to fit bayesian structural time series, you must install the CausalImpact package.\")\n    }\n    \n    extra_params = list(...)\n    if (length(extra_params) > 0) {\n        warning(\"Unused parameters using Bayesian structural time series with CausalImpact: \", paste(names(extra_params), collapse = \", \"))\n    }\n\n    ## structure data accordingly\n    ids <- 1:nrow(X)\n    t0 <- dim(X)[2]\n    t_final <- t0 + dim(y)[2]\n    n <- nrow(X)\n\n    comb <- cbind(X, y)\n\n    imp_dat <- t(rbind(colMeans(comb[trt==1,,drop=F]), comb[trt==0,,drop=F]))\n\n    \n    ## get predicted post-period outcomes\n    ## TODO: is this the way to use CausalImpact??\n    ci_func <- function(i) {\n        ## fit causal impact using controls\n        CausalImpact::CausalImpact(t(rbind(comb[i,], comb[-i,][trt[-i]==0,])),\n                                   pre.period=c(1, t0), post.period=c(t0+1, t_final)\n                                   )$series$point.pred\n        \n    }\n\n    y0hat <- t(sapply(1:n, ci_func))\n\n    params <- list()\n\n    \n    ## add treated prediction for whole pre-period\n    params$Y.ct <- t(y0hat[trt==1,,drop=F])\n\n    ## and control prediction\n    ctrl_pred <- y0hat[trt==0,,drop=F]\n\n    ## control and treated residuals\n    params$ctrl_resids <- t(cbind(X[trt==0,,drop=FALSE],\n                                y[trt==0,,drop=FALSE])) - \n        t(ctrl_pred)\n    \n    params$trt_resids <- colMeans(cbind(X[trt==1,,drop=FALSE],\n                                            y[trt==1,,drop=FALSE])) -\n        rowMeans(params$Y.ct)\n    return(list(y0hat=y0hat[,(t0+1):t_final, drop=F],\n                params=params))\n    \n}\n\n\n\n\n#' Fit a seq2seq model with a feedforward net\n#' to fit E[Y(0)|X]\n#'\n#' @param X Matrix of covariates/lagged outcomes\n#' @param y Matrix of post-period outcomes\n#' @param trt Vector of treatment indicator\n#' @param layers List of (n_hidden_units, activation function) pairs to define layers\n#' @param epochs Number of epochs for training\n#' @param patience Number of epochs to wait before early stopping\n#' @param val_split Proportion of control units to use for validation\n#' @param verbose Whether to print training progress\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Model parameters}}\nfit_prog_seq2seq <- function(X, y, trt,\n                             layers=list(c(50, \"relu\"), c(5, \"relu\")),\n                             epochs=500,\n                             patience=5,\n                             val_split=0.2,\n                             verbose=F, ...) {\n\n    if(!requireNamespace(\"keras\", quietly = TRUE)) {\n        stop(\"In order to fit a neural network, you must install the keras package.\")\n    }\n    \n    extra_params = list(...)\n    if (length(extra_params) > 0) {\n        warning(\"Unused parameters when building sequence to sequence learning with feedforward nets: \", paste(names(extra_params), collapse = \", \"))\n    }\n    \n    ## structure data accordingly\n    ids <- 1:nrow(X)\n    t0 <- dim(X)[2]\n    t_final <- t0 + dim(y)[2]\n    n <- nrow(X)\n\n\n    Xctrl <- X[trt==0,,drop=F]\n    yctrl <- y[trt==0,,drop=F]\n\n    ## create first layer\n    model <- keras::keras_model_sequential() %>%\n        keras::layer_dense(units = layers[[1]][1], activation = layers[[1]][2],\n                    input_shape = ncol(Xctrl))\n\n    ## add layers\n    for(layer in layers[-1]) {\n        model %>% keras::layer_dense(units = layer[1], activation = layer[2])\n    }\n\n    ## output lyaer\n    model %>% keras::layer_dense(units=ncol(yctrl))\n\n    ## compile\n    model %>% keras::compile(optimizer=\"rmsprop\", loss=\"mse\", metrics=c(\"mae\")) \n\n    ## fit model\n    learn <- model %>%\n        keras::fit(x=Xctrl, y=yctrl,\n            epochs=epochs,\n            batch_size=nrow(Xctrl),\n            validation_split=val_split,\n            callbacks=list(keras::callback_early_stopping(patience=patience)),\n            verbose=verbose)\n\n    ## predict for everything\n    y0hat <- model %>% predict(X)\n    params=list(model=model, learn=learn)\n    \n    return(list(y0hat=y0hat,\n                params=params))\n}\n\n\n\n"
  },
  {
    "path": "R/outcome_multi.R",
    "content": "################################################################################\n## Fitting outcome models for multiple treatment groups\n################################################################################\n\n\n#' Use gsynth to fit factor model with \n#' @importFrom utils capture.output\n#' @param long_df A long dataframe with 4 columns in the order unit, time, trt, outcome\n#' @param X Matrix of outcomes\n#' @param trt Vector of treatment status for each unit\n#' @param r Number of factors to use (or start with if CV==1)\n#' @param r.end Max number of factors to consider if CV==1\n#' @param force Fixed effects (0=none, 1=unit, 2=time, 3=two-way)\n#' @param CV Whether to do CV (0=no CV, 1=yes CV)\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Regression parameters}}\nfit_gsynth_multi <- function(long_df, X, trt, r=0, force=3, CV=1) {\n    if(!requireNamespace(\"gsynth\", quietly = TRUE)) {\n        stop(\"In order to fit generalized synthetic controls, you must install the gsynth package.\")\n    }\n    ttot <- ncol(X)\n    n <- nrow(X)\n\n    labels <- colnames(long_df)\n    gsyn <- gsynth::gsynth(data = long_df, Y = labels[4], D = labels[3], index = c(labels[1], labels[2]), force = force, CV = CV, r=r)\n    \n    y0hat <- matrix(0, nrow=n, ncol=ttot)\n    y0hat[!is.finite(trt),]  <- t(gsyn$Y.co - gsyn$est.co$residuals)\n    \n    y0hat[is.finite(trt),] <- t(gsyn$Y.ct)\n    \n    ## add treated prediction for whole pre-period\n    gsyn$est.co$Y.ct <- gsyn$Y.ct\n    return(list(y0hat=y0hat,\n                params=gsyn$est.co))\n}\n\n\n\n#' Get fixed effects from pre-treatment data for each level\n#'\n#' @param X Matrix of outcomes\n#' @param trt Vector of treatment status for each unit\n#' @param mask Matrix of treatment statuses\n#' @param force Fixed effects: 1=\"unit\", 2=\"time\", 3=\"two-way\"\n#' @param time_cohort Boolean indicating whether to use time cohorts\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{Predicted outcome under control}\n#'           \\item{params }{Regression parameters}}\nfit_feff <- function(X, trt, mask, force, time_cohort) {\n\n    ttot <- ncol(X)\n    n <- nrow(X)\n    # grps <- trt[is.finite(trt)]\n    # iterate over treatment cohorts\n    grps <- unique(trt[is.finite(trt)])\n    J <- length(grps)\n    which_t <- (1:n)[is.finite(trt)]\n\n    if(force %in% c(2,3)) {\n        ## compute time fixed effects from pure controls\n        time_eff <- matrix(colMeans(X[!is.finite(trt),, drop = F],\n                            na.rm = TRUE),\n                            nrow=nrow(X),\n                            ncol=ncol(X),\n                            byrow=T)\n    } else {\n      time_eff <- matrix(0, nrow = nrow(X), ncol = ncol(X))\n    }\n    residuals <- X - time_eff\n    y0hat <- time_eff\n    if(force %in% c(1,3)) {\n\n        ## compute unit fixed effects from pre-intervention outcomes\n        unit_eff <- lapply(grps, \n                            function(tj) matrix(\n                                            rowMeans(residuals[, 1:tj, drop = F],\n                                                     na.rm = TRUE),\n                                            nrow=nrow(X), ncol=ncol(X)))\n        residuals <- lapply(1:J, function(j) residuals -\n                                                unit_eff[[j]])\n        y0hat <- unit_eff\n    }\n\n    if(force == 3) {\n        y0hat <- lapply(unit_eff, function(ufj) time_eff + ufj)\n    }\n    \n    # go from treatment cohorts to individuals\n    if(force %in% c(1,3) & !time_cohort) {\n      names(residuals) <- as.character(grps)\n      residuals <- residuals[as.character(trt[is.finite(trt)])]\n      names(y0hat) <- as.character(grps)\n      y0hat <- y0hat[as.character(trt[is.finite(trt)])]\n    }\n    \n    return(list(y0hat = y0hat,\n                residuals = residuals))\n    \n}\n\n"
  },
  {
    "path": "R/ridge.R",
    "content": "################################################################################\n## Ridge-augmented SCM\n################################################################################\n\n#' Ridge augmented weights (possibly with covariates)\n#'\n#' @param wide_data Output of `format_data`\n#' @param synth_data Output of `format_synth`\n#' @param Z Matrix of covariates, default is  NULL\n#' @param lambda Ridge hyper-parameter, if NULL use CV\n#' @param ridge Include ridge or not\n#' @param scm Include SCM or not\n#' @param lambda_min_ratio Ratio of the smallest to largest lambda when tuning lambda values\n#' @param n_lambda Number of lambdas to consider between the smallest and largest lambda value\n#' @param lambda_max Initial (largest) lambda, if NULL sets it to be (1+norm(X_1-X_c))^2\n#' @param holdout_length Length of conseuctive holdout period for when tuning lambdas \n#' @param min_1se If TRUE, chooses the maximum lambda within 1 standard error of the lambda that minimizes the CV error, if FALSE chooses the optimal lambda; default TRUE\n#' @param V V matrix for synth, default NULL\n#' @param residualize Whether to residualize auxiliary covariates or balance directly, default TRUE\n#' @param ... optional arguments for outcome model\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Ridge ASCM weights}\n#'          \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n#'          \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n#'          \\item{\"mhat\"}{Outcome model estimate (zero in this case)}\n#'          \\item{\"lambda\"}{Value of the ridge hyperparameter}\n#'          \\item{\"ridge_mhat\"}{The ridge regression predictions (for estimating the bias)}\n#'          \\item{\"synw\"}{The synth weights(for estimating the bias)}\n#'          \\item{\"lambdas\"}{List of lambda values evaluated to tune ridge regression}\n#'          \\item{\"lambda_errors\"}{\"The MSE associated with each lambda term in lambdas.\"}\n#'          \\item{\"lambda_errors_se\"}{\"The SE of the MSE associated with each lambda term in lambdas.\"}\n#' }\nfit_ridgeaug_formatted <- function(wide_data, synth_data,\n                                   Z=NULL, lambda=NULL, ridge=T, scm=T,\n                                   lambda_min_ratio = 1e-8, n_lambda = 20,\n                                   lambda_max = NULL,\n                                   holdout_length = 1, min_1se = T,\n                                   V = NULL,\n                                   residualize = FALSE, ...) {\n    extra_params = list(...)\n    if (length(extra_params) > 0) {\n        warning(\"Unused parameters in using ridge augmented weights: \", paste(names(extra_params), collapse = \", \"))\n    }\n\n    X <- wide_data$X\n    y <- wide_data$y\n    trt <- wide_data$trt\n\n    lambda_errors <- NULL\n    lambda_errors_se <- NULL\n    lambdas <- NULL\n\n    ## center outcomes\n    X_cent <- apply(X, 2, function(x) x - mean(x[trt==0]))\n    X_c <- X_cent[trt==0,,drop=FALSE]\n    X_1 <- matrix(colMeans(X_cent[trt==1,,drop=FALSE]), nrow=1)\n    y_cent <- apply(y, 2, function(x) x - mean(x[trt==0]))\n    y_c <- y_cent[trt==0,,drop=FALSE]\n\n    t0 <- ncol(X_c)\n\n    V <- make_V_matrix(t0, V)\n\n    # apply V matrix transformation\n    X_c <- X_c %*% V\n    X_1 <- X_1 %*% V\n\n    new_synth_data <- synth_data\n\n\n    ## if there are auxiliary covariates, use them\n    if(!is.null(Z)) {\n        ## center covariates\n        Z_cent <- apply(Z, 2, function(x) x - mean(x[trt==0]))\n        Z_c <- Z_cent[trt==0,,drop=FALSE]\n        Z_1 <- matrix(colMeans(Z_cent[trt==1,,drop=FALSE]), nrow=1)\n\n        if(residualize) {\n          ## regress out covariates\n          Xc_hat <- Z_c %*% solve(t(Z_c) %*% Z_c) %*% t(Z_c) %*% X_c\n          X1_hat <- Z_1 %*% solve(t(Z_c) %*% Z_c) %*% t(Z_c) %*% X_c\n\n          # take residuals\n          res_t <- X_1  - X1_hat\n          res_c <- X_c - Xc_hat\n\n          X_c <- res_c\n          X_1 <- res_t\n\n          X_cent[trt == 0,] <- res_c\n          X_cent[trt == 1,] <- res_t\n\n\n          new_synth_data$Z1 <- t(res_t)\n          new_synth_data$X1 <- t(res_t)\n          new_synth_data$Z0 <- t(res_c)\n          new_synth_data$X0 <- t(res_c)\n        } else {\n            # standardize covariates to be on the same scale as the outcomes\n            sdz <-  apply(Z_c, 2, sd)\n            sdx <- sd(X_c)\n            Z_c <- sdx * t(t(Z_c) / sdz)\n            Z_1 <- sdx * Z_1 / sdz\n\n          # concatenate\n          X_c <- cbind(X_c, Z_c)\n          X_1 <- cbind(X_1, Z_1)\n          new_synth_data$Z1 <- t(X_1)\n          new_synth_data$X1 <- t(X_1)\n          new_synth_data$Z0 <- t(X_c)\n          new_synth_data$X0 <- t(X_c)\n          V <- diag(ncol(X_c))\n        }\n    } else {\n        new_synth_data$Z1 <- t(X_1)\n        new_synth_data$X1 <- t(X_1)\n        new_synth_data$Z0 <- t(X_c)\n        new_synth_data$X0 <- t(X_c)\n    }\n    out <- fit_ridgeaug_inner(X_c, X_1, trt, new_synth_data,\n                               lambda, ridge, scm,\n                               lambda_min_ratio, n_lambda,\n                               lambda_max,\n                               holdout_length, min_1se)\n\n    weights <- out$weights\n    synw <- out$synw\n    lambda <- out$lambda\n    lambdas <- out$lambdas\n    lambda_errors <- out$lambda_errors\n    lambda_errors_se <- out$lambda_errors_se\n\n    # add back in covariate weights\n    if(!is.null(Z)) {\n        if(residualize) {\n          no_cov_weights <- weights\n          ridge_w <- t(t(Z_1) - t(Z_c) %*% weights) %*% \n                      solve(t(Z_c) %*% Z_c) %*% t(Z_c)\n          weights <- weights + t(ridge_w)\n        } else {\n          no_cov_weights <- NULL\n        }\n    }\n\n    l2_imbalance <- sqrt(sum((synth_data$X0 %*% weights - synth_data$X1)^2))\n\n    ## primal objective value scaled by least squares difference for mean\n    uni_w <- matrix(1/ncol(synth_data$X0), nrow=ncol(synth_data$X0), ncol=1)\n    unif_l2_imbalance <- sqrt(sum((synth_data$X0 %*% uni_w - synth_data$X1)^2))\n    scaled_l2_imabalance <- l2_imbalance / unif_l2_imbalance\n\n\n    ## no outcome model\n    mhat <- matrix(0, nrow=nrow(y), ncol=ncol(y))\n    ridge_mhat <- mhat\n    if(!is.null(Z)) {\n      if(residualize) {\n        ridge_mhat <- ridge_mhat + Z_cent %*% solve(t(Z_c) %*% Z_c) %*%\n                        t(Z_c) %*% y_c\n\n        ## regress out covariates for outcomes\n        yc_hat <- ridge_mhat[trt == 0,, drop = F]\n        # take residuals of outcomes\n        y_c <- y_c - yc_hat\n      } else {\n        X_cent <- cbind(X_cent, Z_cent)\n      }\n    }\n\n    if(ridge) {\n        ridge_mhat <- ridge_mhat + X_cent %*% solve(t(X_c) %*% X_c +\n                        lambda * diag(ncol(X_c))) %*%\n                        t(X_c) %*% y_c\n    }\n\n    output <- list(weights = weights,\n                l2_imbalance = l2_imbalance,\n                scaled_l2_imbalance = scaled_l2_imabalance,\n                mhat = mhat,\n                lambda = lambda,\n                ridge_mhat = ridge_mhat,\n                synw = synw,\n                lambdas = lambdas,\n                lambda_errors = lambda_errors,\n                lambda_errors_se = lambda_errors_se)\n\n    if(!is.null(Z)) {\n        output$no_cov_weights <- no_cov_weights\n\n        z_l2_imbalance <- sqrt(sum((t(Z_c) %*% weights - t(Z_1))^2))\n        z_unif_l2_imbalance <- sqrt(sum((t(Z_c) %*% uni_w - t(Z_1))^2))\n        z_scaled_l2_imbalance <- z_l2_imbalance / z_unif_l2_imbalance\n\n        output$covariate_l2_imbalance <- z_l2_imbalance\n        output$scaled_covariate_l2_imbalance <- z_scaled_l2_imbalance\n\n    }\n    return(output)\n}\n\n#' Helper function to fit ridge ASCM\n#' @param X_c Matrix of control lagged outcomes\n#' @param X_1 Vector of treated leagged outcomes\n#' @param trt Vector of treatment indicators\n#' @param synth_data Output of `format_synth`\n#' @param lambda Ridge hyper-parameter, if NULL use CV\n#' @param ridge Include ridge or not\n#' @param scm Include SCM or not\n#' @param lambda_min_ratio Ratio of the smallest to largest lambda when tuning lambda values\n#' @param n_lambda Number of lambdas to consider between the smallest and largest lambda value\n#' @param lambda_max Initial (largest) lambda, if NULL sets it to be (1+norm(X_1-X_c))^2\n#' @param holdout_length Length of conseuctive holdout period for when tuning lambdas \n#' @param min_1se If TRUE, chooses the maximum lambda within 1 standard error of the lambda that minimizes the CV error, if FALSE chooses the optimal lambda; default TRUE\n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"weights\"}{Ridge ASCM weights}\n#'          \\item{\"lambda\"}{Value of the ridge hyperparameter}\n#'          \\item{\"synw\"}{The synth weights(for estimating the bias)}\n#'          \\item{\"lambdas\"}{List of lambda values evaluated to tune ridge regression}\n#'          \\item{\"lambda_errors\"}{\"The MSE associated with each lambda term in lambdas.\"}\n#'          \\item{\"lambda_errors_se\"}{\"The SE of the MSE associated with each lambda term in lambdas.\"}\n#' }\nfit_ridgeaug_inner <- function(X_c, X_1, trt, synth_data,\n                               lambda, ridge, scm,\n                               lambda_min_ratio, n_lambda,\n                               lambda_max,\n                               holdout_length, min_1se) {\n    lambda_errors <- NULL\n    lambda_errors_se <- NULL\n    lambdas <- NULL\n\n    ## if SCM fit scm\n    if(scm) {\n        syn <- fit_synth_formatted(synth_data)$weights\n    } else {\n        ## else use uniform weights\n        syn <- rep(1 / sum(trt == 0), sum(trt == 0))\n    }\n    if(ridge) {\n        if(is.null(lambda)) {\n            cv_out <- cv_lambda(X_c, X_1, synth_data, trt, holdout_length, scm,\n                      lambda_max, lambda_min_ratio, n_lambda, min_1se)\n\n            lambda <- cv_out$lambda\n            lambda_errors <- cv_out$lambda_errors\n            lambda_errors_se <- cv_out$lambda_errors_se\n            lambdas <- cv_out$lambdas\n        }\n        # get ridge weights\n        ridge_w <- t(t(X_1) - t(X_c) %*% syn) %*%\n                    solve(t(X_c) %*% X_c  + lambda * diag(ncol(X_c))) %*% t(X_c)\n    } else {\n        ridge_w <- matrix(0, ncol = sum(trt == 0), nrow=1)\n    }\n    ## combine weights\n    weights <- syn + t(ridge_w)\n\n    return(list(weights = weights,\n                synw = syn,\n                lambda = lambda,\n                lambdas = lambdas,\n                lambda_errors = lambda_errors,\n                lambda_errors_se = lambda_errors_se))\n}\n\n\n\n#' Choose max lambda as largest eigenvalue of control X\n#' @param X_c matrix of control lagged outcomes\n#' @noRd\n#' @return max lambda\nget_lambda_max <- function(X_c) {\n    svd(X_c)$d[1] ^ 2\n}\n#' Create list of lambdas\n#' @param lambda_min_ratio Ratio of the smallest to largest lambda when tuning lambda values\n#' @param n_lambda Number of lambdas to consider between the smallest and largest lambda value\n#' @param lambda_max Initial (largest) lambda, if NULL sets it to be (1+norm(X_1-X_c))^2\n#' @noRd\n#' @return List of lambdas\ncreate_lambda_list <- function(lambda_max, lambda_min_ratio, n_lambda) {\n    scaler <- (lambda_min_ratio) ^ (1/n_lambda)\n    lambdas <- lambda_max * (scaler ^ (seq(0:n_lambda) - 1))\n    return(lambdas)\n}\n\n#' Choose either the lambda that minimizes CV MSE or largest lambda within 1 se of min\n#' @param lambdas list of lambdas\n#' @param lambda_errors The MSE associated with each lambda term in lambdas.\n#' @param lambda_errors_se The SE of the MSE associated with each lambda\n#' @param min_1se If TRUE, chooses the maximum lambda within 1 standard error of the lambda that minimizes the CV error, if FALSE chooses the optimal lambda; default TRUE\n#' @noRd\n#' @return optimal lambda\nchoose_lambda <- function(lambdas, lambda_errors, lambda_errors_se, min_1se) {\n    # lambda with smallest error\n    min_idx <- which.min(lambda_errors)\n    min_error <- lambda_errors[min_idx]\n    min_se <- lambda_errors_se[min_idx]\n    lambda_min <- lambdas[min_idx]\n    # max lambda with error within one se of min\n    lambda_1se <- max(lambdas[lambda_errors <= min_error + min_se])\n    return(if(min_1se) lambda_1se else lambda_min)\n}\n\n#' Choose best lambda with CV\n#' @param X_c Matrix of control lagged outcomes\n#' @param X_1 Vector of treated leagged outcomes\n#' @param synth_data Output of `format_synth`\n#' @param trt Vector of treatment indicators\n#' @param holdout_length Length of conseuctive holdout period for when tuning lambdas \n#' @param scm Include SCM or not\n#' @param lambda_max Initial (largest) lambda, if NULL sets it to be (1+norm(X_1-X_c))^2\n#' @param lambda_min_ratio Ratio of the smallest to largest lambda when tuning lambda values\n#' @param n_lambda Number of lambdas to consider between the smallest and largest lambda value\n#' @param min_1se If TRUE, chooses the maximum lambda within 1 standard error of the lambda \n#' @noRd\n#' @return \\itemize{\n#'          \\item{\"lambda\"}{Value of the ridge hyperparameter}\n#'          \\item{\"lambdas\"}{List of lambda values evaluated to tune ridge regression}\n#'          \\item{\"lambda_errors\"}{\"The MSE associated with each lambda term in lambdas.\"}\n#'          \\item{\"lambda_errors_se\"}{\"The SE of the MSE associated with each lambda term}\n#' }\ncv_lambda <- function(X_c, X_1, synth_data, trt, holdout_length, scm,\n                      lambda_max, lambda_min_ratio, n_lambda, min_1se) {\n    if(is.null(lambda_max)) {\n        lambda_max <- get_lambda_max(X_c) \n    }\n\n    lambdas <- create_lambda_list(lambda_max, lambda_min_ratio, n_lambda)\n    \n    lambda_out <- get_lambda_errors(lambdas, X_c, X_1,\n                                        synth_data, trt,\n                                        holdout_length, scm)\n    lambda_errors <- lambda_out$lambda_errors\n    lambda_errors_se <- lambda_out$lambda_errors_se\n\n    lambda <- choose_lambda(lambdas, lambda_errors, lambda_errors_se, min_1se)\n\n    return(list(lambda = lambda, lambda_errors = lambda_errors,\n                lambda_errors_se = lambda_errors_se, lambdas = lambdas))\n}\n"
  },
  {
    "path": "R/ridge_lambda.R",
    "content": "################################################################################\n## Function to calculate error on different lambda values if using Ridge Augmented SCM\n################################################################################\n\n#' Get Lambda Errors\n#' @importFrom stats sd\n#'\n#' @param lambdas Vector of lambda values to compute errors for\n#' @param X_c Matrix of control group pre-treatment outcomes\n#' @param X_t Matrix of treatment group pre-treatment outcomes\n#' @param synth_data Output of `format_synth`\n#' @param trt Boolean vector of treatment assignments\n#' @param holdout_length Length of conseuctive holdout period for when tuning lambdas\n#' @param scm Include SCM or not\n#' @noRd\n#' @return List of lambda errors for each corresponding lambda in the lambdas parameter.\nget_lambda_errors <- function(lambdas, X_c, X_t, synth_data, trt, holdout_length=1, scm=T) {\n  # vector that stores the sum MSE across all CV sets\n  errors <- matrix(0, nrow = ncol(X_c) - holdout_length, ncol = length(lambdas))\n  lambda_errors = numeric(length(lambdas)) \n  lambda_errors_se = numeric(length(lambdas)) \n\n  for (i in 1:(ncol(X_c) - holdout_length)) {\n    X_0 <- X_c[,-(i:(i + holdout_length - 1))]\n    X_1 <- matrix(X_t[-(i:(i + holdout_length - 1))])\n    X_0v <- X_c[,i:(i + holdout_length - 1)]\n    X_1v <- matrix(X_t[i:(i + holdout_length - 1)], ncol = 1)\n    new_synth_data <- synth_data\n    new_synth_data$Z1 <- X_1\n    new_synth_data$X1 <- X_1\n    new_synth_data$Z0 <- t(X_0)\n    new_synth_data$X0 <- t(X_0)\n\n    if(scm) {\n      syn <- fit_synth_formatted(new_synth_data)$weights\n    } else {\n      syn <- rep(1/sum(trt==0), sum(trt==0))\n    }\n\n    for (j in 1:length(lambdas)) {\n      ridge_weights <- t(X_1 - t(X_0) %*% syn) %*% solve(t(X_0) %*% X_0 + lambdas[j] * diag(ncol(X_0))) %*% t(X_0)\n      aug_weights <- syn + t(ridge_weights)\n      error <- X_1v - t(X_0v) %*% aug_weights\n      # take sum of errors across the holdout time periods\n      error <- sum(error ^ 2)\n      errors[i, j] <-  error\n      # lambda_errors[j] <- lambda_errors[j] + error\n    }\n  }\n  lambda_errors <- apply(errors, 2, mean)\n  lambda_errors_se <- apply(errors, 2, function(x) sd(x) / sqrt(length(x)))\n  return(list(lambda_errors = lambda_errors, \n              lambda_errors_se = lambda_errors_se))\n}"
  },
  {
    "path": "R/time_regression_multi.R",
    "content": "##############################################################################\n## Outcome regression with multiple treated units\n##############################################################################\n\n#' Fit a time regression\n#' @param X Matrix of outcomes\n#' @param trt Vector of treatment status for each unit\n#' @param n_leads How long past treatment effects should be estimated for\n#' @param reg_param Regularization hyperparameter\n#' @param lowlim Lower bound for coefs\n#' @param uplim upper bound for coefs\n#' @param ... Extra optimization hyperparameters\n#' @noRd\n#' @return \\itemize{\n#'           \\item{y0hat }{List of predicted outcome under control}\n#'           \\item{residuals }{List of residuals}\n#'           \\item{params }{Regression parameters}}\nfit_time_reg <- function(X, trt, n_leads, reg_param, lowlim = 0, uplim = 1, ...) {\n\n    grps <- trt[is.finite(trt)]\n    J <- length(grps)\n    tmax <- max(trt[is.finite(trt)])\n\n    # fit QP\n    reg_weights <- fit_time_reg_qp_(X, trt, n_leads, lowlim, uplim, reg_param, ...)\n\n    # get predicted outcomes (repeated as a matrix) and residuals\n    y0hat <- lapply(1:J, \n        function(j) {\n            # compute time fixed effects from pure controls\n            time_eff <- matrix(colMeans(X[!is.finite(trt),]),\n                              nrow=nrow(X), ncol=ncol(X),\n                              byrow=T)\n            Xj <- X - time_eff\n            zero_mat <- matrix(0, nrow = nrow(X), ncol = (tmax - grps[j]))\n            Xj <- cbind(zero_mat, Xj[, 1:grps[j], drop = F])\n            # take out pure control means\n            y0hatj <- Xj %*% reg_weights[,j, drop = F]\n            matrix(y0hatj, nrow=nrow(X), ncol=ncol(X)) + time_eff\n        })\n\n    residuals <- lapply(1:J, function(j) X - y0hat[[j]])\n\n    return(list(y0hat = y0hat,\n                residuals = residuals,\n                time_weights = reg_weights))\n}\n\n\n#' Fit a time regression\n#' @param X Matrix of outcomes\n#' @param trt Vector of treatment status for each unit\n#' @param n_leads How long past treatment effects should be estimated for\n#' @param reg_param Regularization hyperparameter\n#' @param lowlim Lower bound for coefs\n#' @param uplim upper bound for coefs\n#' @param ... Extra optimization hyperparameters\n#' @noRd\n#' @return reg_weights Fitted regression weights\nfit_time_reg_qp_ <- function(X, trt, n_leads, lowlim, uplim, reg_param, ...) {\n\n    grps <- trt[is.finite(trt)]\n    J <- length(grps)\n    ttot <- ncol(X)\n    max_trt <- max(grps)\n\n    # get data in the right form\n    data_mats <- collect_data(X, trt, n_leads)\n\n    # create constraint matrices\n    constraints <- make_constraints(J, grps, lowlim, uplim)\n\n    # get components of QP\n    Qmat <- get_Qmat(data_mats$pre_mats)\n    pvec <- get_pvec(data_mats$pre_mats, data_mats$post_vecs)\n\n    I0 <- get_regularization_matrix(J, max_trt, reg_param)\n    # add in regularization\n    # I0 <- Matrix::bdiag(reg_param1 * Matrix::Diagonal(max_trt), \n    #                     reg_param2 * Matrix::Diagonal(J * max_trt))\n    \n\n    \n    Qmat <- Qmat + I0\n\n    # fit QP\n    settings <- osqp::osqpSettings(verbose = FALSE, ...)\n    out <- osqp::solve_osqp(Qmat, pvec, constraints$Amat, \n                            constraints$lvec, constraints$uvec, \n                            pars=settings)\n\n    # collect as matrix\n    # reg_weights <- matrix(out$x, ncol = J + 1)\n    reg_weights <- matrix(out$x, ncol = J)\n    # pooled <- reg_weights[,1]\n    # add in common component\n    # reg_weights <- reg_weights[, 1] + reg_weights[, -1]\n    # reverse to calendar time\n    # reg_weights <- reg_weights[nrow(reg_weights):1, ]\n    return(reg_weights)\n}\n\n\n#' Organize data for the QP\n#' @param X Matrix of outcomes\n#' @param trt Vector of treatment status for each unit\n#' @param n_leads How long past treatment effects should be estimated for\n#' @noRd\ncollect_data <- function(X, trt, n_leads) {\n\n    grps <- trt[is.finite(trt)]\n    J <- length(grps)\n    ttot <- ncol(X)\n    max_trt <- max(grps)\n\n\n    # sapply(1:ncol(X), \n    #        function(tj) {\n    #            mean(X[trt >= tj])\n    #        }) -> ctrl_means\n\n    # X <- t(t(X) - ctrl_means)\n\n    # get pre-treatment matrices\n    lapply(grps, function(tj) {\n        # donor unit pre tj outcomes\n        idxs <- trt > tj + n_leads\n        pre_mat <- cbind(#1, \n                         matrix(0, nrow = nrow(X), ncol = (max_trt - tj)),\n                         X[, 1:tj, drop = F])\n        # subtract out pure control means\n        pre_mat <- t(t(pre_mat) - colMeans(pre_mat[!is.finite(trt),,drop = F]))\n        # restrict to units that won't be treated w/in n_leads\n        pre_mat[idxs,,drop = F]\n    }) -> pre_mats\n\n\n    # get post treatment averages\n    lapply(grps, \n           function(tj) {\n               # avg of donor units post tj outcomes\n               idxs <- trt > tj + n_leads\n               donors <- rowMeans(X[, (tj + 1):(tj + n_leads), drop = F])\n               # subtract out pure control means\n               donors <- donors - mean(donors[!is.finite(trt)])\n               # restrict to units that won't be treated w/in n_leads\n               donors[idxs]\n           }) -> post_vecs\n    return(list(pre_mats = pre_mats, post_vecs = post_vecs))\n}\n\n\nget_Qmat <- function(pre_mats) {\n\n    #### matrix in QP\n    cov_mats <- lapply(pre_mats, function(x) t(x) %*% x)\n    # unit specific covariance matrices\n    Qmat <- Matrix::bdiag(cov_mats)\n    return(Qmat)\n}\n\n\nget_Qmat_pool <- function(pre_mats) {\n\n    #### matrix in QP\n    cov_mats <- lapply(pre_mats, function(x) t(x) %*% x)\n    pooled_cov <- Reduce(`+`, cov_mats)\n    cov_mats_bind <-do.call(rbind, cov_mats)\n    # unit specific covariance matrices\n    Qmat <- Matrix::bdiag(cov_mats)\n    # pooling terms\n    Qmat <- rbind(t(cov_mats_bind), Qmat)\n    Qmat <- cbind(rbind(pooled_cov, cov_mats_bind), Qmat)\n\n    return(Qmat)\n}\n\n\nget_pvec <- function(pre_mats, post_vecs) {\n    # vector in QP\n    lapply(1:length(pre_mats), function(j) {\n        t(pre_mats[[j]]) %*% post_vecs[[j]]\n    }) -> pvec_list\n\n    pvec <- do.call(c, pvec_list)\n    \n    return(-2 * pvec)\n}\n\n\nget_pvec_pool <- function(pre_mats, post_vecs) {\n    # vector in QP\n    lapply(1:length(pre_mats), function(j) {\n        t(pre_mats[[j]]) %*% post_vecs[[j]]\n    }) -> pvec_list\n\n    pvec_pool <- Reduce(`+`, pvec_list)\n    pvec <- do.call(c, pvec_list)\n    pvec <- c(pvec_pool, pvec)\n\n    return(-2 * pvec)\n}\n\n\nmake_constraints <- function(J, grps, lowlim, uplim) {\n\n    tmax <- max(grps)\n    # sum to 1 constraints\n    A1 <- Matrix::t(Matrix::bdiag(lapply(1:J, function(j) c(0, rep(1, tmax)))))\n    A1 <- Matrix::t(Matrix::bdiag(lapply(1:J, function(j) rep(1, tmax))))\n    l1 <- rep(1, J)\n    # l1 <- rep(-Inf, J)\n    u1 <- rep(1, J)\n    # u1 <- rep(Inf, J)\n\n    # upper lower limits\n    diag_w_intercept <- Matrix::bdiag(list(0, Matrix::Diagonal(tmax)))[-1, ]\n    A2 <- Matrix::bdiag(lapply(1:J, function(j) diag_w_intercept))\n    A2 <- Matrix::Diagonal(J * tmax)\n    # make sure that only weighting times that exist\n    l2 <- sapply(1:J, function(j) {\n        c(rep(0, tmax - grps[j]), rep(lowlim, grps[j]))\n    })\n    \n    u2 <- sapply(1:J, function(j) {\n        c(rep(0, tmax - grps[j]), rep(uplim, grps[j]))\n    })\n\n    # combine\n\n    Amat <- rbind(A1, A2)\n    lvec <- c(l1, l2)\n    uvec <- c(u1, u2)\n    return(list(Amat = Amat, lvec = lvec, uvec = uvec))\n}\n\nmake_constraints_pool <- function(J, grps, lowlim, uplim) {\n\n    tmax <- max(grps)\n    # sum to 1 constraints\n    A1 <- cbind(0,\n                matrix(1, ncol = tmax, nrow = J),\n                Matrix::t(Matrix::bdiag(lapply(1:J, function(j) c(0, rep(1, tmax))))))\n    \n    l1 <- rep(1, J)\n    # l1 <- rep(-Inf, J)\n    u1 <- rep(1, J)\n    # u1 <- rep(Inf, J)\n\n    # upper lower limits\n    diag_w_intercept <- Matrix::bdiag(list(0, Matrix::Diagonal(tmax)))[-1, ]\n    pool_A2 <- do.call(rbind, lapply(1:J, function(j) diag_w_intercept))\n    A2 <- Matrix::bdiag(lapply(1:J, function(j) diag_w_intercept))\n    A2 <- cbind(pool_A2, A2)\n\n    # restrict global intercept to 0\n    A2 <- rbind(c(1, numeric(ncol(A2) - 1)), A2)\n    \n    # make sure that only weighting times that exist\n    l2 <- sapply(1:J, function(j) {\n        c(rep(0, tmax - grps[j]), rep(lowlim, grps[j]))\n    })\n    l2 <- c(0, l2)\n    \n    u2 <- sapply(1:J, function(j) {\n        c(rep(0, tmax - grps[j]), rep(uplim, grps[j]))\n    })\n    u2 <- c(0, u2)\n    # l2 <- rep(lowlim, J * tmax)\n    # u2 <- rep(uplim, J  * tmax)\n\n    # combine\n\n    Amat <- rbind(A1, A2)\n    lvec <- c(l1, l2)\n    uvec <- c(u1, u2)\n    return(list(Amat = Amat, lvec = lvec, uvec = uvec))\n}\n\nget_regularization_matrix <- function(J, max_trt, reg_param) {\n\n    single_reg_mat <- Matrix::bdiag(list(0, Matrix::Diagonal(max_trt)))\n    I0 <- Matrix::bdiag(lapply(1:J,function(j) single_reg_mat))\n    I0 <- reg_param * Matrix::Diagonal(J * max_trt)\n    return(reg_param * I0)\n}\n\nget_regularization_matrix_pool <- function(J, max_trt, reg_param1, reg_param2) {\n\n    single_reg_mat <- Matrix::bdiag(list(0, Matrix::Diagonal(max_trt)))\n        grp_reg_mats <- Matrix::bdiag(lapply(1:J,function(j) single_reg_mat))\n        I0 <- Matrix::bdiag(reg_param1 * single_reg_mat,\n                            reg_param2 * grp_reg_mats)\n\n    return(I0)\n}"
  },
  {
    "path": "README.md",
    "content": "# augsynth: Augmented Synthetic Control Method\n[![Build Status](https://travis-ci.org/ebenmichael/augsynth.svg?branch=master)](https://travis-ci.org/ebenmichael/augsynth) [![Project Status: Active  The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)\n\n\n\n\n\n## Overview\nThis package implements the Augmented Synthetic Control Method (ASCM).\n\nFor a more detailed description of the main functionality check out:\n- [the vignette for simultaneous adoption](https://github.com/ebenmichael/augsynth/blob/master/vignettes/singlesynth-vignette.md)\n- [the vignette for staggered adoption](https://github.com/ebenmichael/augsynth/blob/master/vignettes/multisynth-vignette.md)\n\n## Installation\nTo install this package, first ensure that `devtools` is installed with\n\n```\ninstall.packages(\"devtools\")\n```\n\nthen install the package from GitHub with\n\n```\ndevtools::install_github(\"ebenmichael/augsynth\")\n```\n\n## Basic usage\nTo get started, use a panel dataset with an `outcome` measure, a `treatment` indicator, a `unit` indicator, a `time` variable, and an intervention time `t_int`. Then run\n\n\n```\nasyn <- augsynth(outcome ~ trt, unit, time, t_int, data)\n```\n"
  },
  {
    "path": "data-raw/clean_kansas.R",
    "content": "library(haven)\nlibrary(tidyverse)\n\nkansas <- read_dta(\"kansas_longer2.dta\")\nstate_abb <- read_csv(\"us-state-ansi-fips.csv\") %>%\n                rename(fips = st, abb = stusps) %>%\n                mutate(fips = as.numeric(fips)) %>%\n                select(fips, abb)\nkansas <- kansas %>%\n        rename(fips=Fips) %>%\n        filter(year >= 1990,\n               !is.na(fips), # filter out all of US\n               fips != 11, # filter out DC\n            #    year_qtr >= 2005 | year_qtr == round(year_qtr)\n               ) %>%\n        # interpolate GDP\n        mutate(year_qtr = year + qtr / 4 - 0.25, # combine year and quarter\n               fips = as.integer(fips), # state id\n               treated = 1 * (fips == 20) * (year_qtr >= 2012.25),\n               gdp = ifelse((qtr == 1) | (year >= 2005), gdp, NA),\n               popestimate = ifelse((qtr == 1), popestimate, NA)) %>%\n        # interpolate GDP and population\n        group_by(fips) %>%\n        arrange(year_qtr) %>%\n        mutate(gdp = approx(year_qtr, gdp, year_qtr)$y,\n               popestimate = approx(year_qtr, popestimate, year_qtr)$y) %>%\n        ungroup() %>% arrange(fips, year_qtr) %>%\n        mutate(gdpcapita = gdp / popestimate * 1e6,\n               lngdp = log(gdp),\n               lngdpcapita = log(gdpcapita),\n               revstatecapita = rev_state_total / popestimate * 1e6,\n               revlocalcapita = rev_local_total / popestimate * 1e6,\n               emplvl1capita = month1_emplvl / popestimate,\n               emplvl2capita = month2_emplvl / popestimate,\n               emplvl3capita = month3_emplvl / popestimate,\n               emplvlcapita = (month1_emplvl + month2_emplvl + month3_emplvl) / (3 * popestimate),\n               totalwagescapita = total_qtrly_wages / popestimate,\n               taxwagescapita = taxable_qtrly_wages / popestimate,\n               avgwklywagecapita = avg_wkly_wage,\n               estabscapita = qtrly_estabs_count / popestimate) %>%\n        filter(year_qtr <= 2016) %>%\n        inner_join(state_abb)\n\nfor (name in colnames(kansas)) {\n        attributes(kansas[[name]])$label = NULL \n}\n\n"
  },
  {
    "path": "man/augsynth-package.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\docType{package}\n\\name{augsynth-package}\n\\alias{augsynth-package}\n\\title{augsynth}\n\\description{\nA package implementing the Augmented Synthetic Controls Method\n}\n"
  },
  {
    "path": "man/augsynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth_pre.R\n\\name{augsynth}\n\\alias{augsynth}\n\\title{Fit Augmented SCM}\n\\usage{\naugsynth(form, unit, time, data, t_int = NULL, ...)\n}\n\\arguments{\n\\item{form}{outcome ~ treatment | auxillary covariates}\n\n\\item{unit}{Name of unit column}\n\n\\item{time}{Name of time column}\n\n\\item{data}{Panel data as dataframe}\n\n\\item{t_int}{Time of intervention (used for single-period treatment only)}\n\n\\item{...}{Optional arguments\n\\itemize{\n  \\item Single period augsynth with/without multiple outcomes\n    \\itemize{\n      \\item{\"progfunc\"}{What function to use to impute control outcomes: Ridge=Ridge regression (allows for standard errors), None=No outcome model, EN=Elastic Net, RF=Random Forest, GSYN=gSynth, MCP=MCPanel, CITS=CITS, CausalImpact=Bayesian structural time series with CausalImpact, seq2seq=Sequence to sequence learning with feedforward nets}\n      \\item{\"scm\"}{Whether the SCM weighting function is used}\n      \\item{\"fixedeff\"}{Whether to include a unit fixed effect, default F }\n      \\item{\"cov_agg\"}{Covariate aggregation functions, if NULL then use mean with NAs omitted}\n    }\n  \\item Multi period (staggered) augsynth\n   \\itemize{\n         \\item{\"relative\"}{Whether to compute balance by relative time}\n         \\item{\"n_leads\"}{How long past treatment effects should be estimated for}\n         \\item{\"n_lags\"}{Number of pre-treatment periods to balance, default is to balance all periods}\n         \\item{\"alpha\"}{Fraction of balance for individual balance}\n         \\item{\"lambda\"}{Regularization hyperparameter, default = 0}\n         \\item{\"force\"}{Include \"none\", \"unit\", \"time\", \"two-way\" fixed effects. Default: \"two-way\"}\n         \\item{\"n_factors\"}{Number of factors for interactive fixed effects, default does CV}\n        }\n}}\n}\n\\value{\naugsynth object that contains:\n        \\itemize{\n         \\item{\"weights\"}{weights}\n         \\item{\"data\"}{Panel data as matrices}\n        }\n}\n\\description{\nFit Augmented SCM\n}\n"
  },
  {
    "path": "man/augsynth_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{augsynth_multiout}\n\\alias{augsynth_multiout}\n\\title{Fit Augmented SCM with multiple outcomes}\n\\usage{\naugsynth_multiout(\n  form,\n  unit,\n  time,\n  t_int,\n  data,\n  progfunc = c(\"Ridge\", \"None\"),\n  scm = T,\n  fixedeff = FALSE,\n  cov_agg = NULL,\n  combine_method = \"avg\",\n  ...\n)\n}\n\\arguments{\n\\item{form}{outcome ~ treatment | auxillary covariates}\n\n\\item{unit}{Name of unit column}\n\n\\item{time}{Name of time column}\n\n\\item{t_int}{Time of intervention}\n\n\\item{data}{Panel data as dataframe}\n\n\\item{progfunc}{What function to use to impute control outcomes\nRidge=Ridge regression (allows for standard errors),\nNone=No outcome model,}\n\n\\item{scm}{Whether the SCM weighting function is used}\n\n\\item{fixedeff}{Whether to include a unit fixed effect, default F}\n\n\\item{cov_agg}{Covariate aggregation functions, if NULL then use mean with NAs omitted}\n\n\\item{combine_method}{How to combine outcomes: `concat` concatenates outcomes and `avg` averages them, default: 'avg'}\n\n\\item{...}{optional arguments for outcome model}\n}\n\\value{\naugsynth object that contains:\n        \\itemize{\n         \\item{\"weights\"}{Ridge ASCM weights}\n         \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n         \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n         \\item{\"mhat\"}{Outcome model estimate}\n         \\item{\"data\"}{Panel data as matrices}\n        }\n}\n\\description{\nFit Augmented SCM with multiple outcomes\n}\n"
  },
  {
    "path": "man/check_data_stag.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/format.R\n\\name{check_data_stag}\n\\alias{check_data_stag}\n\\title{Check that we can actually run multisynth on the data}\n\\usage{\ncheck_data_stag(wide, fixedeff, n_leads, n_lags)\n}\n\\arguments{\n\\item{wide}{Output of format_data_stag}\n\n\\item{fixedeff}{Whether to include a unit fixed effect}\n\n\\item{n_leads}{How long past treatment effects should be estimated for, default is number of post treatment periods for last treated unit}\n\n\\item{n_lags}{Number of pre-treatment periods to balance, default is to balance all periods}\n}\n\\description{\nCheck that we can actually run multisynth on the data\n}\n"
  },
  {
    "path": "man/conformal_inf.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf}\n\\alias{conformal_inf}\n\\title{Conformal inference procedure to compute p-values and point-wise confidence intervals}\n\\usage{\nconformal_inf(\n  ascm,\n  alpha = 0.05,\n  stat_func = NULL,\n  type = \"iid\",\n  q = 1,\n  ns = 1000,\n  grid_size = 50\n)\n}\n\\arguments{\n\\item{ascm}{Fitted `augsynth` object}\n\n\\item{alpha}{Confidence level}\n\n\\item{stat_func}{Function to compute test statistic}\n\n\\item{type}{Either \"iid\" for iid permutations or \"block\" for moving block permutations; default is \"block\"}\n\n\\item{q}{The norm for the test static `((sum(x ^ q))) ^ (1/q)`}\n\n\\item{ns}{Number of resamples for \"iid\" permutations}\n\n\\item{grid_size}{Number of grid points to use when inverting the hypothesis test}\n}\n\\value{\nList that contains:\n        \\itemize{\n         \\item{\"att\"}{Vector of ATT estimates}\n         \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n         \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n         \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n         \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n         \\item{\"p_val\"}{p-value for test of no post-treatment effect}\n         \\item{\"alpha\"}{Level of confidence interval}\n        }\n}\n\\description{\nConformal inference procedure to compute p-values and point-wise confidence intervals\n}\n"
  },
  {
    "path": "man/conformal_inf_linear.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf_linear}\n\\alias{conformal_inf_linear}\n\\title{Conformal inference procedure to compute a confidence interval for a linear in time effect}\n\\usage{\nconformal_inf_linear(\n  ascm,\n  alpha = 0.05,\n  stat_func = NULL,\n  type = \"iid\",\n  q = 1,\n  ns = 1000,\n  grid_size = 50\n)\n}\n\\arguments{\n\\item{ascm}{Fitted `augsynth` object}\n\n\\item{alpha}{Confidence level}\n\n\\item{stat_func}{Function to compute test statistic}\n\n\\item{type}{Either \"iid\" for iid permutations or \"block\" for moving block permutations; default is \"iid\"}\n\n\\item{q}{The norm for the test static `((sum(x ^ q))) ^ (1/q)`}\n\n\\item{ns}{Number of resamples for \"iid\" permutations}\n\n\\item{grid_size}{Number of grid points to use when inverting the hypothesis test}\n}\n\\value{\nList that contains:\n        \\itemize{\n         \\item{\"att\"}{Vector of ATT estimates}\n         \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n         \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n         \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n         \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n         \\item{\"p_val\"}{p-value for test of no post-treatment effect}\n         \\item{\"alpha\"}{Level of confidence interval}\n        }\n}\n\\description{\nConformal inference procedure to compute a confidence interval for a linear in time effect\n}\n"
  },
  {
    "path": "man/conformal_inf_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf_multiout}\n\\alias{conformal_inf_multiout}\n\\title{Conformal inference procedure to compute p-values and point-wise confidence intervals}\n\\usage{\nconformal_inf_multiout(\n  ascm_multi,\n  alpha = 0.05,\n  stat_func = NULL,\n  type = \"iid\",\n  q = 1,\n  ns = 1000,\n  grid_size = 1,\n  lin_h0 = NULL\n)\n}\n\\arguments{\n\\item{alpha}{Confidence level}\n\n\\item{stat_func}{Function to compute test statistic}\n\n\\item{type}{Either \"iid\" for iid permutations or \"block\" for moving block permutations}\n\n\\item{q}{The norm for the test static `((sum(x ^ q))) ^ (1/q)`}\n\n\\item{ns}{Number of resamples for \"iid\" permutations}\n\n\\item{grid_size}{Number of grid points to use when inverting the hypothesis test (default is 1, so only to test joint null)}\n\n\\item{ascm}{Fitted `augsynth` object}\n}\n\\value{\nList that contains:\n        \\itemize{\n         \\item{\"att\"}{Vector of ATT estimates}\n         \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n         \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n         \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n         \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n         \\item{\"p_val\"}{p-value for test of no post-treatment effect}\n         \\item{\"alpha\"}{Level of confidence interval}\n        }\n}\n\\description{\nConformal inference procedure to compute p-values and point-wise confidence intervals\n}\n"
  },
  {
    "path": "man/get_nona_donors.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/eligible_donors.R\n\\name{get_nona_donors}\n\\alias{get_nona_donors}\n\\title{Get donors that don't have missing outcomes where treated units have outcomes}\n\\usage{\nget_nona_donors(X, y, trt, n_lags, n_leads, time_cohort)\n}\n\\description{\nGet donors that don't have missing outcomes where treated units have outcomes\n}\n"
  },
  {
    "path": "man/jackknife_se_single.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{jackknife_se_single}\n\\alias{jackknife_se_single}\n\\title{Estimate standard errors for single ASCM with the jackknife\nDo this for ridge-augmented synth}\n\\usage{\njackknife_se_single(ascm)\n}\n\\arguments{\n\\item{ascm}{Fitted augsynth object}\n}\n\\value{\nList that contains:\n        \\itemize{\n         \\item{\"att\"}{Vector of ATT estimates}\n         \\item{\"se\"}{Standard error estimate}\n         \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n         \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n         \\item{\"alpha\"}{Level of confidence interval}\n        }\n}\n\\description{\nEstimate standard errors for single ASCM with the jackknife\nDo this for ridge-augmented synth\n}\n"
  },
  {
    "path": "man/kansas.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/data.R\n\\docType{data}\n\\name{kansas}\n\\alias{kansas}\n\\title{Economic indicators for US states from 1990-2016}\n\\format{\nA dataframe with 5250 rows and 32 variables:\n\\describe{\n \\item{fips}{FIPS code for each state}\n \\item{year}{Year of measurement}\n \\item{qtr}{Quarter (1-4) of measurement}\n \\item{state}{Name of State}\n \\item{gdp}{Gross State Product (millions of $) Values before 2005 are linearly interpolated between years}\n \\item{revenuepop}{State and local revenue per capita}\n \\item{rev_state_total}{State total general revenue (millions of $)}\n \\item{rev_local_total}{Local total general revenue (millions of $)}\n \\item{popestimate}{Population estimate}\n \\item{qtrly_estabs_count}{Count of establishments for a given quarter}\n \\item{month1_emplvl, month2_emplvl, month3_emplvl}{ Employment level for first, second, and third months of a given quarter}\n \\item{total_qtrly_wages}{Total wages for a givne quarter}\n \\item{taxable_qtrly_wage}{Taxable wages for a given quarter}\n \\item{avg_wkly_wage}{Average weekly wage for a given quarter}\n \\item{year_qtr}{Year and quarter combined into one continuous variable}\n \\item{treated}{Whether the state passed tax cuts before the given year and quareter}\n \\item{lngdpcapita}{Natural log of GDP per capita}\n \\item{emplvlcapita}{Average employment level per capita}\n \\item{Xcapita}{Per capita value of X}\n \\item{abb}{State abbreviation}\n}\n}\n\\usage{\nkansas\n}\n\\description{\nEconomic indicators for US states from 1990-2016\n}\n\\keyword{datasets}\n"
  },
  {
    "path": "man/make_V_matrix.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/fit_synth.R\n\\name{make_V_matrix}\n\\alias{make_V_matrix}\n\\title{Make a V matrix from a vector (or null)}\n\\usage{\nmake_V_matrix(t0, V)\n}\n\\description{\nMake a V matrix from a vector (or null)\n}\n"
  },
  {
    "path": "man/multisynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{multisynth}\n\\alias{multisynth}\n\\title{Fit staggered synth}\n\\usage{\nmultisynth(\n  form,\n  unit,\n  time,\n  data,\n  n_leads = NULL,\n  n_lags = NULL,\n  nu = NULL,\n  lambda = 0,\n  V = NULL,\n  fixedeff = TRUE,\n  n_factors = 0,\n  scm = T,\n  time_cohort = F,\n  how_match = \"knn\",\n  cov_agg = NULL,\n  eps_abs = 1e-04,\n  eps_rel = 1e-04,\n  verbose = FALSE,\n  ...\n)\n}\n\\arguments{\n\\item{form}{outcome ~ treatment | weighting covariates | approximate matching covaraites | exact matching covariates\n\\itemize{\n   \\item{outcome}{Name of the outcome of interest}\n   \\item{treatment}{Name of the treatment assignment variable}\n   \\item{weighting covariates}{Auxiliary covariates to weight on}\n   \\item{approximate matching covariates}{Auxiliary covariates to approximately match one before weighting}\n   \\item{exact matching covariates}{Auxiliary covariates to exactly match on before weighting}\n}\nIf covariates are time-varying, their average value before the first unit is treated will be used. This can be changed by supplying a custom aggregation function to cov_agg.}\n\n\\item{unit}{Name of unit column}\n\n\\item{time}{Name of time column}\n\n\\item{data}{Panel data as dataframe}\n\n\\item{n_leads}{How long past treatment effects should be estimated for, default is number of post treatment periods for last treated unit}\n\n\\item{n_lags}{Number of pre-treatment periods to balance, default is to balance all periods}\n\n\\item{nu}{Fraction of balance for individual balance}\n\n\\item{lambda}{Regularization hyperparameter, default = 0}\n\n\\item{V}{Scaling matrix for synth optimization, default NULL is identity}\n\n\\item{fixedeff}{Whether to include a unit fixed effect, default TRUE}\n\n\\item{n_factors}{Number of factors for interactive fixed effects, setting to NULL fits with CV, default is 0}\n\n\\item{scm}{Whether to fit scm weights}\n\n\\item{time_cohort}{Whether to average synthetic controls into time cohorts, default FALSE}\n\n\\item{cov_agg}{Covariate aggregation function}\n\n\\item{eps_abs}{Absolute error tolerance for osqp}\n\n\\item{eps_rel}{Relative error tolerance for osqp}\n\n\\item{verbose}{Whether to print logs for osqp}\n\n\\item{...}{Extra arguments}\n}\n\\value{\nmultisynth object that contains:\n        \\itemize{\n         \\item{\"weights\"}{weights matrix where each column is a set of weights for a treated unit}\n         \\item{\"data\"}{Panel data as matrices}\n         \\item{\"imbalance\"}{Matrix of treatment minus synthetic control for pre-treatment time periods, each column corresponds to a treated unit}\n         \\item{\"global_l2\"}{L2 imbalance for the pooled synthetic control}\n         \\item{\"scaled_global_l2\"}{L2 imbalance for the pooled synthetic control, scaled by the imbalance for unitform weights}\n         \\item{\"ind_l2\"}{Average L2 imbalance for the individual synthetic controls}\n         \\item{\"scaled_ind_l2\"}{Average L2 imbalance for the individual synthetic controls, scaled by the imbalance for unitform weights}\n        \\item{\"n_leads\", \"n_lags\"}{Number of post treatment outcomes (leads) and pre-treatment outcomes (lags) to include in the analysis}\n         \\item{\"nu\"}{Fraction of balance for individual balance}\n         \\item{\"lambda\"}{Regularization hyperparameter}\n         \\item{\"scm\"}{Whether to fit scm weights}\n         \\item{\"grps\"}{Time periods for treated units}\n         \\item{\"y0hat\"}{Pilot estimates of control outcomes}\n         \\item{\"residuals\"}{Difference between the observed outcomes and the pilot estimates}\n         \\item{\"n_factors\"}{Number of factors for interactive fixed effects}\n        }\n}\n\\description{\nFit staggered synth\n}\n"
  },
  {
    "path": "man/plot.augsynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{plot.augsynth}\n\\alias{plot.augsynth}\n\\title{Plot function for augsynth}\n\\usage{\n\\method{plot}{augsynth}(x, inf = T, cv = F, ...)\n}\n\\arguments{\n\\item{x}{Augsynth object to be plotted}\n\n\\item{inf}{Boolean, whether to get confidence intervals around the point estimates}\n\n\\item{cv}{If True, plot cross validation MSE against hyper-parameter, otherwise plot effects}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPlot function for augsynth\n}\n"
  },
  {
    "path": "man/plot.augsynth_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{plot.augsynth_multiout}\n\\alias{plot.augsynth_multiout}\n\\title{Plot function for summary function for augsynth}\n\\usage{\n\\method{plot}{augsynth_multiout}(x, inf = T, plt_avg = F, ...)\n}\n\\arguments{\n\\item{x}{summary.augsynth_multiout object}\n\n\\item{inf}{Boolean, whether to plot uncertainty intervals, default TRUE}\n\n\\item{plt_avg}{Boolean, whether to plot the average of the outcomes, default FALSE}\n\n\\item{...}{Optional arguments for summary function}\n}\n\\description{\nPlot function for summary function for augsynth\n}\n"
  },
  {
    "path": "man/plot.multisynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{plot.multisynth}\n\\alias{plot.multisynth}\n\\title{Plot function for multisynth}\n\\usage{\n\\method{plot}{multisynth}(\n  x,\n  inf_type = \"bootstrap\",\n  inf = T,\n  levels = NULL,\n  label = T,\n  weights = FALSE,\n  ...\n)\n}\n\\arguments{\n\\item{x}{Augsynth object to be plotted}\n\n\\item{inf_type}{Type of inference to perform:\n \\itemize{\n   \\item{bootstrap}{Wild bootstrap, the default option}\n   \\item{jackknife}{Jackknife}\n}}\n\n\\item{inf}{Whether to compute and plot confidence intervals}\n\n\\item{levels}{Which units/groups to plot, default is every group}\n\n\\item{label}{Whether to label the individual levels}\n\n\\item{weights}{Whether to plot the weights, default = FALSE}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPlot function for multisynth\n}\n"
  },
  {
    "path": "man/plot.summary.augsynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{plot.summary.augsynth}\n\\alias{plot.summary.augsynth}\n\\title{Plot function for summary function for augsynth}\n\\usage{\n\\method{plot}{summary.augsynth}(x, inf = T, ...)\n}\n\\arguments{\n\\item{x}{Summary object}\n\n\\item{inf}{Boolean, whether to plot confidence intervals}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPlot function for summary function for augsynth\n}\n"
  },
  {
    "path": "man/plot.summary.augsynth_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{plot.summary.augsynth_multiout}\n\\alias{plot.summary.augsynth_multiout}\n\\title{Plot function for summary function for augsynth}\n\\usage{\n\\method{plot}{summary.augsynth_multiout}(x, inf = F, plt_avg = F, ...)\n}\n\\arguments{\n\\item{x}{summary.augsynth_multiout object}\n\n\\item{inf}{Boolean, whether to plot uncertainty intervals, default TRUE}\n\n\\item{plt_avg}{Boolean, whether to plot the average of the outcomes, default FALSE}\n}\n\\description{\nPlot function for summary function for augsynth\n}\n"
  },
  {
    "path": "man/plot.summary.multisynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{plot.summary.multisynth}\n\\alias{plot.summary.multisynth}\n\\title{Plot function for summary function for multisynth}\n\\usage{\n\\method{plot}{summary.multisynth}(x, inf = T, levels = NULL, label = T, weights = FALSE, ...)\n}\n\\arguments{\n\\item{x}{summary object}\n\n\\item{inf}{Whether to plot confidence intervals}\n\n\\item{levels}{Which units/groups to plot, default is every group}\n\n\\item{label}{Whether to label the individual levels}\n\n\\item{weights}{Whether to plot the weights, default = FALSE}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPlot function for summary function for multisynth\n}\n"
  },
  {
    "path": "man/predict.augsynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{predict.augsynth}\n\\alias{predict.augsynth}\n\\title{Get prediction of ATT or average outcome under control}\n\\usage{\n\\method{predict}{augsynth}(object, att = F, ...)\n}\n\\arguments{\n\\item{object}{augsynth object}\n\n\\item{att}{If TRUE, return the ATT, if FALSE, return imputed counterfactual}\n\n\\item{...}{Optional arguments}\n}\n\\value{\nVector of predicted post-treatment control averages\n}\n\\description{\nGet prediction of ATT or average outcome under control\n}\n"
  },
  {
    "path": "man/predict.augsynth_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{predict.augsynth_multiout}\n\\alias{predict.augsynth_multiout}\n\\title{Get prediction of ATT or average outcome under control}\n\\usage{\n\\method{predict}{augsynth_multiout}(object, ...)\n}\n\\arguments{\n\\item{object}{augsynth_multiout object}\n\n\\item{...}{Optional arguments, including \\itemize{\\item{\"att\"}{Whether to return the ATT or average outcome under control}}}\n}\n\\value{\nVector of predicted post-treatment control averages\n}\n\\description{\nGet prediction of ATT or average outcome under control\n}\n"
  },
  {
    "path": "man/predict.multisynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{predict.multisynth}\n\\alias{predict.multisynth}\n\\title{Get prediction of average outcome under control or ATT}\n\\usage{\n\\method{predict}{multisynth}(object, att = F, att_weight = NULL, bs_weight = NULL, ...)\n}\n\\arguments{\n\\item{object}{Fit multisynth object}\n\n\\item{att}{If TRUE, return the ATT, if FALSE, return imputed counterfactual}\n\n\\item{att_weight}{Weights to place on individual units/cohorts when averaging}\n\n\\item{bs_weight}{Weight to perturb units by for weighted bootstrap}\n\n\\item{...}{Optional arguments}\n}\n\\value{\nMatrix of predicted post-treatment control outcomes for each treated unit\n}\n\\description{\nGet prediction of average outcome under control or ATT\n}\n"
  },
  {
    "path": "man/print.augsynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{print.augsynth}\n\\alias{print.augsynth}\n\\title{Print function for augsynth}\n\\usage{\n\\method{print}{augsynth}(x, ...)\n}\n\\arguments{\n\\item{x}{augsynth object}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPrint function for augsynth\n}\n"
  },
  {
    "path": "man/print.augsynth_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{print.augsynth_multiout}\n\\alias{print.augsynth_multiout}\n\\title{Print function for augsynth}\n\\usage{\n\\method{print}{augsynth_multiout}(x, ...)\n}\n\\arguments{\n\\item{x}{augsynth_multiout object}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPrint function for augsynth\n}\n"
  },
  {
    "path": "man/print.multisynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{print.multisynth}\n\\alias{print.multisynth}\n\\title{Print function for multisynth}\n\\usage{\n\\method{print}{multisynth}(x, att_weight = NULL, ...)\n}\n\\arguments{\n\\item{x}{multisynth object}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPrint function for multisynth\n}\n"
  },
  {
    "path": "man/print.summary.augsynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{print.summary.augsynth}\n\\alias{print.summary.augsynth}\n\\title{Print function for summary function for augsynth}\n\\usage{\n\\method{print}{summary.augsynth}(x, ...)\n}\n\\arguments{\n\\item{x}{summary object}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPrint function for summary function for augsynth\n}\n"
  },
  {
    "path": "man/print.summary.augsynth_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{print.summary.augsynth_multiout}\n\\alias{print.summary.augsynth_multiout}\n\\title{Print function for summary function for augsynth}\n\\usage{\n\\method{print}{summary.augsynth_multiout}(x, ...)\n}\n\\arguments{\n\\item{x}{summary.augsynth_multiout object}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPrint function for summary function for augsynth\n}\n"
  },
  {
    "path": "man/print.summary.multisynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{print.summary.multisynth}\n\\alias{print.summary.multisynth}\n\\title{Print function for summary function for multisynth}\n\\usage{\n\\method{print}{summary.multisynth}(x, level = \"Average\", ...)\n}\n\\arguments{\n\\item{x}{summary object}\n\n\\item{level}{Which unit/group to print results for, default is the overall average}\n\n\\item{...}{Optional arguments}\n}\n\\description{\nPrint function for summary function for multisynth\n}\n"
  },
  {
    "path": "man/rdirichlet_b.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rdirichlet_b}\n\\alias{rdirichlet_b}\n\\title{Bayesian bootstrap}\n\\usage{\nrdirichlet_b(n)\n}\n\\arguments{\n\\item{n}{Number of units}\n}\n\\description{\nBayesian bootstrap\n}\n"
  },
  {
    "path": "man/rmultinom_b.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rmultinom_b}\n\\alias{rmultinom_b}\n\\title{Non-parametric bootstrap}\n\\usage{\nrmultinom_b(n)\n}\n\\arguments{\n\\item{n}{Number of units}\n}\n\\description{\nNon-parametric bootstrap\n}\n"
  },
  {
    "path": "man/rwild_b.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rwild_b}\n\\alias{rwild_b}\n\\title{Wild bootstrap (Mammen 1993)}\n\\usage{\nrwild_b(n)\n}\n\\arguments{\n\\item{n}{Number of units}\n}\n\\description{\nWild bootstrap (Mammen 1993)\n}\n"
  },
  {
    "path": "man/single_augsynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{single_augsynth}\n\\alias{single_augsynth}\n\\title{Fit Augmented SCM}\n\\usage{\nsingle_augsynth(\n  form,\n  unit,\n  time,\n  t_int,\n  data,\n  progfunc = \"ridge\",\n  scm = T,\n  fixedeff = FALSE,\n  cov_agg = NULL,\n  ...\n)\n}\n\\arguments{\n\\item{form}{outcome ~ treatment | auxillary covariates}\n\n\\item{unit}{Name of unit column}\n\n\\item{time}{Name of time column}\n\n\\item{t_int}{Time of intervention}\n\n\\item{data}{Panel data as dataframe}\n\n\\item{progfunc}{What function to use to impute control outcomes\nridge=Ridge regression (allows for standard errors),\nnone=No outcome model,\nen=Elastic Net, RF=Random Forest, GSYN=gSynth,\nmcp=MCPanel, \ncits=Comparitive Interuppted Time Series\ncausalimpact=Bayesian structural time series with CausalImpact}\n\n\\item{scm}{Whether the SCM weighting function is used}\n\n\\item{fixedeff}{Whether to include a unit fixed effect, default F}\n\n\\item{cov_agg}{Covariate aggregation functions, if NULL then use mean with NAs omitted}\n\n\\item{...}{optional arguments for outcome model}\n}\n\\value{\naugsynth object that contains:\n        \\itemize{\n         \\item{\"weights\"}{Ridge ASCM weights}\n         \\item{\"l2_imbalance\"}{Imbalance in pre-period outcomes, measured by the L2 norm}\n         \\item{\"scaled_l2_imbalance\"}{L2 imbalance scaled by L2 imbalance of uniform weights}\n         \\item{\"mhat\"}{Outcome model estimate}\n         \\item{\"data\"}{Panel data as matrices}\n        }\n}\n\\description{\nFit Augmented SCM\n}\n"
  },
  {
    "path": "man/summary.augsynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{summary.augsynth}\n\\alias{summary.augsynth}\n\\title{Summary function for augsynth}\n\\usage{\n\\method{summary}{augsynth}(object, inf = T, inf_type = \"conformal\", linear_effect = F, ...)\n}\n\\arguments{\n\\item{object}{augsynth object}\n\n\\item{inf}{Boolean, whether to get confidence intervals around the point estimates}\n\n\\item{inf_type}{Type of inference algorithm. Options are\n\\itemize{\n \\item{\"conformal\"}{Conformal inference (default)}\n \\item{\"jackknife+\"}{Jackknife+ algorithm over time periods}\n \\item{\"jackknife\"}{Jackknife over units}\n}}\n\n\\item{linear_effect}{Boolean, whether to invert the conformal inference hypothesis test to get confidence intervals for a linear-in-time treatment effect: intercept + slope * time}\n\n\\item{...}{Optional arguments for inference, for more details for each `inf_type` see\n\\itemize{\n \\item{\"conformal\"}{`conformal_inf`}\n \\item{\"jackknife+\"}{`time_jackknife_plus`}\n \\item{\"jackknife\"}{`jackknife_se_single`}\n}}\n}\n\\description{\nSummary function for augsynth\n}\n"
  },
  {
    "path": "man/summary.augsynth_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{summary.augsynth_multiout}\n\\alias{summary.augsynth_multiout}\n\\title{Summary function for augsynth}\n\\usage{\n\\method{summary}{augsynth_multiout}(object, inf = T, inf_type = \"conformal\", grid_size = 1, ...)\n}\n\\arguments{\n\\item{object}{augsynth_multiout object}\n\n\\item{inf}{whether or not to perform inference}\n\n\\item{grid_size}{Grid to compute prediction intervals over, default is 1 and only p-values are computed}\n\n\\item{...}{Optional arguments, including \\itemize{\\item{\"se\"}{Whether to plot standard error}}}\n\n\\item{inf_typ}{Type of inference, default is \"conformal\"}\n}\n\\description{\nSummary function for augsynth\n}\n"
  },
  {
    "path": "man/summary.multisynth.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{summary.multisynth}\n\\alias{summary.multisynth}\n\\title{Summary function for multisynth}\n\\usage{\n\\method{summary}{multisynth}(object, inf_type = \"bootstrap\", att_weight = NULL, ...)\n}\n\\arguments{\n\\item{object}{multisynth object}\n\n\\item{inf_type}{Type of inference to perform:\n \\itemize{\n   \\item{bootstrap}{Wild bootstrap, the default option}\n   \\item{jackknife}{Jackknife}\n}}\n\n\\item{...}{Optional arguments}\n}\n\\value{\nsummary.multisynth object that contains:\n        \\itemize{\n         \\item{\"att\"}{Dataframe with ATT estimates, standard errors for each treated unit}\n         \\item{\"global_l2\"}{L2 imbalance for the pooled synthetic control}\n         \\item{\"scaled_global_l2\"}{L2 imbalance for the pooled synthetic control, scaled by the imbalance for unitform weights}\n         \\item{\"ind_l2\"}{Average L2 imbalance for the individual synthetic controls}\n         \\item{\"scaled_ind_l2\"}{Average L2 imbalance for the individual synthetic controls, scaled by the imbalance for unitform weights}\n        \\item{\"n_leads\", \"n_lags\"}{Number of post treatment outcomes (leads) and pre-treatment outcomes (lags) to include in the analysis}\n        }\n}\n\\description{\nSummary function for multisynth\n}\n"
  },
  {
    "path": "man/time_jackknife_plus.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{time_jackknife_plus}\n\\alias{time_jackknife_plus}\n\\title{Jackknife+ algorithm over time}\n\\usage{\ntime_jackknife_plus(ascm, alpha = 0.05, conservative = F)\n}\n\\arguments{\n\\item{ascm}{Fitted `augsynth` object}\n\n\\item{alpha}{Confidence level}\n\n\\item{conservative}{Whether to use the conservative jackknife+ procedure}\n}\n\\value{\nList that contains:\n        \\itemize{\n         \\item{\"att\"}{Vector of ATT estimates}\n         \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n         \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n         \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n         \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n         \\item{\"alpha\"}{Level of confidence interval}\n        }\n}\n\\description{\nJackknife+ algorithm over time\n}\n"
  },
  {
    "path": "man/time_jackknife_plus_multiout.Rd",
    "content": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{time_jackknife_plus_multiout}\n\\alias{time_jackknife_plus_multiout}\n\\title{Jackknife+ algorithm over time}\n\\usage{\ntime_jackknife_plus_multiout(ascm_multi, alpha = 0.05, conservative = F)\n}\n\\arguments{\n\\item{alpha}{Confidence level}\n\n\\item{conservative}{Whether to use the conservative jackknife+ procedure}\n\n\\item{ascm}{Fitted `augsynth` object}\n}\n\\value{\nList that contains:\n        \\itemize{\n         \\item{\"att\"}{Vector of ATT estimates}\n         \\item{\"heldout_att\"}{Vector of ATT estimates with the time period held out}\n         \\item{\"se\"}{Standard error, always NA but returned for compatibility}\n         \\item{\"lb\"}{Lower bound of 1 - alpha confidence interval}\n         \\item{\"ub\"}{Upper bound of 1 - alpha confidence interval}\n         \\item{\"alpha\"}{Level of confidence interval}\n        }\n}\n\\description{\nJackknife+ algorithm over time\n}\n"
  },
  {
    "path": "pkg.Rproj",
    "content": "Version: 1.0\n\nRestoreWorkspace: No\nSaveWorkspace: No\nAlwaysSaveHistory: Default\n\nEnableCodeIndexing: Yes\nEncoding: UTF-8\n\nAutoAppendNewline: Yes\nStripTrailingWhitespace: Yes\n\nBuildType: Package\nPackageUseDevtools: Yes\nPackageInstallArgs: --no-multiarch --with-keep.source\nPackageRoxygenize: rd,collate,namespace\n"
  },
  {
    "path": "tests/testthat/test_augsynth_pre.R",
    "content": "context(\"Testing that top level API runs the right functions\")\n\nlibrary(Synth)\n\n\ntest_that(\"augsynth runs single_synth when there is a single treated unit\", {\n\n  data(basque)\n  basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                              regionno != 17 ~0,\n                                              regionno == 17 ~ 1)) %>%\n      filter(regionno != 1)\n\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque,\n                  progfunc = \"None\", scm = T, t_int = 1975)\n\n  syn_single <- single_augsynth(gdpcap ~ trt, regionno, year, basque,\n                                progfunc = \"None\", scm = T, t_int = 1975)\n\n  expect_equal(syn$weights, syn_single$weights)\n})\n\n\ntest_that(\"augsynth finds the treated time when is a single treated unit\", {\n\n  data(basque)\n  basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                              regionno != 17 ~0,\n                                              regionno == 17 ~ 1)) %>%\n      filter(regionno != 1)\n\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque,\n                  progfunc = \"None\", scm = T, t_int = 1975)\n\n  syn2 <- augsynth(gdpcap ~ trt, regionno, year, basque,\n                  progfunc = \"None\", scm = T)\n\n  expect_equal(syn$weights, syn2$weights, tolerance = 1e-6)\n\n  # should work with out of order time as well\n  syn_rev <- augsynth(gdpcap ~ trt, regionno, year,\n                      basque %>% arrange(desc(year)),\n                      progfunc = \"None\", scm = T)\n  expect_equal(syn$weights, syn_rev$weights)\n  expect_equal(predict(syn), predict(syn_rev))\n})\n\n\ntest_that(\"augsynth runs single_synth when there is simultaneous adoption\", {\n\n  data(basque)\n  basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                              !regionno %in% c(16, 17) ~0,\n                                              regionno %in% c(16, 17) ~ 1)) %>%\n      filter(regionno != 1)\n\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque,\n                  progfunc = \"None\", scm = T, t_int = 1975)\n\n  syn_single <- single_augsynth(gdpcap ~ trt, regionno, year, basque,\n                                progfunc = \"None\", scm = T, t_int = 1975)\n\n  expect_equal(syn$weights, syn_single$weights)\n})\n\n\ntest_that(\"augsynth runs multisynth when there is staggered adoption\", {\n\n  data(basque)\n  basque <- basque %>% mutate(trt = case_when((regionno == 17) & (year == 1975) ~ 1,\n                                              (regionno == 16) & (year == 1980) ~ 1,\n                                              TRUE ~ 0)) %>%\n      filter(regionno != 1)\n\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque, scm = T)\n\n  syn_multi <- multisynth(gdpcap ~ trt, regionno, year, basque)\n\n  expect_equal(syn$weights, syn_multi$weights,  tolerance = 1e-5)\n})\n\n\n\ntest_that(\"augsynth with a single treated unit doesn't depend on unit order\", {\n\n  data(kansas)\n\n\n  syn <- augsynth(lngdpcapita ~ treated | log(revstatecapita), abb,\n                  year_qtr, kansas, progfunc = \"None\", scm = T)\n\n  syn2 <- augsynth(lngdpcapita ~ treated | log(revstatecapita), fips, year_qtr,\n                   kansas %>% arrange(desc(fips)), progfunc = \"None\", scm = T)\n\n\n  expect_equal(predict(syn), predict(syn2))\n\n\n  asyn <- augsynth(lngdpcapita ~ treated | log(revstatecapita), fips,\n                  year_qtr, kansas, progfunc = \"ridge\", scm = T)\n\n  asyn2 <- augsynth(lngdpcapita ~ treated | log(revstatecapita), fips, year_qtr,\n                   kansas %>% arrange(desc(fips)), progfunc = \"ridge\", scm = T)\n\n  expect_equal(c(asyn$weights), c(asyn2$weights))\n  expect_equal(predict(asyn), predict(asyn2))\n\n\n})\n\n\n\n\ntest_that(\"augsynth runs single_synth with progfunc = 'ridge' when there is a single treated unit and no progfunc is specified\", {\n\n  data(basque)\n  basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                              regionno != 17 ~0,\n                                              regionno == 17 ~ 1)) %>%\n      filter(regionno != 1)\n\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque, scm = T)\n\n  syn_single <- single_augsynth(gdpcap ~ trt, regionno, year, basque,\n                                progfunc = \"ridge\", scm = T, t_int = 1975)\n\n  expect_equal(syn$weights, syn_single$weights)\n})\n"
  },
  {
    "path": "tests/testthat/test_format.R",
    "content": "context(\"Test data formatting\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                            regionno != 17 ~0,\n                                            regionno == 17 ~ 1)) %>%\n    filter(regionno != 1)\n                            \ntest_that(\"format_data creates matrices with the right dimensions\", {\n    \n    dat <- format_data(quo(gdpcap), quo(trt), quo(regionno), quo(year),1975, basque)\n\n    test_dim <- function(obj, d) {\n        expect_equivalent(dim(obj), d)\n        }\n\n    test_dim(dat$X, c(17, 20))\n    expect_equivalent(length(dat$trt), 17)\n    test_dim(dat$y, c(17, 23))\n}\n)\n\n\ntest_that(\"format_synth creates matrices with the right dimensions\", {\n    \n    dat <- format_data(quo(gdpcap), quo(trt), quo(regionno), quo(year),1975, basque)\n    syn_dat <- format_synth(dat$X, dat$trt, dat$y)\n    test_dim <- function(obj, d) {\n        expect_equivalent(dim(obj), d)\n        }\n\n    test_dim(syn_dat$Z0, c(20, 16))\n    test_dim(syn_dat$Z1, c(20, 1))\n\n    test_dim(syn_dat$Y0plot, c(43, 16))\n    test_dim(syn_dat$Y1plot, c(43, 1))\n\n    expect_equivalent(syn_dat$Z1, syn_dat$X1)\n    expect_equivalent(syn_dat$Z0, syn_dat$X0)\n}\n)\n\n\ntest_that(\"multisynth throws errors when there aren't enough pre-treatment times\",\n  {\n    basque2 <- basque %>%\n      mutate(trt = case_when(\n        regionno == 16 ~ 1,\n        year >= 1975 & regionno == 17 ~ 1,\n        TRUE ~ 0)\n        ) %>%\n      filter(regionno != 1)\n\n  # error from always treated unit\n  expect_warning(\n    expect_error(\n      multisynth(gdpcap ~ trt, regionno, year, basque2)\n    )\n  )\n\n  basque2 <- basque %>%\n      mutate(trt = case_when(\n        regionno == 16 & year >= 1956 ~ 1,\n        year >= 1975 & regionno == 17 ~ 1,\n        TRUE ~ 0)\n        ) %>%\n      filter(regionno != 1)\n\n  # error from one pre-treatment outcome and fixedeff = T\n  expect_warning(\n    expect_error(multisynth(gdpcap ~ trt, regionno, year, basque2))\n  )\n\n  # no error from one pre-treatment outcome and fixedeff = F, just warning\n  expect_warning(multisynth(gdpcap ~ trt, regionno, year, basque2, fixedeff = F))\n\n  })\n\n\n  test_that(\"formatting for staggered adoption doesn't care about order of time in data\",\n  {\n    basque2 <- basque %>%\n      # slice(sample(1:n())) %>%\n      mutate(trt = case_when((regionno == 17) & (year >= 1975) ~ 1,\n                              (regionno == 16) & (year >= 1980) ~ 1,\n                                TRUE ~ 0))\n\n      dat <- format_data_stag(quo(gdpcap), quo(trt), quo(regionno), quo(year), basque2)\n\n      # true treatment times\n      true_trt <- c(1975, 1980) - min(basque$year)\n\n      expect_equal(true_trt, sort(dat$trt[is.finite(dat$trt)]))\n\n    basque2 <- basque %>%\n        slice(sample(1:n())) %>%\n        mutate(trt = case_when((regionno == 17) & (year >= 1975) ~ 1,\n                                (regionno == 16) & (year >= 1980) ~ 1,\n                                  TRUE ~ 0))\n\n    dat <- format_data_stag(quo(gdpcap), quo(trt), quo(regionno), quo(year), basque2)\n\n    expect_equal(true_trt, sort(dat$trt[is.finite(dat$trt)]))\n\n  })\n\n\ntest_that(\"augsynth exits with error message if there are no never treated units\",\n  {\n  basque2 <- basque %>%\n    # slice(sample(1:n())) %>%\n    mutate(trt = case_when((regionno == 17) & (year >= 1975) ~ 1,\n                            (year >= 1997) ~ 1,\n                              TRUE ~ 0))\n  expect_error(augsynth(gdpcap ~ trt, regionno, year, basque2), \"1996\")\n\n  })"
  },
  {
    "path": "tests/testthat/test_general.R",
    "content": "context(\"Generally testing the workflow for augsynth\")\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                            regionno != 17 ~0,\n                                            regionno == 17 ~ 1)) %>%\n    filter(regionno != 1)\n\n\n                            \ntest_that(\"SCM gives the right answer\", {\n\n    syn <- augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"None\", scm=T, t_int=1975)\n    ## average att estimate is as expected\n    expect_equal(-.3686, mean(summary(syn, inf = F)$att$Estimate), tolerance=1e-4)\n\n\n    ## level of balance is as expected\n    expect_equal(.377, syn$l2_imbalance, tolerance=1e-3)\n\n}\n)\n\ntest_that(\"SCM finds the correct t_int and gives the right answer\", {\n\n    syn1 <- augsynth(gdpcap ~ trt, regionno, year, basque,\n                     progfunc=\"None\", scm=T)\n    syn2 <- augsynth(gdpcap ~ trt, regionno, year, basque,\n                     progfunc = \"None\", scm = T, t_int = 1975)\n    ## average att estimate is as expected\n    expect_equal(mean(summary(syn1, inf = F)$att$Estimate), \n                 mean(summary(syn2, inf = F)$att$Estimate), tolerance=1e-4)\n    \n    ## level of balance is as expected\n    expect_equal(syn1$l2_imbalance, syn2$l2_imbalance, tolerance=1e-3)\n    \n}\n)\n\n\ntest_that(\"Ridge ASCM gives the right answer\", {\n\n    asyn <- augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"Ridge\",\n                     scm=T, lambda=8)\n\n    ## average att estimate is as expected\n    expect_equal(-.3696, mean(summary(asyn, inf = F)$att$Estimate), tolerance=1e-3)\n\n    ## level of balance is as expected\n    expect_equal(.373, asyn$l2_imbalance, tolerance=1e-3)\n\n}\n)\n\n\n\n\ntest_that(\"SCM after residualizing covariates gives the right answer\", {\n\n  covsyn_resid <- augsynth(gdpcap ~ trt | invest + popdens,\n                      regionno, year, basque,\n                      progfunc = \"None\", scm = T,\n                      residualize = T)\n\n  ## average att estimate is as expected\n  expect_equal(-.1443,\n                mean(summary(covsyn_resid, inf = F)$att$Estimate),\n                tolerance = 1e-3)\n\n  ## level of balance is as expected\n  expect_equal(.3720, covsyn_resid$l2_imbalance, tolerance=1e-3)\n\n  # perfect auxiliary covariate balance\n  expect_equal(0, covsyn_resid$covariate_l2_imbalance, tolerance=1e-3)\n\n}\n)\n\ntest_that(\"Ridge ASCM with covariates jointly gives the right answer\", {\n\n    covsyn_noresid <- augsynth(gdpcap ~ trt | invest + popdens,\n                       regionno, year, basque,\n                       progfunc = \"None\", scm = T,\n                       residualize = F)\n\n    ## average att estimate is as expected\n    expect_equal(-.3345,\n                 mean(summary(covsyn_noresid, inf = F)$att$Estimate),\n                 tolerance = 1e-3)\n\n    ## level of balance is as expected\n    expect_equal(0.659, covsyn_noresid$l2_imbalance, tolerance=1e-3)\n\n    # covariate balance is as expected\n    expect_equal(0.884, covsyn_noresid$covariate_l2_imbalance, tolerance=1e-3)\n\n\n}\n)\n\n\ntest_that(\"Ridge ASCM after residualizing covariates gives the right answer\", {\n\n    covascm_resid <- augsynth(gdpcap ~ trt | invest + popdens,\n                       regionno, year, basque,\n                       progfunc = \"Ridge\", scm = T,\n                       lambda = 1,\n                       residualize = T)\n\n    ## average att estimate is as expected\n    expect_equal(-.123,\n                 mean(summary(covascm_resid, inf = F)$att$Estimate),\n                 tolerance = 1e-3)\n\n    ## level of balance is as expected\n    expect_equal(.347, covascm_resid$l2_imbalance, tolerance=1e-3)\n\n    # perfect auxiliary covariate balance\n    expect_equal(0, covascm_resid$covariate_l2_imbalance, tolerance=1e-3)\n\n}\n)\n\ntest_that(\"Ridge ASCM with covariates jointly gives the right answer\", {\n\n    covascm_noresid <- augsynth(gdpcap ~ trt | invest + popdens,\n                       regionno, year, basque,\n                       progfunc = \"Ridge\", scm = T,\n                       lambda = 1,\n                       residualize = F)\n\n    ## average att estimate is as expected\n    expect_equal(-.267,\n                 mean(summary(covascm_noresid, inf = F)$att$Estimate),\n                 tolerance = 1e-3)\n\n    ## level of balance is as expected\n    expect_equal(0.419, covascm_noresid$l2_imbalance, tolerance=1e-3)\n\n    # covariate balance is as expected\n    expect_equal(0.084, covascm_noresid$covariate_l2_imbalance, tolerance=1e-3)\n\n}\n)\n\n\ntest_that(\"Warning given when inputting an unused argument\", {\n\n    expect_warning(\n      augsynth(gdpcap ~ trt| invest + popdens, regionno, year, basque, \n               progfunc=\"Ridge\", scm=T, lambda=8, t_int = 1975, \n               bad_param = \"Unused input parameter\"),\n    )\n})\n"
  },
  {
    "path": "tests/testthat/test_lambda.R",
    "content": "context(\"Testing lambda tuning if ridge is true.\")\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                            regionno != 17 ~0,\n                                            regionno == 17 ~ 1)) %>%\n  filter(regionno != 1)\n\n\ntest_that(\"Lambda sequence is generated correctly\", {\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"Ridge\", scm=T)\n  lambdas <- syn$lambdas\n  expect_equivalent(lambdas[length(lambdas)] / lambdas[1], 1e-8)\n  expect_equivalent(lambdas[2] / lambdas[1], lambdas[3] / lambdas[2])\n})\n\ntest_that(\"Smallest lambda is chosen\", {\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque, \n                  progfunc=\"Ridge\", scm=T, \n                  min_1se = F)\n  expect_equivalent(syn$lambda, syn$lambdas[which.min(syn$lambda_errors)])\n})\n\n\ntest_that(\"Largest lambda within 1 SE of minimum is chosen\", {\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque, \n                  progfunc=\"Ridge\", scm=T, \n                  min_1se = T)\n  min_idx <- which.min(syn$lambda_errors)\n  min_1se <- max(syn$lambdas[syn$lambda_errors <= \n                              syn$lambda_errors[min_idx] + \n                              syn$lambda_errors_se[min_idx]])\n  expect_equivalent(syn$lambda, min_1se)\n})\n\ntest_that(\"Max lambda is in list of returned lambas (optional parameters are going through)\", {\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque, \n                  progfunc=\"Ridge\", scm=T, \n                  lambda_max = 100)\n  lambdas <- syn$lambdas\n  expect_equivalent(lambdas[1], 100)\n})"
  },
  {
    "path": "tests/testthat/test_load_data.R",
    "content": "context(\"Testing that we can load packaged data\")\n\ntest_that(\"kansas data loads\", {\n    expect_error(data(kansas), NA)\n})"
  },
  {
    "path": "tests/testthat/test_multiple_outcomes.R",
    "content": "context(\"Generally testing the workflow for synth with multiple outcomes\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                            regionno != 17 ~0,\n                                            regionno == 17 ~ 1),\n                            gdpcap_sq = gdpcap ^ 2) %>%\n    filter(regionno != 1)\n\ntest_that(\"augsynth and augsynth_multiout are the same without augmentation\", {\n\n    syn1 <- augsynth_multiout(gdpcap + gdpcap_sq  ~ trt, regionno, year, 1975, basque,\n                    progfunc=\"None\", scm=T)\n    syn2 <- augsynth(gdpcap + gdpcap_sq  ~ trt, regionno, year, basque,\n                    progfunc=\"None\", scm=T)\n\n    # weights are the same\n    expect_equal(c(syn1$weights), c(syn2$weights), tolerance=3e-4)\n\n    # estimates are the same\n    expect_equal(c(predict(syn1, att=F)), c(predict(syn2, att = F)), tolerance=5e-5)\n\n\n    ## level of balance is same\n    expect_equal(syn1$l2_imbalance, syn2$l2_imbalance, tolerance=1e-5)        \n})\n\ntest_that(\"augsynth and augsynth_multiout are the same with ridge augmentation\", {\n\n    syn1 <- augsynth_multiout(gdpcap + gdpcap_sq  ~ trt, regionno, year, 1975, basque,\n                    progfunc=\"Ridge\", scm=T, lambda = 10)\n    syn2 <- augsynth(gdpcap + gdpcap_sq  ~ trt, regionno, year, basque,\n                    progfunc=\"Ridge\", scm=T, lambda = 10)\n\n    # weights are the same\n    expect_equal(c(syn1$weights), c(syn2$weights), tolerance=3e-4)\n\n    # estimates are the same\n    expect_equal(c(predict(syn1, att=F)), c(predict(syn2, att = F)), tolerance=5e-5)\n\n\n    ## level of balance is same\n    expect_equal(syn1$l2_imbalance, syn2$l2_imbalance, tolerance=1e-5)        \n})\n\ntest_that(\"augsynth and augsynth_multiout are the same with fixed effects augmentation\", {\n\n    syn1 <- augsynth_multiout(gdpcap + gdpcap_sq  ~ trt, regionno, year, 1975, basque,\n                    progfunc=\"None\", scm=T, fixedeff = T)\n    syn2 <- augsynth(gdpcap + gdpcap_sq  ~ trt, regionno, year, basque,\n                    progfunc=\"None\", scm=T, fixedeff = T)\n\n    # weights are the same\n    expect_equal(c(syn1$weights), c(syn2$weights), tolerance=3e-4)\n\n    # estimates are the same\n    expect_equal(c(predict(syn1, att=F)), c(predict(syn2, att = F)), tolerance=5e-5)\n\n\n    ## level of balance is same\n    expect_equal(syn1$l2_imbalance, syn2$l2_imbalance, tolerance=1e-5)\n})\n\ntest_that(\"single_augsynth and augsynth_multiout are the same for one outcome\", {\n    syn1 <- augsynth_multiout(gdpcap  ~ trt, regionno, year, 1975, basque,\n                    progfunc=\"None\", scm=T, combine_method = \"concat\")\n    syn2 <- augsynth(gdpcap  ~ trt, regionno, year, basque,\n                    progfunc=\"None\", scm=T)\n\n    # weights are the same\n    expect_equal(c(syn1$weights), c(syn2$weights), tolerance=3e-4)\n\n    # estimates are the same\n    expect_equal(c(predict(syn1, att=F)), unname(predict(syn2, att = F)), tolerance=5e-5)\n\n\n    ## level of balance is same\n    expect_equal(syn1$l2_imbalance, syn2$l2_imbalance, tolerance=1e-5)            \n})\n\ntest_that(\"single_augsynth and augsynth_multiout are the same for one outcome with ridge augmentation\",{\n    syn1 <- augsynth_multiout(gdpcap  ~ trt, regionno, year, 1975, basque,\n                    progfunc=\"Ridge\", scm=T, combine_method = \"concat\")\n    syn2 <- augsynth(gdpcap  ~ trt, regionno, year, basque,\n                    progfunc=\"Ridge\", scm=T)\n\n    # weights are the same\n    expect_equal(c(syn1$weights), c(syn2$weights), tolerance=3e-4)\n\n    # estimates are the same\n    expect_equal(c(predict(syn1, att=F)), unname(predict(syn2, att = F)),\n                 tolerance=5e-5)\n\n\n    ## level of balance is same\n    expect_equal(syn1$l2_imbalance, syn2$l2_imbalance, tolerance=1e-5)            \n})\n\ntest_that(\"single_augsynth and augsynth_multiout are the same for one outcome with fixed effect augmentation\", {\n    syn1 <- augsynth_multiout(gdpcap  ~ trt, regionno, year, 1975, basque,\n                    progfunc=\"None\", scm=T, fixedeff = T, combine_method = \"concat\")\n    syn2 <- augsynth(gdpcap  ~ trt, regionno, year, basque,\n                    progfunc=\"None\", scm=T, fixedeff = T)\n\n    # weights are the same\n    expect_equal(c(syn1$weights), c(syn2$weights), tolerance=3e-4)\n\n    # estimates are the same\n    expect_equal(c(predict(syn1, att=F)), unname(predict(syn2, att = F)), tolerance=5e-5)\n\n\n    ## level of balance is same\n    expect_equal(syn1$l2_imbalance, syn2$l2_imbalance, tolerance=1e-5)            \n})\n\n\ntest_that(\"Averaging outcomes with augsynth_multiout gives correct results without fixed effects\", {\n\n  sds <- basque %>% filter(trt == 0, year < 1975) %>%\n          summarise(across(c(gdpcap, gdpcap_sq), sd)) %>%\n          rename(gdpcap_sd = gdpcap, gdpcap_sq_sd = gdpcap_sq)\n\n  basque %>%\n    bind_cols(sds) %>%\n    mutate(avg = gdpcap / gdpcap_sd + gdpcap_sq / gdpcap_sq_sd,\n           avg2 = gdpcap + gdpcap_sq) -> bas_avg\n  \n\n  syn1 <- augsynth_multiout(gdpcap + gdpcap_sq ~ trt, regionno, year, 1975, basque,\n                            progfunc=\"None\", scm=T, fixedeff = F, combine_method = \"avg\")\n  syn2 <- augsynth(avg  ~ trt, regionno, year, bas_avg,\n                    progfunc=\"None\", scm=T, fixedeff = F)\n\n  # weights are the same\n  expect_equal(c(syn1$weights), c(syn2$weights), tolerance=3e-4)\n})\n\ntest_that(\"Averaging outcomes with augsynth_multiout gives correct results with fixed effects\", {\n\n  sds <- basque %>% filter(trt == 0, year < 1975) %>%\n          summarise(across(c(gdpcap, gdpcap_sq), sd)) %>%\n          rename(gdpcap_sd = gdpcap, gdpcap_sq_sd = gdpcap_sq)\n\n  basque %>%\n    bind_cols(sds) %>%\n    mutate(avg = gdpcap / gdpcap_sd + gdpcap_sq / gdpcap_sq_sd,\n           avg2 = gdpcap + gdpcap_sq) -> bas_avg\n  \n\n  syn1 <- augsynth_multiout(gdpcap + gdpcap_sq ~ trt, regionno, year, 1975, basque,\n                            progfunc=\"None\", scm=T, fixedeff = T,\n                            combine_method = \"avg\")\n  syn2 <- augsynth(avg  ~ trt, regionno, year, bas_avg,\n                    progfunc=\"None\", scm=T, fixedeff = T)\n\n  # weights are the same\n  expect_equal(c(syn1$weights), c(syn2$weights), tolerance=1e-3)\n})\n\n\n\ntest_that(\"Concatenating outcomes with augsynth_multiout gives correct results without fixed effects\", {\n\n\n  sds <- basque %>% filter(trt == 0, year < 1975) %>%\n          summarise(across(c(gdpcap, gdpcap_sq), sd)) %>%\n          rename(gdpcap_sd = gdpcap, gdpcap_sq_sd = gdpcap_sq)\n\n  basque %>%\n    bind_cols(sds) %>%\n    mutate(gdpcap = gdpcap / gdpcap_sd, gdpcap_sq = gdpcap_sq / gdpcap_sq_sd) %>%\n    select(gdpcap, gdpcap_sq, trt, year, regionno) %>%\n    pivot_longer(-c(regionno, year, trt)) %>%\n    mutate(year = ifelse(name == \"gdpcap\", year, year - 0.5)) -> bas_cat  \n\n  syn1 <- augsynth_multiout(gdpcap + gdpcap_sq ~ trt, regionno, year, 1975, basque,\n                            progfunc=\"None\", scm=T, fixedeff = F,\n                            combine_method = \"concat\")\n  syn2 <- augsynth(value  ~ trt, regionno, year, bas_cat,\n                    progfunc=\"None\", scm=T, fixedeff = F)\n\n  # weights are the same\n  expect_equal(c(syn1$weights), c(syn2$weights), tolerance=5e-4)\n})\n\n"
  },
  {
    "path": "tests/testthat/test_multisynth.R",
    "content": "context(\"Generally testing the workflow for multisynth\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                            regionno != 17 ~0,\n                                            regionno == 17 ~ 1)) %>%\n    filter(regionno != 1)\n\n\n                            \ntest_that(\"augsynth and multisynth give the same answer for a single treated unit and no augmentation\", {\n\n    syn <- single_augsynth(gdpcap ~ trt, regionno, year, 1975, basque,\n                    progfunc=\"None\", scm=T, fixedeff = F)\n    msyn <- multisynth(gdpcap ~ trt, regionno, year, basque, nu = 0,\n                       fixedeff = F, scm=T, eps_rel=1e-5, eps_abs=1e-5)\n    \n    # weights are the same-ish\n    expect_equal(c(syn$weights), c(msyn$weights[-16]), tolerance=3e-4)\n\n    # estimates are the same-ish\n    pred_msyn <- predict(msyn, att=F)[,1]\n    pred_msyn <- pred_msyn[-length(pred_msyn)]\n    expect_equal(unname(predict(syn, att=F)), pred_msyn, tolerance=5e-5)\n\n\n    ## level of balance is same-ish expected\n    expect_equal(syn$l2_imbalance, msyn$avg_l2, tolerance=1e-5)\n\n}\n)\n\n\ntest_that(\"Pooling doesn't matter for a single treated unit\", {\n\n    nopool <- multisynth(gdpcap ~ trt, regionno, year, basque, nu = 0,\n                         scm=T, eps_rel=1e-5, eps_abs=1e-5)\n    allpool <- multisynth(gdpcap ~ trt, regionno, year, basque, nu = 1,\n                          scm=T, eps_rel=1e-5, eps_abs=1e-5)\n\n    # weights are the same\n    expect_equal(nopool$weights, allpool$weights)\n\n    # estimates are the same\n    expect_equal(predict(nopool), predict(allpool))\n\n\n    ## level of balance is same-ish expected\n    expect_equal(allpool$ind_l2, nopool$ind_l2)\n\n}\n)\n\n\n\n\n\n                            \ntest_that(\"Separate synth is the same as fitting separate synths\", {\n\n\n    basque2 <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                                !regionno %in% c(16, 17) ~ 0,\n                                                regionno %in% c(16, 17) ~ 1)) %>%\n        filter(regionno != 1)\n\n\n    basque2  %>% filter(regionno != 16) %>% \n        single_augsynth(gdpcap ~ trt, regionno, year, 1975, .,\n                    progfunc=\"None\", scm=T) -> scm17\n    basque2  %>% filter(regionno != 17) %>% \n        single_augsynth(gdpcap ~ trt, regionno, year, 1975, .,\n                    progfunc=\"None\", scm=T) -> scm16\n    \n    msyn <- multisynth(gdpcap ~ trt, regionno, year, basque2, nu = 0,\n                       scm=T, eps_rel=1e-5, eps_abs=1e-5, fixedeff = F)\n    \n    # weights are the same-ish\n    sscm_weights <- unname(c(scm17$weights))\n    mscm_weights <- unname(c(msyn$weights[-c(15, 16), 2]))\n    expect_equal(sscm_weights, mscm_weights, tolerance=3e-2)\n    expect_equal(rownames(scm17$weights), rownames(as.matrix(msyn$weights[-c(15, 16), 2])))\n    # expect_equal(c(scm16$weights), c(msyn$weights[-c(15, 16), 1]), tolerance=3e-2)\n    \n    # estimates are the same-ish\n    pred_msyn <- predict(msyn, att=F)\n    pred_msyn <- pred_msyn[-nrow(pred_msyn), ]\n    expect_equal(unname(predict(scm17, att=F)), pred_msyn[, 3], tolerance=5e-3)\n    expect_equal(unname(predict(scm16, att=F)), pred_msyn[, 2], tolerance=5e-3)\n}\n)\n\ntest_that(\"Limiting number of lags works\", {\n\n\n    basque2 <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                                !regionno %in% c(16, 17) ~ 0,\n                                                regionno %in% c(16, 17) ~ 1)) %>%\n        filter(regionno != 1)\n\n    expect_error(\n      multisynth(gdpcap ~ trt, regionno, year, basque2, nu = 0,\n                 scm=T, eps_rel=1e-5, eps_abs=1e-5, n_lags =3),\n      NA\n    )\n}\n)\n\ntest_that(\"L2 imbalance computed correctly\", {\n\n  basque2 <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                              !regionno %in% c(16, 17) ~ 0,\n                                              regionno %in% c(16, 17) ~ 1)) %>%\n      filter(regionno != 1)\n\n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque2,\n                scm=T, eps_rel=1e-5, eps_abs=1e-5)\n\n  glbl <- sqrt(mean(msyn$imbalance[,1]^2))\n  ind <- sqrt(mean(\n    apply(msyn$imbalance[, -1], 2,\n          function(x) sum(x ^ 2) / sum(x != 0))))\n  avg_ind <- mean(apply(msyn$imbalance[,-1, drop = F], 2,\n              function(x) sqrt(sum(x ^ 2))))\n  expect_equal(glbl, msyn$global_l2)\n  expect_equal(avg_ind, msyn$avg_l2)\n  expect_equal(ind, msyn$ind_l2)\n})\n\ntest_that(\"V matrix is equivalent to hard thresholding\", {\n\n\n  basque2 <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                              !regionno %in% c(16, 17) ~ 0,\n                                              regionno %in% c(16, 17) ~ 1)) %>%\n      filter(regionno != 1)\n\n  V <- c(numeric(10), rep(1,5))\n  msyn1 <- multisynth(gdpcap ~ trt, regionno, year, basque2, nu = 0,\n                scm=T, eps_rel=1e-8, eps_abs=1e-8, n_lags = 15, V = V)\n\n  msyn2 <- multisynth(gdpcap ~ trt, regionno, year, basque2, nu = 0,\n                scm=T, eps_rel=1e-8, eps_abs=1e-8, n_lags = 5)\n\n  expect_equal(msyn1$weights, msyn2$weights, tolerance = 1e-5)\n  expect_equal(msyn1$global_l2, msyn2$global_l2, tolerance = 1e-5)\n  expect_equal(msyn1$avg_l2, msyn2$avg_l2, tolerance = 1e-5)\n}\n)\n\ntest_that(\"V matrix is the same for single and multi synth\", {\n\n  V <- exp(seq(log(1e-3), log(1), length.out = 20))\n\n  syn <- augsynth(gdpcap ~ trt, regionno, year, basque, progfunc = \"none\",\n                scm=T, V = V)\n\n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque,\n                scm=T, eps_rel=1e-8, eps_abs=1e-8, V = V,\n                fixed = F, nu = 0)\n\n  expect_equal(as.numeric(syn$weights), as.numeric(msyn$weights[-16, ]), tolerance = 1e-3)\n}\n)\n\n\n                            \ntest_that(\"multisynth doesn't depend on unit order\", {\n\n    basque2 <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                                !regionno %in% c(16, 17) ~ 0,\n                                                regionno %in% c(16, 17) ~ 1)) %>%\n        filter(regionno != 1)\n\n    msyn <- multisynth(gdpcap ~ trt, regionno, year, basque2, nu = 0,\n                       fixedeff = F, scm=T, eps_rel=1e-5, eps_abs=1e-5)\n\n    msyn2 <- multisynth(gdpcap ~ trt, regionno, year,\n                       basque2 %>% arrange(desc(regionno)), nu = 0,\n                       fixedeff = F, scm=T, eps_rel=1e-5, eps_abs=1e-5)\n\n    \n    # weights are the same\n    expect_equal(c(msyn$weights), c(msyn2$weights))\n\n    # estimates are the same\n    expect_equal(predict(msyn), predict(msyn2))\n\n}\n)\n\n\n                            \ntest_that(\"multisynth doesn't depend on time order\", {\n\n    basque2 <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                                !regionno %in% c(16, 17) ~ 0,\n                                                regionno %in% c(16, 17) ~ 1)) %>%\n        filter(regionno != 1)\n\n    msyn <- multisynth(gdpcap ~ trt, regionno, year, basque2, nu = 0,\n                       fixedeff = F, scm=T, eps_rel=1e-5, eps_abs=1e-5)\n\n    msyn2 <- multisynth(gdpcap ~ trt, regionno, year,\n                       basque2 %>% arrange(desc(year)), nu = 0,\n                       fixedeff = F, scm=T, eps_rel=1e-5, eps_abs=1e-5)\n\n    \n    # weights are the same\n    expect_equal(c(msyn$weights), c(msyn2$weights))\n\n    # estimates are the same\n    expect_equal(predict(msyn), predict(msyn2))\n\n}\n)\n"
  },
  {
    "path": "tests/testthat/test_multisynth_covariates.R",
    "content": "context(\"Testing multisynth with covariates\")\nset.seed(1011)\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when((regionno == 17) & (year >= 1975) ~ 1,\n                                              (regionno == 16) & (year >= 1980) ~ 1,\n                                              TRUE ~ 0)) %>%\n      filter(regionno != 1)\n\n\nregions <- basque %>% distinct(regionno) %>% pull(regionno)\n\ntest_that(\"Getting eligible donor units by exact matching works\", {\n\n  # binary variable to split on\n  fake_bin <- sample(c(0, 1), length(regions), replace = T)\n  basque %>%\n    inner_join(\n      data.frame(regionno = regions, Z = fake_bin) %>%\n        mutate(Z = case_when(regionno == 17 ~ 0,\n                             regionno == 16 ~ 1,\n                             TRUE ~ Z)\n              ),\n               by = \"regionno\") -> basque2\n\n  msyn <- multisynth(gdpcap ~ trt | 0 | 0| Z, regionno, year, basque2, nu = 0,\n                     scm = T)\n\n  # check that there is actually no weight on donors with different Z\n  expect_equal(sum(msyn$weights[fake_bin == 1, 1]), 1, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 0, 1]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 1, 2]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 0, 2]), 1, tolerance = 1e-3)\n\n\n  # again with fixed effect\n  msyn <- multisynth(gdpcap ~ trt | 0 | 0 | Z, regionno, year, basque2, nu = 0,\n                     scm = T, fixedeff = T)\n  # check that there is actually no weight on donors with different Z\n  expect_equal(sum(msyn$weights[fake_bin == 1, 1]), 1, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 0, 1]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 1, 2]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 0, 2]), 1, tolerance = 1e-3)\n})\n\ntest_that(\"Getting eligible donor units by exact matching works with factors\", {\n\n  # binary variable to split on\n  fake_fac <- sample(c(0, 1, 3), length(regions), replace = T)\n  basque %>%\n    inner_join(\n      data.frame(regionno = regions, Z = fake_fac) %>%\n        mutate(\n          Z = case_when(regionno == 17 ~ 0,\n                             regionno == 16 ~ 1,\n                             TRUE ~ Z),\n               Z = as.factor(Z)\n              ),\n               by = \"regionno\") -> basque2\n\n  msyn <- multisynth(gdpcap ~ trt | 0 | 0 | Z, regionno, year, basque2, nu = 0,\n                     scm = T)\n\n  # check that there is actually no weight on donors with different Z\n  expect_equal(sum(msyn$weights[fake_fac == 1, 1]), 1, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 0, 1]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 3, 1]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 1, 2]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 0, 2]), 1, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 3, 2]), 0, tolerance = 1e-4)\n\n\n  # again with fixed effect\n  msyn <- multisynth(gdpcap ~ trt | 0 | 0 |Z, regionno, year, basque2, nu = 0,\n                     scm = T, fixedeff = T, how_match = \"exact\")\n  # check that there is actually no weight on donors with different Z\n  expect_equal(sum(msyn$weights[fake_fac == 1, 1]), 1, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 0, 1]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 3, 1]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 1, 2]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 0, 2]), 1, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_fac == 3, 2]), 0, tolerance = 1e-4)\n})\n\ntest_that(\"K-NN finds the right number of neighbors\", {\n\n  # variables to match on\n  Z <- matrix(rnorm(length(regions) * 3), ncol = 3)\n  basque %>%\n    inner_join(\n      data.frame(regionno = regions,\n                 Z1 = Z[, 1], Z2 = Z[, 2], Z3 = Z[, 3]),\n      by = \"regionno\") -> basque2\n  \n  dat <- format_data_stag(quo(gdpcap), quo(trt), quo(regionno),\n                          quo(year), basque2)\n  k <- 3\n  donors <- get_eligible_donors(dat$trt, F, 100)\n  knn_donors <- get_knn_donors(dat$trt, Z, donors, k)\n  expect_true(all(sapply(knn_donors, sum) == k))\n\n  k <- 20\n  expect_warning(get_knn_donors(dat$trt, Z, donors, k))\n})\n\ntest_that(\"Getting eligible donor units by knn matching works\", {\n\n  # variables to match on\n  Z <- matrix(rnorm(length(regions) * 3), ncol = 3)\n  basque %>%\n    inner_join(\n      data.frame(regionno = regions,\n                 Z1 = Z[, 1], Z2 = Z[, 2], Z3 = Z[, 3]),\n      by = \"regionno\") -> basque2\n\n  # error if no k is supplied\n  expect_error(multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3, regionno, \n                          year, basque2,\n                          scm = T, how_match = \"knn\"),\n              \"Number of neighbors for knn not selected, please choose k.\")\n\n  k <- 5\n  msyn <- multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3, regionno, year, \n                     basque2, scm = T, how_match = \"knn\", k = k)\n\n  # check that at most k units recieve non-0 weight\n  expect_lte(sum(msyn$weights[, 1] != 0), k)\n  expect_lte(sum(msyn$weights[, 2] != 0), k)\n\n  \n\n  # again with fixed effect\n  msyn <- multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3, regionno, year,\n                       basque2, scm = T, fixedeff = T, how_match = \"knn\", k = k)\n  # check that all but k units recieve exactly 0 weight\n  expect_lte(sum(msyn$weights[, 1] != 0), k)\n  expect_lte(sum(msyn$weights[, 2] != 0), k) \n\n  # without synth weights, weights are uniform\n  k <- 2\n  unimatch <- multisynth(gdpcap ~ trt| 0 | Z1 + Z2 + Z3, regionno, year,\n                     basque2, scm = T, how_match = \"knn\", k = k, lambda = 1e10)\n\n  expect_equal(unimatch$weights[unimatch$weights != 0 ], rep(1 / k, 2 * k))\n\n  # matching with more neighbors is worse\n  unimatch2 <- multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3, regionno, year, basque2,\n                     scm = T, how_match = \"knn\", k = 2.5 * k, lambda = 1e10)\n\n  trtZ <- Z[regions %in% c(16, 17),]\n  imbal1 <- sqrt(sum(sapply(1:2, \n                function(i) sum(unimatch$weights[,i] * (trtZ[i,] - Z) ^ 2 ))))\n  imbal2 <- sqrt(sum(sapply(1:2, \n                function(i) sum(unimatch2$weights[,i] * (trtZ[i,] - Z) ^ 2 ))))\n\n  expect_lt(imbal1, imbal2)\n\n})\n\n\ntest_that(\"Getting eligible donor units by exact and knn matching works\", {\n\n  # binary variable to split on\n  fake_bin <- sample(c(0, 1), length(regions), replace = T)\n\n  # variables to match on\n  Z <- matrix(rnorm(length(regions) * 3), ncol = 3)\n  basque %>%\n    inner_join(\n      data.frame(regionno = regions,\n                 Z1 = Z[, 1], Z2 = Z[, 2], Z3 = Z[, 3],\n                 Z_bin = fake_bin) %>%\n        mutate(Z_bin = case_when(regionno == 17 ~ 0,\n                             regionno == 16 ~ 1,\n                             TRUE ~ Z_bin)),\n      by = \"regionno\") -> basque2\n\n  # error if no k is supplied\n  expect_error(multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3 | Z_bin, regionno, \n                          year, basque2,\n                          scm = T, how_match = \"knn\"),\n              \"Number of neighbors for knn not selected, please choose k.\")\n\n  k <- 3\n  msyn <- multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3 | Z_bin, regionno, year, \n                     basque2, scm = T, how_match = \"knn\", k = k)\n  \n  # check that there is actually no weight on donors with different Z\n  expect_equal(sum(msyn$weights[fake_bin == 1, 1]), 1, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 0, 1]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 1, 2]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 0, 2]), 1, tolerance = 1e-4)\n  \n  # check that at most k units recieve non-0 weight\n  expect_lte(sum(msyn$weights[, 1] != 0), k)\n  expect_lte(sum(msyn$weights[, 2] != 0), k)\n  \n  # again with fixed effect\n    msyn <- multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3 | Z_bin, regionno, year,\n                       basque2, scm = T, fixedeff = T, how_match = \"knn\", k = k)\n  # check that at most k units recieve non-0 weight\n  expect_lte(sum(msyn$weights[, 1] != 0), k)\n  expect_lte(sum(msyn$weights[, 2] != 0), k)\n\n  # check that there is actually no weight on donors with different Z\n  expect_equal(sum(msyn$weights[fake_bin == 1, 1]), 1, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 0, 1]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 1, 2]), 0, tolerance = 1e-4)\n  expect_equal(sum(msyn$weights[fake_bin == 0, 2]), 1, tolerance = 1e-4) \n\n  k <- 3\n  # without synth weights, weights are uniform\n  unimatch <- multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3 | Z_bin, regionno,\n                     year, basque2, scm = T, how_match = \"knn\", k = k, lambda = 1e10)\n\n  expect_equal(unimatch$weights[unimatch$weights != 0 ], rep(1 / k, 2 * k))\n\n  # matching without exact gives better matches\n  unimatch2 <- multisynth(gdpcap ~ trt | 0 | Z1 + Z2 + Z3, regionno, year,\n                    basque2, scm = T, how_match = \"knn\", k = k, lambda = 1e10)\n\n  trtZ <- Z[regions %in% c(16, 17),]\n  imbal1 <- sqrt(sum(sapply(1:2, \n                function(i) sum(unimatch$weights[,i] * (trtZ[i,] - Z) ^ 2 ))))\n  imbal2 <- sqrt(sum(sapply(1:2, \n                function(i) sum(unimatch2$weights[,i] * (trtZ[i,] - Z) ^ 2 ))))\n\n  expect_lt(imbal2, imbal1)\n})\n\n\ntest_that(\"An error is thrown if trying to match with time cohorts or the formula is wrong\", {\n\n  # binary variable to split on\n  fake_bin <- sample(c(0, 1), length(regions), replace = T)\n\n  # variables to match on\n  Z <- matrix(rnorm(length(regions) * 3), ncol = 3)\n  basque %>%\n    inner_join(\n      data.frame(regionno = regions,\n                 Z1 = Z[, 1], Z2 = Z[, 2], Z3 = Z[, 3],\n                 Z_bin = fake_bin) %>%\n        mutate(Z_bin = case_when(regionno == 17 ~ 0,\n                             regionno == 16 ~ 1,\n                             TRUE ~ Z_bin)),\n      by = \"regionno\") %>% \n    mutate(trt = case_when((regionno == 17) & (year >= 1975) ~ 1,\n                            (regionno == 16) & (year >= 1975) ~ 1,\n                                              TRUE ~ 0)) %>%\n      filter(regionno != 1)-> basque2\n\n  expect_error(multisynth(gdpcap ~ trt | Z1 + Z2, regionno, year, basque2,\n                     time_cohort = T), NA)\n\n  expect_error(multisynth(gdpcap ~ trt | Z1 + Z2 | 0 | Z_bin,\n                          regionno, year, basque2, time_cohort = T))\n})\n\n\ntest_that(\"multisynth with covariates doesn't depend on unit or time order \", {\n\n  data <- read.csv(\"https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/WGWMAV/3UHTLP\", sep=\"\\t\")\n  data %>%\n    filter(!State %in% c(\"DC\", \"WI\"),\n           year >= 1959, year <= 1997) %>%\n    mutate(YearCBrequired = ifelse(is.na(YearCBrequired), \n                                   Inf, YearCBrequired),\n           cbr = 1 * (year >= YearCBrequired)) -> analysis_df\n\n  data %>%\n  select(State, year, agr, pnwht, purban, perinc, studteachratio) %>%\n  group_by(State) %>%\n  summarise(perinc_1959 = perinc[year == 1959],\n            studteachratio_1959 = studteachratio[year == 1959]) %>% \n  # filter to lower 48 where we have data\n  filter(!State %in% c(\"AK\", \"HI\"))  -> cov_data\n\n  analysis_df %>%\n    inner_join(cov_data, by = \"State\") -> analysis_df_covs\n\n  msyn <- multisynth(lnppexpend ~ cbr | perinc_1959 + studteachratio_1959,\n                            State, year, analysis_df_covs)\n\n  msyn_rev_unit <- multisynth(lnppexpend ~ cbr | perinc_1959 + studteachratio_1959,\n                            State, year,\n                            analysis_df_covs %>% arrange(desc(State)))\n\n  msyn_rev_time <- multisynth(lnppexpend ~ cbr | perinc_1959 + studteachratio_1959,\n                            State, year,\n                            analysis_df_covs %>% arrange(desc(year)))\n\n  expect_equal(predict(msyn), predict(msyn_rev_time))\n  expect_equal(predict(msyn), predict(msyn_rev_unit))\n\n})"
  },
  {
    "path": "tests/testthat/test_outcome_models.R",
    "content": "context(\"Testing that augmenting synth with different models loads and runs\")\n\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                            regionno != 17 ~0,\n                                            regionno == 17 ~ 1)) %>%\n    filter(regionno != 1)\n\n\n                            \ntest_that(\"Augmenting synth with glmnet runs\", {\n\n    if(!requireNamespace(\"glmnet\", quietly = TRUE)) {\n        ## should fail because glmnet isn't installed\n        expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"EN\", scm=T),\n                     \"you must install the glmnet package\")\n\n        ## install glmnet\n        install.packages(\"glmnet\", repos = \"http://cran.us.r-project.org\")\n    }\n\n    ## should run because glmnet is installed\n    expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"EN\", scm=T),\n                 NA)    \n}\n)\n\n\n\ntest_that(\"Augmenting synth with random forest runs\", {\n\n    if(!requireNamespace(\"randomForest\", quietly = TRUE)) {\n        ## should fail because randomForest isn't installed\n        expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"RF\", scm=T),\n                     \"you must install the randomForest package\")\n\n        ## install randomForest\n        install.packages(\"randomForest\", repos = \"http://cran.us.r-project.org\")\n    }\n\n    ## should run because randomForest is installed\n    expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"RF\", scm=T),\n                 NA)    \n}\n)\n\n\n\n\ntest_that(\"Augmenting synth with gsynth runs and produces the correct result\", {\n\n    if(!requireNamespace(\"gsynth\", quietly = TRUE)) {\n        ## should fail because gsynth isn't installed\n        expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, \n                              progfunc=\"GSYN\", scm=T),\n                     \"you must install the gsynth package\")\n\n        ## install gsynth\n        install.packages(\"gsynth\", repos = \"http://cran.us.r-project.org\")\n    }\n\n    ## should run because gsynth is installed\n    expect_error(\n      augsynth(gdpcap ~ trt, regionno, year, basque, \n                                progfunc = \"GSYN\", scm = T, CV = 0, r = 4),\n      NA)\n    asyn_gsyn <- augsynth(gdpcap ~ trt, regionno, year, basque,\n                          progfunc = \"GSYN\", scm = F, CV = 0, r = 4)\n    expect_equal(summary(asyn_gsyn, inf = F)$average_att$Estimate, \n                 -0.1444637, tolerance=1e-4) \n}\n)\n\n\n\ntest_that(\"Augmenting synth with MCPanel runs\", {\n\n    if(!requireNamespace(\"MCPanel\", quietly = TRUE)) {\n        ## should fail because MCPanel isn't installed\n        expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"MCP\", scm=T),\n                     \"you must install the MCPanel package\")\n    } else {\n        ## should run because MCPanel is installed\n        expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"MCP\", scm=T),\n                     NA)    \n    }\n\n    \n}\n)\n\n\n\n\ntest_that(\"Augmenting synth with CausalImpact runs\", {\n\n    if(!requireNamespace(\"CausalImpact\", quietly = TRUE)) {\n        ## should fail because CausalImpact isn't installed\n        expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"CausalImpact\", scm=T),\n                     \"you must install the CausalImpact package\")\n\n        ## install CausalImpact\n        install.packages(\"CausalImpact\", repos = \"http://cran.us.r-project.org\")\n    }\n\n    ## should run because CausalImpact is installed\n    expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc=\"CausalImpact\", scm=T),\n                 NA)    \n}\n)\n"
  },
  {
    "path": "tests/testthat/test_time_cohort.R",
    "content": "context(\"Test time cohort vs unit level analysis\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0,\n                                            regionno != 17 ~0,\n                                            regionno == 17 ~ 1)) %>%\n    filter(regionno != 1)\n\n\n                            \ntest_that(\"multisynth at the unit level and time cohort level give the same answer for a single treated unit and no augmentation\", {\n\n    msyn_unit <- multisynth(gdpcap ~ trt, regionno, year, basque, nu = 0,\n                            time_cohort = F, scm = T,\n                            eps_rel = 1e-5, eps_abs = 1e-5)\n    msyn_time <- multisynth(gdpcap ~ trt, regionno, year, basque, nu = 0,\n                            time_cohort = T, scm = T,\n                            eps_rel = 1e-5, eps_abs = 1e-5)\n    # weights are the same-ish\n    expect_equal(c(msyn_unit$weights), c(msyn_time$weights), tolerance=3e-2)\n\n    # estimates are the same-ish\n    expect_equal(c(predict(msyn_unit, att=F)),\n                 c(predict(msyn_time, att=F)),\n                 tolerance=5e-3)\n\n\n    ## level of balance is same-ish expected\n    expect_equal(msyn_unit$ind_l2, msyn_time$ind_l2, tolerance=1e-3)\n\n}\n)\n\n\ntest_that(\"multisynth at the time cohort level runs\", {\n\n    expect_error(msyn_time <- multisynth(gdpcap ~ trt, regionno, year, basque,\n                            time_cohort = T, scm = T),\n                 NA)\n}\n)"
  },
  {
    "path": "tests/testthat/test_unbalanced_multisynth.R",
    "content": "context(\"Test multisynth for unbalanced panels\")\n\nset.seed(1011)\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when((regionno == 17) & (year >= 1975) ~ 1,\n                                              (regionno == 16) & (year >= 1980) ~ 1,\n                                              TRUE ~ 0)) %>%\n      filter(regionno != 1)\nregions <- basque %>% distinct(regionno) %>% pull(regionno)\n\n\ntest_that(\"Data formatting creates NAs correctly\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(regionno != 17 | year != 1970) -> basque_mis\n\n  dat_format <- format_data_stag(quo(gdpcap), quo(trt),\n                                  quo(regionno), quo(year), basque_mis)\n\n  expect_true(is.na(dat_format$X[regions == 17, \"1970\"]))\n})\n\n\ntest_that(\"Non-NA donors are chosen correctly with missing pre-treatment\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(15, 17, 18) | year != 1970) -> basque_mis\n\n  dat_format <- format_data_stag(quo(gdpcap), quo(trt),\n                                  quo(regionno), quo(year), basque_mis)\n  n_lags <- ncol(dat_format$X)\n  n_leads <- ncol(dat_format$y)\n  donors <- get_nona_donors(dat_format$X, dat_format$y, dat_format$trt,\n                            n_lags, n_leads, F)\n\n  expect_true(!all(donors[[1]][regions %in% c(15, 17, 18) ]))\n  expect_true(all(donors[[1]][!regions %in% c(15, 17, 18) ]))\n  expect_true(all(donors[[2]]))\n})\n\ntest_that(\"Non-NA donors are chosen correctly with missing post-treatment\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(15, 17, 18) | !year %in% c(1990)) -> basque_mis\n\n  dat_format <- format_data_stag(quo(gdpcap), quo(trt),\n                                  quo(regionno), quo(year), basque_mis)\n  n_lags <- ncol(dat_format$X)\n  n_leads <- ncol(dat_format$y)\n  donors <- get_nona_donors(dat_format$X, dat_format$y, dat_format$trt,\n                           n_lags, n_leads, F)\n\n  expect_true(!all(donors[[1]][regions %in% c(15, 17, 18) ]))\n  expect_true(all(donors[[1]][!regions %in% c(15, 17, 18) ]))\n  expect_true(all(donors[[2]]))\n})\n\n\ntest_that(\"Non-NA donors are chosen correctly with missing pre- and post-treatment\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(15, 17, 18) | !year %in% c(1970, 1990)) -> basque_mis\n\n  dat_format <- format_data_stag(quo(gdpcap), quo(trt),\n                                  quo(regionno), quo(year), basque_mis)\n  n_lags <- ncol(dat_format$X)\n  n_leads <- ncol(dat_format$y)\n  donors <- get_nona_donors(dat_format$X, dat_format$y, dat_format$trt,\n                           n_lags, n_leads, F)\n\n  expect_true(!all(donors[[1]][regions %in% c(15, 17, 18) ]))\n  expect_true(all(donors[[1]][!regions %in% c(15, 17, 18) ]))\n  expect_true(all(donors[[2]]))\n})\n\n\ntest_that(\"Non-NA donors are chosen correctly with missing pre- and post-treatment and not considering all leads and lags\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(15, 17, 18) | !year %in% c(1970, 1990)) -> basque_mis\n\n  dat_format <- format_data_stag(quo(gdpcap), quo(trt),\n                                  quo(regionno), quo(year), basque_mis)\n  n_lags <- ncol(dat_format$X)\n  n_leads <- ncol(dat_format$y)\n  donors <- get_nona_donors(dat_format$X, dat_format$y, dat_format$trt,\n                           5, 5, F)\n\n  expect_true(all(donors[[1]]))\n  expect_true(all(donors[[2]]))\n})\n\ntest_that(\"Separate synth with missing treated unit time drops the time\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(17) | year != 1970) -> basque_mis\n\n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis,\n                     fixedeff = F,\n                     nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  msyn2 <- multisynth(gdpcap ~ trt, regionno, year,\n                      basque %>% filter(year != 1970),\n                      fixedeff = F,\n                      nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n\n  expect_equal(msyn$weights[,2], msyn2$weights[,2], tolerance = 1e-6)\n})\n\n\ntest_that(\"Separate synth with missing control unit time drops control unit\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(18) | year != 1970) -> basque_mis\n  \n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, \n                     nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  msyn2 <- multisynth(gdpcap ~ trt, regionno, year, \n                      basque %>% filter(regionno != 18),\n                      nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  expect_equal(msyn$weights[-17,2], msyn2$weights[,2], tolerance = 1e-6)\n})\n\n\ntest_that(\"Separate synth with missing control unit only in post-treatment period drops control unit\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(18) | year < 1980) -> basque_mis\n\n  dat_format <- format_data_stag(quo(gdpcap), quo(trt), quo(regionno), quo(year), basque_mis)\n\n  expect_true(nrow(dat_format$X) == nrow(dat_format$y))\n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, \n                     nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  msyn2 <- multisynth(gdpcap ~ trt, regionno, year, \n                      basque %>% filter(regionno != 18),\n                      nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  expect_equal(msyn$weights[-17,2], msyn2$weights[,2], tolerance = 1e-6)\n})\n\ntest_that(\"Separate synth with missing control unit only in pre-treatment period drops control unit\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(18) | year >= 1980) -> basque_mis\n\n  dat_format <- format_data_stag(quo(gdpcap), quo(trt), quo(regionno), quo(year), basque_mis)\n\n  expect_true(nrow(dat_format$X) == nrow(dat_format$y))\n\n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, \n                     nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  msyn2 <- multisynth(gdpcap ~ trt, regionno, year, \n                      basque %>% filter(regionno != 18),\n                      nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  expect_equal(msyn$weights[-17,2], msyn2$weights[,2], tolerance = 1e-6)\n})\n\n\ntest_that(\"Multisynth with unbalanced panels runs\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(15, 17) | year != 1970) -> basque_mis\n\n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, \n                     scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  expect_error(summary(msyn), NA)\n})\n\n\ntest_that(\"Multisynth with unbalanced panels runs with missing post-treatment\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(15, 17) | year != 1990) -> basque_mis\n\n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, \n                     scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  expect_error(summary(msyn), NA)\n})\n\n\n\ntest_that(\"Multisynth with unbalanced panels runs\", {\n\n  # drop a time period for unit 17\n  basque %>%\n    filter(!regionno %in% c(15) | year != 1985) -> basque_mis\n\n  msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, \n                     scm=T, eps_rel=1e-8, eps_abs=1e-8)\n\n  expect_error(summary(msyn), NA)\n})"
  },
  {
    "path": "tests/testthat.R",
    "content": "library(testthat)\nlibrary(augsynth)\n\ntest_check(\"augsynth\")\n"
  },
  {
    "path": "vignettes/.gitignore",
    "content": "*.html\n*.R\n"
  },
  {
    "path": "vignettes/multi-outcomes-vignette.Rmd",
    "content": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{Multi Outcomes AugSynth Vignette}\n  %\\VignetteEngine{knitr::rmarkdown}\n  %\\VignetteEncoding{UTF-8}\n---\n\n```{r setup, include = FALSE}\nknitr::opts_chunk$set(\n  collapse = TRUE,\n  comment = \"#>\"\n  )\nlibrary(kableExtra)\n```\n\n# `augsynth`: Estimating multiple outcome effects\n\n### The data\nTo demonstrate `augsynth` with multiple outcomes, we'll use data on the impact of personal income tax cuts in Kansas that comes with the `AugSynth` package. Our interest is in estimating the effect of income tax cuts on gross state product (GSP) per capita, wages, establishment counts, and other macroeconomic indicators.\n\n```{r load_data, results=\"hide\", warning=F, message=F}\nlibrary(magrittr)\nlibrary(dplyr)\nlibrary(augsynth)\ndata(kansas)\n```\n\nThe `kansas` dataset contains the GSP per capita (the outcome measure) `lngdpcapita` for all 50 states from the first quarter of 1990 to the first quarter of 2016.\n\nTo run `augsynth`, we need to include a treatment status column that indicates which region was treated and at what time. The table in `kansas` contains the column `treated` to denote this. In the original study, the second quarter of 2012 was the implementation of the tax cut in Kansas.\n\n```{r treated_units}\nkansas %>% select(year, qtr, year_qtr, state, treated, gdp, lngdpcapita) %>% filter(state == \"Kansas\" & year_qtr >= 2012 & year_qtr < 2013) \n```\n\n\n### Using the Synthetic Controls Method\nWe will begin by running the synthetic controls method on GDP per capita, wages, and the number of establishments. To run the vanilla synthetic controls method using `augsynth`, set `progfunc` to `None` and `scm` to `TRUE`.\n\n#### Single outcomes\nFirst, we will examine each outcome variable separately, beginning with log GDP per capita `lngdpcapita`.\n\n```{r lngdpcapita_syn}\nsyn_lngdpcapita <- augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas, progfunc=\"None\", scm=T)\nsummary(syn_lngdpcapita)\n```\n\n```{r lngdpcapita_syn_plot}\nplot(syn_lngdpcapita)\n```\n\nNext we will examine the log total wages per capita. Since this column doesn't already exist in the dataframe, we will create a `lntotalwagescapita` column.\n```{r lntotalwagescapita_syn}\nkansas$lntotalwagescapita <- log(kansas$totalwagescapita)\nsyn_lntotalwagescapita <- augsynth(lntotalwagescapita ~ treated, fips, year_qtr, kansas, progfunc=\"None\", scm=T)\nsummary(syn_lntotalwagescapita)\n```\n\n```{r lntotalwagescapita_syn_plot}\nplot(syn_lntotalwagescapita)\n```\n\nLastly, we will examine the number of establishments per capita, `estabscapita`.\n```{r estabscapita_syn}\nsyn_estabscapita <- augsynth(estabscapita ~ treated, fips, year_qtr, kansas, progfunc=\"None\", scm=T)\nsummary(syn_estabscapita)\n```\n\n```{r single_estabscapita_syn_plot}\nplot(syn_estabscapita)\n```\n\n#### Multiple outcomes\nNow we will combine our outcome variables into one study. To add more outcome variables, we add them to the LHS of the formula. \n```{r multi_outcome_syn}\nsyn_multi <- augsynth(lngdpcapita + lntotalwagescapita + estabscapita ~ treated, fips, year_qtr, kansas, progfunc=\"None\", scm=T)\nsummary(syn_multi)\n```\n\n```{r multi_outcome_syn_plot}\nplot(syn_multi)\n```\n\n\n### Using the Augmented Synthetic Controls Method\nWe will now repeat the study using the Augmented Synthetic Controls Method with ridge regression. In ASCM, we first fit the SCM weights, then combine it with a ridge regression, thus setting `progfunc=\"Ridge\", scm=T`.\n\n```{r lngdpcapita_asyn}\nasyn_lngdpcapita <- augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas, progfunc=\"Ridge\", scm=T)\nsummary(asyn_lngdpcapita)\n```\n\n```{r lngdpcapita_asyn_plot}\nplot(asyn_lngdpcapita)\n```\n\n```{r lntotalwagescapita_asyn}\nasyn_lntotalwagescapita <- augsynth(lntotalwagescapita ~ treated, fips, year_qtr, kansas, progfunc=\"Ridge\", scm=T)\nsummary(asyn_lntotalwagescapita)\n```\n\n```{r lntotalwagescapita_asyn_plot}\nplot(asyn_lntotalwagescapita)\n```\n\n```{r estabscapita_asyn}\nasyn_estabscapita <- augsynth(estabscapita ~ treated, fips, year_qtr, kansas, progfunc=\"Ridge\", scm=T)\nsummary(asyn_estabscapita)\n```\n\n```{r single_estabscapita_asyn_plot}\nplot(asyn_estabscapita)\n```\n\n#### Multiple outcomes\nNow we will combine our outcome variables into one study. To add more outcome variables, we add them to the LHS of the formula. \n```{r multi_outcome_asyn}\nasyn_multi <- augsynth(lngdpcapita + lntotalwagescapita + estabscapita ~ treated, \n                       fips, year_qtr, kansas, progfunc=\"Ridge\", scm=T, lambda = 1e-4)\nsummary(asyn_multi)\n```\n\n```{r multi_outcome_asyn_plot}\nplot(asyn_multi)\n```\n"
  },
  {
    "path": "vignettes/multisynth-vignette.Rmd",
    "content": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{MultiSynth Vignette}\n  %\\VignetteEngine{knitr::rmarkdown}\n  %\\VignetteEncoding{UTF-8}\n---\n\n```{r setup, include = FALSE}\nknitr::opts_chunk$set(\n  collapse = TRUE,\n  comment = \"#>\"\n  )\nlibrary(kableExtra)\n```\n\n\n# `augsynth`: Estimating treatment effects with staggered adoption\n\n### The data\n\nTo show the features of the `multisynth` function we will use data on the effects of states implementing mandatory collective bargaining agreements for public sector unions [(Paglayan, 2018)](https://onlinelibrary.wiley.com/doi/full/10.1111/ajps.12388)\n\n```{r results=\"hide\", warning=F, message=F}\nlibrary(magrittr)\nlibrary(dplyr)\nlibrary(augsynth)\n```\n\n```{r }\ndata <- read.csv(\"https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/WGWMAV/3UHTLP\", sep=\"\\t\")\n```\n\nThe dataset contains several important variables that we'll use:\n\n- `year`, `State`: The state and year of the measurement\n- `YearCBrequired`: The year that the state adopted mandatory collective bargaining\n- `lnppexpend`: Log per pupil expenditures in constant 2010 $\n\n```{r echo = F}\ndata %>% \n    filter(year == 1960) %>% \n    select(year, State, YearCBrequired, lnppexpend) %>%\n    head() %>%\n    kable() %>%\n    kable_styling(bootstrap_options =c(\"hover\", \"responsive\"))\n```\n\nTo run `multisynth`, we need to include a treatment status column that indicates which state is treated in a given year, we call this `cbr` below. We also restrict to the years 1959-1997 where we have yearly measurements of expenditures and drop Washington D.C. and Wisconsin from the analysis.\n\n```{r }\ndata %>%\n    filter(!State %in% c(\"DC\", \"WI\"),\n           year >= 1959, year <= 1997) %>%\n    mutate(YearCBrequired = ifelse(is.na(YearCBrequired), \n                                   Inf, YearCBrequired),\n           cbr = 1 * (year >= YearCBrequired)) -> analysis_df\n```\n\n## Partially pooled SCM with an intercept\n\nTo fit partially pooled synthetic controls, we need to give `multisynth` a formula of the form `outcome ~ treatment`, point it to the unit and time variables, and choose the level of partial pooling `nu`. Setting `nu = 0` fits a separate synthetic control for each treated unit and setting `nu = 1` fits fully pooled synthetic controls. If we don't set `nu`, `multisynth` will choose a heuristic value based on how well separate synthetic controls balance the overall average.\nBy default, `multisynth` includes an intercept shift along with the weights; we can exclude the intercept shift by setting `fixedeff = F`.\nWe can also set the number of pre-treatment time periods (lags) that we want to balance with the `n_lags` argument and the number of post-treatment time periods (leads) that we want to estimate with the `n_leads` argument. By default `multisynth` sets `n_lags` and `n_leads` to the number of pre-treatment and post-treatment periods for the last treated unit, respectively.\n\n```{r }\n# with a choice of nu\nppool_syn <- multisynth(lnppexpend ~ cbr, State, year, \n                        nu = 0.5, analysis_df)\n# with default nu\nppool_syn <- multisynth(lnppexpend ~ cbr, State, year, \n                        analysis_df)\n\nprint(ppool_syn$nu)\n\nppool_syn\n```\n\nUsing the `summary` function, we'll compute the treatment effects and standard errors and confidence intervals for all treated units as well as the average via the wild bootstrap. (This takes a bit of time so we'll store the output) We can also change the significant level associated with the confidence intervals by setting the `alpha` argument, by default `alpha = 0.05`.\n\n```{r}\nppool_syn_summ <- summary(ppool_syn)\n```\n\nWe can then report the level of global and individual balance as well as estimates for the average.\n\n```{r }\nppool_syn_summ\n```\n\n`ppool_syn_summ$att` is a dataframe that contains all of the point estimates, standard errors, and lower/upper confidence limits. `Time = NA` denotes the effect averaged across the post treatment periods.\n\n```{r echo = F}\nppool_syn_summ$att %>%\n  filter(Time >= 0) %>%\n  head() %>%\n  kable() %>%\n  kable_styling(bootstrap_options =c(\"hover\", \"responsive\"))\n```\n\nWe can also visually display both the pre-treatment balance and the estimated treatment effects.\n\n```{r ppool_syn_plot, fig.width=8, fig.height=4.5, fig.align=\"center\", warning=F, message=F}\nplot(ppool_syn_summ)\n```\n\nAnd again we can hone in on the average effects.\n\n```{r ppool_syn_plot_avg, fig.width=8, fig.height=4.5, fig.align=\"center\", warning=F, message=F}\nplot(ppool_syn_summ, levels = \"Average\")\n```\n\n\n### Collapsing into time cohorts\n\nWe can also collapse treated units with the same treatment time into _time cohorts_, and find one synthetic control per time cohort by setting `time_cohort = TRUE`. When the number of distinct treatment times is much smaller than the number of treated units, this will run significantly faster.\n\n```{r }\n# with default nu\nppool_syn_time <- multisynth(lnppexpend ~ cbr, State, year,\n                        analysis_df, time_cohort = TRUE)\n\nprint(ppool_syn_time$nu)\n\nppool_syn_time\n```\n\nWe can then compute effects for the overall average as well as for each treatment time cohort, rather than individual units.\n\n```{r}\nppool_syn_time_summ <- summary(ppool_syn_time)\nppool_syn_time_summ\n```\n\n```{r echo = F}\nppool_syn_time_summ$att %>%\n  filter(Time >= 0) %>%\n  head() %>%\n  kable() %>%\n  kable_styling(bootstrap_options =c(\"hover\", \"responsive\"))\n```\n\nAgain we can plot the effects.\n\n```{r ppool_syn_time_plot, fig.width=8, fig.height=4.5, fig.align=\"center\", warning=F, message=F}\nplot(ppool_syn_time_summ)\n```\n\n\n### Including auxiliary covariates\n\nWe can also include an additional set of covariates to balance along with the pre-treatment outcomes. First, let's create a data frame with the values of some covariates in a few different years:\n\n```{r cov_data}\n\ndata %>%\n  select(State, year, agr, pnwht, purban, perinc, studteachratio) %>%\n  group_by(State) %>%\n  summarise(perinc_1959 = perinc[year == 1959],\n            studteachratio_1959 = studteachratio[year == 1959]) %>% \n  # filter to lower 48 where we have data\n  filter(!State %in% c(\"AK\", \"HI\"))  -> cov_data\n\nanalysis_df %>%\n  inner_join(cov_data, by = \"State\") -> analysis_df_covs\n\n```\n\nTo include auxiliary covariates, we can add them in to the formula after `|`. This will balance the auxiliary covariates along with the pre-treatment outcomes simultanouesly. If the covariates vary during the pre-treatment periods, `multisynth` will use the average pre-treatment value. We can change this behavior by including our own custom aggregation function via the `cov_agg` argument.\n```{r cov_syn}\n# with default nu\nppool_syn_cov <- multisynth(lnppexpend ~ cbr | perinc_1959 + studteachratio_1959,\n                            State, year, analysis_df_covs)\n\nprint(ppool_syn_cov$nu)\n\nppool_syn_cov\n```\n\nAgain we can compute effects, along with their standard errors and confidence intervals, and plot.\n```{r}\nppool_syn_cov_summ <- summary(ppool_syn_cov)\nppool_syn_cov_summ\n```\n\n```{r echo = F}\nppool_syn_cov_summ$att %>%\n  filter(Time >= 0) %>%\n  head() %>%\n  kable() %>%\n  kable_styling(bootstrap_options =c(\"hover\", \"responsive\"))\n```\n\nAgain we can plot the effects.\n```{r ppool_syn_cov_plot, fig.width=8, fig.height=4.5, fig.align=\"center\", warning=F, message=F}\nplot(ppool_syn_cov_summ, levels = \"Average\")\n```"
  },
  {
    "path": "vignettes/multisynth-vignette.md",
    "content": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{MultiSynth Vignette}\n  %\\VignetteEngine{knitr::rmarkdown}\n  %\\VignetteEncoding{UTF-8}\n---\n\n\n\n\n# `augsynth`: Estimating treatment effects with staggered adoption\n\n### The data\n\nTo show the features of the `multisynth` function we will use data on the effects of states implementing mandatory collective bargaining agreements for public sector unions [(Paglayan, 2018)](https://onlinelibrary.wiley.com/doi/full/10.1111/ajps.12388)\n\n\n```r\nlibrary(magrittr)\nlibrary(dplyr)\nlibrary(augsynth)\n```\n\n\n```r\ndata <- read.csv(\"https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/WGWMAV/3UHTLP\", sep=\"\\t\")\n```\n\nThe dataset contains several important variables that we'll use:\n\n- `year`, `State`: The state and year of the measurement\n- `YearCBrequired`: The year that the state adopted mandatory collective bargaining\n- `lnppexpend`: Log per pupil expenditures in constant 2010 $\n\n<table class=\"table table-hover table-responsive\" style=\"margin-left: auto; margin-right: auto;\">\n <thead>\n  <tr>\n   <th style=\"text-align:right;\"> year </th>\n   <th style=\"text-align:left;\"> State </th>\n   <th style=\"text-align:right;\"> YearCBrequired </th>\n   <th style=\"text-align:right;\"> lnppexpend </th>\n  </tr>\n </thead>\n<tbody>\n  <tr>\n   <td style=\"text-align:right;\"> 1960 </td>\n   <td style=\"text-align:left;\"> AK </td>\n   <td style=\"text-align:right;\"> 1970 </td>\n   <td style=\"text-align:right;\"> 8.325518 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 1960 </td>\n   <td style=\"text-align:left;\"> AL </td>\n   <td style=\"text-align:right;\"> NA </td>\n   <td style=\"text-align:right;\"> 7.396177 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 1960 </td>\n   <td style=\"text-align:left;\"> AR </td>\n   <td style=\"text-align:right;\"> NA </td>\n   <td style=\"text-align:right;\"> 7.385373 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 1960 </td>\n   <td style=\"text-align:left;\"> AZ </td>\n   <td style=\"text-align:right;\"> NA </td>\n   <td style=\"text-align:right;\"> 7.947127 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 1960 </td>\n   <td style=\"text-align:left;\"> CA </td>\n   <td style=\"text-align:right;\"> 1976 </td>\n   <td style=\"text-align:right;\"> 8.185162 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 1960 </td>\n   <td style=\"text-align:left;\"> CO </td>\n   <td style=\"text-align:right;\"> NA </td>\n   <td style=\"text-align:right;\"> 7.952833 </td>\n  </tr>\n</tbody>\n</table>\n\n\n\nTo run `multisynth`, we need to include a treatment status column that indicates which state is treated in a given year, we call this `cbr` below. We also restrict to the years 1959-1997 where we have yearly measurements of expenditures and drop Washington D.C. and Wisconsin from the analysis.\n\n\n```r\ndata %>%\n    filter(!State %in% c(\"DC\", \"WI\"),\n           year >= 1959, year <= 1997) %>%\n    mutate(YearCBrequired = ifelse(is.na(YearCBrequired), \n                                   Inf, YearCBrequired),\n           cbr = 1 * (year >= YearCBrequired)) -> analysis_df\n```\n\n## Partially pooled SCM with an intercept\n\nTo fit partially pooled synthetic controls, we need to give `multisynth` a formula of the form `outcome ~ treatment`, point it to the unit and time variables, and choose the level of partial pooling `nu`. Setting `nu = 0` fits a separate synthetic control for each treated unit and setting `nu = 1` fits fully pooled synthetic controls. If we don't set `nu`, `multisynth` will choose a heuristic value based on how well separate synthetic controls balance the overall average.\nBy default, `multisynth` includes an intercept shift along with the weights; we can exclude the intercept shift by setting `fixedeff = F`.\nWe can also set the number of pre-treatment time periods (lags) that we want to balance with the `n_lags` argument and the number of post-treatment time periods (leads) that we want to estimate with the `n_leads` argument. By default `multisynth` sets `n_lags` and `n_leads` to the number of pre-treatment and post-treatment periods for the last treated unit, respectively.\n\n\n```r\n# with a choice of nu\nppool_syn <- multisynth(lnppexpend ~ cbr, State, year, \n                        nu = 0.5, analysis_df)\n# with default nu\nppool_syn <- multisynth(lnppexpend ~ cbr, State, year, \n                        analysis_df)\n\nprint(ppool_syn$nu)\n#> [1] 0.2606793\n\nppool_syn\n#> \n#> Call:\n#> multisynth(form = lnppexpend ~ cbr, unit = State, time = year, \n#>     data = analysis_df)\n#> \n#> Average ATT Estimate: -0.011\n```\n\nUsing the `summary` function, we'll compute the treatment effects and standard errors and confidence intervals for all treated units as well as the average via the wild bootstrap. (This takes a bit of time so we'll store the output) We can also change the significant level associated with the confidence intervals by setting the `alpha` argument, by default `alpha = 0.05`.\n\n\n```r\nppool_syn_summ <- summary(ppool_syn)\n```\n\nWe can then report the level of global and individual balance as well as estimates for the average.\n\n\n```r\nppool_syn_summ\n#> \n#> Call:\n#> multisynth(form = lnppexpend ~ cbr, unit = State, time = year, \n#>     data = analysis_df)\n#> \n#> Average ATT Estimate (Std. Error): -0.011  (0.022)\n#> \n#> Global L2 Imbalance: 0.003\n#> Scaled Global L2 Imbalance: 0.019\n#> Percent improvement from uniform global weights: 98.1\n#> \n#> Individual L2 Imbalance: 0.028\n#> Scaled Individual L2 Imbalance: 0.096\n#> Percent improvement from uniform individual weights: 90.4\t\n#> \n#>  Time Since Treatment   Level     Estimate  Std.Error lower_bound upper_bound\n#>                     0 Average -0.004281754 0.02231379 -0.04888183  0.03786032\n#>                     1 Average -0.010856856 0.02099299 -0.05423609  0.02939147\n#>                     2 Average  0.004378813 0.02268842 -0.04268354  0.04896627\n#>                     3 Average  0.001155346 0.02388535 -0.04846624  0.04464696\n#>                     4 Average -0.009305005 0.02529949 -0.06207289  0.03822153\n#>                     5 Average -0.016942988 0.02447144 -0.06935946  0.02695179\n#>                     6 Average -0.018505173 0.02507329 -0.07297111  0.02755436\n#>                     7 Average -0.003866657 0.02817460 -0.06047905  0.05013422\n#>                     8 Average -0.015835730 0.03141197 -0.08179055  0.04231137\n#>                     9 Average -0.031751350 0.02962989 -0.09168791  0.02202697\n#>                    10 Average -0.017839047 0.03314017 -0.08835499  0.04070061\n```\n\n`ppool_syn_summ$att` is a dataframe that contains all of the point estimates, standard errors, and lower/upper confidence limits. `Time = NA` denotes the effect averaged across the post treatment periods.\n\n<table class=\"table table-hover table-responsive\" style=\"margin-left: auto; margin-right: auto;\">\n <thead>\n  <tr>\n   <th style=\"text-align:right;\"> Time </th>\n   <th style=\"text-align:left;\"> Level </th>\n   <th style=\"text-align:right;\"> Estimate </th>\n   <th style=\"text-align:right;\"> Std.Error </th>\n   <th style=\"text-align:right;\"> lower_bound </th>\n   <th style=\"text-align:right;\"> upper_bound </th>\n  </tr>\n </thead>\n<tbody>\n  <tr>\n   <td style=\"text-align:right;\"> 0 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0042818 </td>\n   <td style=\"text-align:right;\"> 0.0223138 </td>\n   <td style=\"text-align:right;\"> -0.0488818 </td>\n   <td style=\"text-align:right;\"> 0.0378603 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 1 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0108569 </td>\n   <td style=\"text-align:right;\"> 0.0209930 </td>\n   <td style=\"text-align:right;\"> -0.0542361 </td>\n   <td style=\"text-align:right;\"> 0.0293915 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 2 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> 0.0043788 </td>\n   <td style=\"text-align:right;\"> 0.0226884 </td>\n   <td style=\"text-align:right;\"> -0.0426835 </td>\n   <td style=\"text-align:right;\"> 0.0489663 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 3 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> 0.0011553 </td>\n   <td style=\"text-align:right;\"> 0.0238853 </td>\n   <td style=\"text-align:right;\"> -0.0484662 </td>\n   <td style=\"text-align:right;\"> 0.0446470 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 4 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0093050 </td>\n   <td style=\"text-align:right;\"> 0.0252995 </td>\n   <td style=\"text-align:right;\"> -0.0620729 </td>\n   <td style=\"text-align:right;\"> 0.0382215 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 5 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0169430 </td>\n   <td style=\"text-align:right;\"> 0.0244714 </td>\n   <td style=\"text-align:right;\"> -0.0693595 </td>\n   <td style=\"text-align:right;\"> 0.0269518 </td>\n  </tr>\n</tbody>\n</table>\n\n\n\nWe can also visually display both the pre-treatment balance and the estimated treatment effects.\n\n\n```r\nplot(ppool_syn_summ)\n```\n\n<div class=\"figure\" style=\"text-align: center\">\n<img src=\"figure/ppool_syn_plot-1.png\" alt=\"plot of chunk ppool_syn_plot\"  />\n<p class=\"caption\">plot of chunk ppool_syn_plot</p>\n</div>\n\nAnd again we can hone in on the average effects.\n\n\n```r\nplot(ppool_syn_summ, levels = \"Average\")\n```\n\n<div class=\"figure\" style=\"text-align: center\">\n<img src=\"figure/ppool_syn_plot_avg-1.png\" alt=\"plot of chunk ppool_syn_plot_avg\"  />\n<p class=\"caption\">plot of chunk ppool_syn_plot_avg</p>\n</div>\n\n\n### Collapsing into time cohorts\n\nWe can also collapse treated units with the same treatment time into _time cohorts_, and find one synthetic control per time cohort by setting `time_cohort = TRUE`. When the number of distinct treatment times is much smaller than the number of treated units, this will run significantly faster.\n\n\n```r\n# with default nu\nppool_syn_time <- multisynth(lnppexpend ~ cbr, State, year,\n                        analysis_df, time_cohort = TRUE)\n\nprint(ppool_syn_time$nu)\n#> [1] 0.3939013\n\nppool_syn_time\n#> \n#> Call:\n#> multisynth(form = lnppexpend ~ cbr, unit = State, time = year, \n#>     data = analysis_df, time_cohort = TRUE)\n#> \n#> Average ATT Estimate: -0.018\n```\n\nWe can then compute effects for the overall average as well as for each treatment time cohort, rather than individual units.\n\n\n```r\nppool_syn_time_summ <- summary(ppool_syn_time)\nppool_syn_time_summ\n#> \n#> Call:\n#> multisynth(form = lnppexpend ~ cbr, unit = State, time = year, \n#>     data = analysis_df, time_cohort = TRUE)\n#> \n#> Average ATT Estimate (Std. Error): -0.018  (0.024)\n#> \n#> Global L2 Imbalance: 0.005\n#> Scaled Global L2 Imbalance: 0.018\n#> Percent improvement from uniform global weights: 98.2\n#> \n#> Individual L2 Imbalance: 0.038\n#> Scaled Individual L2 Imbalance: 0.057\n#> Percent improvement from uniform individual weights: 94.3\t\n#> \n#>  Time Since Treatment   Level      Estimate  Std.Error lower_bound upper_bound\n#>                     0 Average -0.0007756959 0.02443902 -0.04849731  0.04410082\n#>                     1 Average -0.0160616979 0.02455148 -0.06120905  0.03042719\n#>                     2 Average -0.0028471499 0.02521902 -0.05189710  0.04841170\n#>                     3 Average -0.0026721191 0.02742973 -0.05634460  0.05048728\n#>                     4 Average -0.0181312843 0.02798461 -0.07148111  0.03468573\n#>                     5 Average -0.0284898474 0.02644653 -0.07724091  0.02368573\n#>                     6 Average -0.0228343778 0.02673115 -0.07456646  0.02837584\n#>                     7 Average -0.0140789250 0.03200335 -0.07574649  0.04580312\n#>                     8 Average -0.0245472682 0.03276526 -0.08792999  0.03819451\n#>                     9 Average -0.0476922268 0.03221486 -0.11080383  0.01490279\n#>                    10 Average -0.0216121159 0.03235770 -0.08391841  0.03853317\n```\n\n<table class=\"table table-hover table-responsive\" style=\"margin-left: auto; margin-right: auto;\">\n <thead>\n  <tr>\n   <th style=\"text-align:right;\"> Time </th>\n   <th style=\"text-align:left;\"> Level </th>\n   <th style=\"text-align:right;\"> Estimate </th>\n   <th style=\"text-align:right;\"> Std.Error </th>\n   <th style=\"text-align:right;\"> lower_bound </th>\n   <th style=\"text-align:right;\"> upper_bound </th>\n  </tr>\n </thead>\n<tbody>\n  <tr>\n   <td style=\"text-align:right;\"> 0 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0007757 </td>\n   <td style=\"text-align:right;\"> 0.0244390 </td>\n   <td style=\"text-align:right;\"> -0.0484973 </td>\n   <td style=\"text-align:right;\"> 0.0441008 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 1 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0160617 </td>\n   <td style=\"text-align:right;\"> 0.0245515 </td>\n   <td style=\"text-align:right;\"> -0.0612091 </td>\n   <td style=\"text-align:right;\"> 0.0304272 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 2 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0028471 </td>\n   <td style=\"text-align:right;\"> 0.0252190 </td>\n   <td style=\"text-align:right;\"> -0.0518971 </td>\n   <td style=\"text-align:right;\"> 0.0484117 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 3 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0026721 </td>\n   <td style=\"text-align:right;\"> 0.0274297 </td>\n   <td style=\"text-align:right;\"> -0.0563446 </td>\n   <td style=\"text-align:right;\"> 0.0504873 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 4 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0181313 </td>\n   <td style=\"text-align:right;\"> 0.0279846 </td>\n   <td style=\"text-align:right;\"> -0.0714811 </td>\n   <td style=\"text-align:right;\"> 0.0346857 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 5 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0284898 </td>\n   <td style=\"text-align:right;\"> 0.0264465 </td>\n   <td style=\"text-align:right;\"> -0.0772409 </td>\n   <td style=\"text-align:right;\"> 0.0236857 </td>\n  </tr>\n</tbody>\n</table>\n\n\n\nAgain we can plot the effects.\n\n\n```r\nplot(ppool_syn_time_summ)\n```\n\n<div class=\"figure\" style=\"text-align: center\">\n<img src=\"figure/ppool_syn_time_plot-1.png\" alt=\"plot of chunk ppool_syn_time_plot\"  />\n<p class=\"caption\">plot of chunk ppool_syn_time_plot</p>\n</div>\n\n\n### Including auxiliary covariates\n\nWe can also include an additional set of covariates to balance along with the pre-treatment outcomes. First, let's create a data frame with the values of some covariates in a few different years:\n\n\n```r\n\ndata %>%\n  select(State, year, agr, pnwht, purban, perinc, studteachratio) %>%\n  group_by(State) %>%\n  summarise(perinc_1959 = perinc[year == 1959],\n            studteachratio_1959 = studteachratio[year == 1959]) %>% \n  # filter to lower 48 where we have data\n  filter(!State %in% c(\"AK\", \"HI\"))  -> cov_data\n\nanalysis_df %>%\n  inner_join(cov_data, by = \"State\") -> analysis_df_covs\n```\n\nTo include auxiliary covariates, we can add them in to the formula after `|`. This will balance the auxiliary covariates along with the pre-treatment outcomes simultanouesly. If the covariates vary during the pre-treatment periods, `multisynth` will use the average pre-treatment value. We can change this behavior by including our own custom aggregation function via the `cov_agg` argument.\n\n```r\n# with default nu\nppool_syn_cov <- multisynth(lnppexpend ~ cbr | perinc_1959 + studteachratio_1959,\n                            State, year, analysis_df_covs)\n\nprint(ppool_syn_cov$nu)\n#> [1] 0.2242633\n\nppool_syn_cov\n#> \n#> Call:\n#> multisynth(form = lnppexpend ~ cbr | perinc_1959 + studteachratio_1959, \n#>     unit = State, time = year, data = analysis_df_covs)\n#> \n#> Average ATT Estimate: -0.019\n```\n\nAgain we can compute effects, along with their standard errors and confidence intervals, and plot.\n\n```r\nppool_syn_cov_summ <- summary(ppool_syn_cov)\nppool_syn_cov_summ\n#> \n#> Call:\n#> multisynth(form = lnppexpend ~ cbr | perinc_1959 + studteachratio_1959, \n#>     unit = State, time = year, data = analysis_df_covs)\n#> \n#> Average ATT Estimate (Std. Error): -0.019  (0.016)\n#> \n#> Global L2 Imbalance: 0.004\n#> Scaled Global L2 Imbalance: 0.030\n#> Percent improvement from uniform global weights: 97\n#> \n#> Individual L2 Imbalance: 0.043\n#> Scaled Individual L2 Imbalance: 0.155\n#> Percent improvement from uniform individual weights: 84.5\t\n#> \n#>  Time Since Treatment   Level      Estimate  Std.Error lower_bound upper_bound\n#>                     0 Average -0.0002624529 0.02142663 -0.04534283 0.039477273\n#>                     1 Average -0.0156461424 0.01955742 -0.05138329 0.021858933\n#>                     2 Average  0.0069387257 0.01979857 -0.03246108 0.046934990\n#>                     3 Average -0.0106105517 0.02094953 -0.05241864 0.032678554\n#>                     4 Average -0.0194238312 0.02027608 -0.06026658 0.019006295\n#>                     5 Average -0.0209126517 0.02065713 -0.06053277 0.018478402\n#>                     6 Average -0.0212525401 0.02011174 -0.06076619 0.018093027\n#>                     7 Average -0.0276107046 0.02122581 -0.07010144 0.014753050\n#>                     8 Average -0.0278450111 0.02282095 -0.07360570 0.017305636\n#>                     9 Average -0.0354977043 0.02341366 -0.07998872 0.009126067\n#>                    10 Average -0.0341083505 0.02709654 -0.08591161 0.017937928\n```\n\n<table class=\"table table-hover table-responsive\" style=\"margin-left: auto; margin-right: auto;\">\n <thead>\n  <tr>\n   <th style=\"text-align:right;\"> Time </th>\n   <th style=\"text-align:left;\"> Level </th>\n   <th style=\"text-align:right;\"> Estimate </th>\n   <th style=\"text-align:right;\"> Std.Error </th>\n   <th style=\"text-align:right;\"> lower_bound </th>\n   <th style=\"text-align:right;\"> upper_bound </th>\n  </tr>\n </thead>\n<tbody>\n  <tr>\n   <td style=\"text-align:right;\"> 0 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0002625 </td>\n   <td style=\"text-align:right;\"> 0.0214266 </td>\n   <td style=\"text-align:right;\"> -0.0453428 </td>\n   <td style=\"text-align:right;\"> 0.0394773 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 1 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0156461 </td>\n   <td style=\"text-align:right;\"> 0.0195574 </td>\n   <td style=\"text-align:right;\"> -0.0513833 </td>\n   <td style=\"text-align:right;\"> 0.0218589 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 2 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> 0.0069387 </td>\n   <td style=\"text-align:right;\"> 0.0197986 </td>\n   <td style=\"text-align:right;\"> -0.0324611 </td>\n   <td style=\"text-align:right;\"> 0.0469350 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 3 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0106106 </td>\n   <td style=\"text-align:right;\"> 0.0209495 </td>\n   <td style=\"text-align:right;\"> -0.0524186 </td>\n   <td style=\"text-align:right;\"> 0.0326786 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 4 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0194238 </td>\n   <td style=\"text-align:right;\"> 0.0202761 </td>\n   <td style=\"text-align:right;\"> -0.0602666 </td>\n   <td style=\"text-align:right;\"> 0.0190063 </td>\n  </tr>\n  <tr>\n   <td style=\"text-align:right;\"> 5 </td>\n   <td style=\"text-align:left;\"> Average </td>\n   <td style=\"text-align:right;\"> -0.0209127 </td>\n   <td style=\"text-align:right;\"> 0.0206571 </td>\n   <td style=\"text-align:right;\"> -0.0605328 </td>\n   <td style=\"text-align:right;\"> 0.0184784 </td>\n  </tr>\n</tbody>\n</table>\n\n\n\nAgain we can plot the effects.\n\n```r\nplot(ppool_syn_cov_summ, levels = \"Average\")\n```\n\n<div class=\"figure\" style=\"text-align: center\">\n<img src=\"figure/ppool_syn_cov_plot-1.png\" alt=\"plot of chunk ppool_syn_cov_plot\"  />\n<p class=\"caption\">plot of chunk ppool_syn_cov_plot</p>\n</div>\n"
  },
  {
    "path": "vignettes/singlesynth-vignette.Rmd",
    "content": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{Single Outcome AugSynth Vignette}\n  %\\VignetteEngine{knitr::rmarkdown}\n  %\\VignetteEncoding{UTF-8}\n---\n\n```{r setup, include = FALSE}\nknitr::opts_chunk$set(\n  collapse = TRUE,\n  comment = \"#>\"\n  )\nlibrary(kableExtra)\n```\n\n# `augsynth`: The Augmented Synthetic Control Method\n\n\n## Installation\n\nYou can install `augsynth` from github using `devtools`.\n\n```{r install, results=\"hide\", message=F, eval=F}\n## Install devtools if noy already installed\ninstall.packages(\"devtools\", repos='http://cran.us.r-project.org')\n## Install augsynth from github\ndevtools::install_github(\"ebenmichael/augsynth\")\n```\n\n## Example: Effects of the 2012 Kansas Tax Cuts \n\n### The data\nTo show the usage and features of `augsynth`, we'll use data on the impact of personal income tax cuts in Kansas that comes with the `AugSynth` package. Our interest is in estimating the effect of income tax cuts on gross state product (GSP) per capita.\n\n```{r load_data, results=\"hide\", warning=F, message=F}\nlibrary(magrittr)\nlibrary(dplyr)\nlibrary(augsynth)\ndata(kansas)\n```\n\nThe `kansas` dataset contains the GSP per capita (the outcome measure) `lngdpcapita` for all 50 states from the first quarter of 1990 to the first quarter of 2016.\n\nTo run `augsynth`, we need to include a treatment status column that indicates which region was treated and at what time. The table in `kansas` contains the column `treated` to denote this. In the original study, the second quarter of 2012 was the implementation of the tax cut in Kansas.\n\n```{r treated_units}\nkansas %>% \n  select(year, qtr, year_qtr, state, treated, gdp, lngdpcapita) %>% \n  filter(state == \"Kansas\" & year_qtr >= 2012 & year_qtr < 2013) \n```\n\n\n### Synth\nNow to find a synthetic control using the entire series of pre-intervention outcomes (and no auxiliary covariates), we can use `augsynth`. To do so we just need to give `augsynth` a formula like `outcome ~ treatment`, tell it what the unit and time variables are, optionally provide when intervention took place (the code will automatically determine this if `t_int` is not provided), and specify that we don't want to fit an outcome model\n\n```{r fit_synth, message=F, warning=F}\nlibrary(augsynth)\nsyn <- augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas,\n                progfunc = \"None\", scm = T)\n```\n\nWe can then look at the ATT estimates for each post-intervention time period and overall. \nWe'll also see the quality of the synthetic control fit measured by the L2 distance between Kansas and its synthetic control, and the percent improvement over uniform weights.\nBy default, we'll also see pointwise confidence intervals using a [conformal inference procedure](https://arxiv.org/abs/1712.09089).\n\n```{r summ_syn}\nsummary(syn)\n```\n\n\nThe default test statistic is the sum of the absolute treatment efects `function(x) sum(abs(x))`. We can change the test statistic via the `stat_func` argument. For instance, if we want to perform a one-way test against postive effects, we can set the test stastic to be the negative sum `function(x) -sum(x)`:\n```{r summ_syn_neg}\nsummary(syn, stat_func = function(x) -sum(x))\n```\nOr if we want to priotize testing the average post-treatment effect, we can set it to be the absolute sum:\n```{r summ_syn_sum}\nsummary(syn, stat_func = function(x) abs(sum(x)))\n```\n\n\nIt's easier to see this information visually. Below we plot the difference between Kansas and it's synthetic control. Before the tax cuts (to the left of the dashed line) we expect these to be close, and after the tax cuts we measure the effect (with point-wise confidence intervals).\n\n```{r fig_syn, fig.width=8, fig.height=4.5, echo=T, fig.align=\"center\"}\nplot(syn)\n```\n\nWe can also compute point-wise confidence intervals using the [Jackknife+ procedure](https://arxiv.org/abs/1905.02928) by changing the `inf_type` argument, although this requires additional assumptions.\n\n```{r fig_syn_plus, fig.width=8, fig.height=4.5, echo=T, fig.align=\"center\"}\nplot(syn, inf_type = \"jackknife+\")\n```\n\n\n### Augmenting synth with an outcome model\nIn this example the pre-intervention synthetic control fit has an L2 imbalance of 0.083, about 20% of the imbalance between Kansas and the average of the other states. We can reduce this by _augmenting_ synth with ridge regression. To do this we change `progfunc` to `\"Ridge\"`. We can also choose the ridge hyper-parameter by setting `lambda`, while not specifying `lambda` will determine one through cross validation:\n```{r fit_asynth, message=F, warning=F}\nasyn <- augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas,\n                progfunc = \"Ridge\", scm = T)\n```\n\nWe can plot the cross-validation MSE when dropping pre-treatment time periods by setting `cv = T` in the `plot` function:\n\n```{r fig_asyn_cv, fig.width=8, fig.height=4.5, echo=T, fig.align=\"center\"}\nplot(asyn, cv = T)\n```\n\nBy default, the CV procedure chooses the maximal value of `lambda` with MSE within one standard deviation of the minimal MSE. To instead choose the `lambda` that minizes the cross validation MSE, set `min_1se = FALSE`.\n\n\nWe can look at the summary and plot the results. Now in the summary output we see an estimate of the overall bias of synth; we measure this with the average amount that augmentation changes the synth estimate. Notice that the estimates become somewhat larger in magnitude, and the standard errors are tighter.\n```{r summ_asyn}\nsummary(asyn)\n```\n\n```{r fig_asyn, fig.width=8, fig.height=4.5, echo=T, fig.align=\"center\"}\nplot(asyn)\n```\n\nThere are also several auxiliary covariates. We can include these in the augmentation by fitting an outcome model using the auxiliary covariates. To do this we simply add the covariates into the formula after `|`. By default this will create time invariant covariates by averaging the auxiliary covariates over the pre-intervention period, dropping `NA` values. We can use a custom aggregation function by setting the `cov_agg` argument. Then the lagged outcomes and the auxiliary covariates are jointly balanced by SCM and the ridge outcome model includes both.\n\n```{r fit_covsynth, message=F, warning=F}\ncovsyn <- augsynth(lngdpcapita ~ treated | lngdpcapita + log(revstatecapita) +\n                                           log(revlocalcapita) + log(avgwklywagecapita) +\n                                           estabscapita + emplvlcapita,\n                   fips, year_qtr, kansas,\n                   progfunc = \"ridge\", scm = T)\n\n```\n\nAgain we can look at the summary and plot the results.\n```{r summ_cvsyn}\nsummary(covsyn)\n```\n\n```{r fig_covsyn, fig.width=8, fig.height=4.5, echo=T, fig.align=\"center\"}\nplot(covsyn)\n```\n\nNow we can additionally fit ridge ASCM on the residuals, look at the summary, and plot the results.\n```{r fit_covsynth_aug, message=F, warning=F}\n\ncovsyn_resid <- augsynth(lngdpcapita ~ treated | lngdpcapita + log(revstatecapita) +\n                                           log(revlocalcapita) + log(avgwklywagecapita) +\n                                           estabscapita + emplvlcapita,\n                   fips, year_qtr, kansas,\n                   progfunc = \"ridge\", scm = T, lambda = asyn$lambda,\n                   residualize = T)\n```\n\n```{r summ_cvsyn_resid}\nsummary(covsyn_resid)\n```\n\n\n```{r fig_covsyn_resid, fig.width=8, fig.height=4.5, echo=T, fig.align=\"center\"}\nplot(covsyn_resid)\n```\n\n\nFinally, we can augment synth with many different outcome models. The simplest outcome model is a unit fixed effect model, which we can include by setting `fixedeff = T`.\n```{r fit_desyn, message=F, warning=F}\n\ndesyn <- augsynth(lngdpcapita ~ treated,\n                   fips, year_qtr, kansas,\n                   progfunc = \"none\", scm = T,\n                   fixedeff = T)\n```\n\n\n```{r summ_desyn}\nsummary(desyn)\n```\n\n\n```{r fig_desyn, fig.width=8, fig.height=4.5, echo=T, fig.align=\"center\"}\nplot(desyn)\n```\n\nWe can incorproate other outcome models by changing the `progfunc`.\nSeveral outcome models are available, including, fitting the factor model directly with `gsynth`, general elastic net regression, bayesian structural time series estimation with `CausalImpact`, and matrix completion with `MCPanel`. For each outcome model you can supply an optional set of parameters, see documentation for details.\n\n\n"
  },
  {
    "path": "vignettes/singlesynth-vignette.md",
    "content": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{Single Outcome AugSynth Vignette}\n  %\\VignetteEngine{knitr::rmarkdown}\n  %\\VignetteEncoding{UTF-8}\n---\n\n\n\n# `augsynth`: The Augmented Synthetic Control Method\n\n\n## Installation\n\nYou can install `augsynth` from github using `devtools`.\n\n\n```r\n## Install devtools if noy already installed\ninstall.packages(\"devtools\", repos='http://cran.us.r-project.org')\n## Install augsynth from github\ndevtools::install_github(\"ebenmichael/augsynth\")\n```\n\n## Example: Effects of the 2012 Kansas Tax Cuts \n\n### The data\nTo show the usage and features of `augsynth`, we'll use data on the impact of personal income tax cuts in Kansas that comes with the `AugSynth` package. Our interest is in estimating the effect of income tax cuts on gross state product (GSP) per capita.\n\n\n```r\nlibrary(magrittr)\nlibrary(dplyr)\nlibrary(augsynth)\ndata(kansas)\n```\n\nThe `kansas` dataset contains the GSP per capita (the outcome measure) `lngdpcapita` for all 50 states from the first quarter of 1990 to the first quarter of 2016.\n\nTo run `augsynth`, we need to include a treatment status column that indicates which region was treated and at what time. The table in `kansas` contains the column `treated` to denote this. In the original study, the second quarter of 2012 was the implementation of the tax cut in Kansas.\n\n\n```r\nkansas %>% \n  select(year, qtr, year_qtr, state, treated, gdp, lngdpcapita) %>% \n  filter(state == \"Kansas\" & year_qtr >= 2012 & year_qtr < 2013) \n#> # A tibble: 4 x 7\n#>    year   qtr year_qtr state  treated    gdp lngdpcapita\n#>   <dbl> <dbl>    <dbl> <chr>    <dbl>  <dbl>       <dbl>\n#> 1  2012     1    2012  Kansas       0 143844        10.8\n#> 2  2012     2    2012. Kansas       1 141518        10.8\n#> 3  2012     3    2012. Kansas       1 138890        10.8\n#> 4  2012     4    2013. Kansas       1 139603        10.8\n```\n\n\n### Synth\nNow to find a synthetic control using the entire series of pre-intervention outcomes (and no auxiliary covariates), we can use `augsynth`. To do so we just need to give `augsynth` a formula like `outcome ~ treatment`, tell it what the unit and time variables are, optionally provide when intervention took place (the code will automatically determine this if `t_int` is not provided), and specify that we don't want to fit an outcome model\n\n\n```r\nlibrary(augsynth)\nsyn <- augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas,\n                progfunc = \"None\", scm = T)\n```\n\nWe can then look at the ATT estimates for each post-intervention time period and overall. \nWe'll also see the quality of the synthetic control fit measured by the L2 distance between Kansas and its synthetic control, and the percent improvement over uniform weights.\nBy default, we'll also see pointwise confidence intervals using a [conformal inference procedure](https://arxiv.org/abs/1712.09089).\n\n\n```r\nsummary(syn)\n#> \n#> Call:\n#> single_augsynth(form = form, unit = !!enquo(unit), time = !!enquo(time), \n#>     t_int = t_int, data = data, progfunc = \"None\", scm = ..2)\n#> \n#> Average ATT Estimate (p Value for Joint Null):  -0.029   ( 0.328 )\n#> L2 Imbalance: 0.083\n#> Percent improvement from uniform weights: 79.5%\n#> \n#> Avg Estimated Bias: NA\n#> \n#> Inference type: Conformal inference\n#> \n#>     Time Estimate 95% CI Lower Bound 95% CI Upper Bound p Value\n#>  2012.25   -0.018             -0.045              0.006   0.111\n#>  2012.50   -0.041             -0.070             -0.015   0.022\n#>  2012.75   -0.033             -0.062             -0.007   0.044\n#>  2013.00   -0.019             -0.046              0.005   0.111\n#>  2013.25   -0.029             -0.053             -0.005   0.044\n#>  2013.50   -0.046             -0.073             -0.022   0.022\n#>  2013.75   -0.032             -0.056             -0.010   0.022\n#>  2014.00   -0.045             -0.074             -0.018   0.022\n#>  2014.25   -0.043             -0.074             -0.014   0.022\n#>  2014.50   -0.029             -0.061              0.000   0.044\n#>  2014.75   -0.018             -0.053              0.011   0.144\n#>  2015.00   -0.029             -0.066              0.005   0.078\n#>  2015.25   -0.019             -0.051              0.010   0.122\n#>  2015.50   -0.022             -0.056              0.007   0.111\n#>  2015.75   -0.019             -0.055              0.013   0.189\n#>  2016.00   -0.028             -0.067              0.008   0.100\n```\n\n\nThe default test statistic is the sum of the absolute treatment efects `function(x) sum(abs(x))`. We can change the test statistic via the `stat_func` argument. For instance, if we want to perform a one-way test against postive effects, we can set the test stastic to be the negative sum `function(x) -sum(x)`:\n\n```r\nsummary(syn, stat_func = function(x) -sum(x))\n#> \n#> Call:\n#> single_augsynth(form = form, unit = !!enquo(unit), time = !!enquo(time), \n#>     t_int = t_int, data = data, progfunc = \"None\", scm = ..2)\n#> \n#> Average ATT Estimate (p Value for Joint Null):  -0.029   ( 0.159 )\n#> L2 Imbalance: 0.083\n#> Percent improvement from uniform weights: 79.5%\n#> \n#> Avg Estimated Bias: NA\n#> \n#> Inference type: Conformal inference\n#> \n#>     Time Estimate 95% CI Lower Bound 95% CI Upper Bound p Value\n#>  2012.25   -0.018             -0.080              0.006   0.067\n#>  2012.50   -0.041             -0.103             -0.015   0.022\n#>  2012.75   -0.033             -0.095             -0.007   0.033\n#>  2013.00   -0.019             -0.081              0.005   0.067\n#>  2013.25   -0.029             -0.091             -0.005   0.033\n#>  2013.50   -0.046             -0.108             -0.022   0.022\n#>  2013.75   -0.032             -0.094             -0.010   0.022\n#>  2014.00   -0.045             -0.107             -0.021   0.022\n#>  2014.25   -0.043             -0.105             -0.014   0.022\n#>  2014.50   -0.029             -0.091              0.000   0.033\n#>  2014.75   -0.018             -0.080              0.011   0.078\n#>  2015.00   -0.029             -0.091              0.005   0.056\n#>  2015.25   -0.019             -0.081              0.007   0.078\n#>  2015.50   -0.022             -0.084              0.007   0.067\n#>  2015.75   -0.019             -0.081              0.013   0.111\n#>  2016.00   -0.028             -0.090              0.008   0.067\n```\nOr if we want to priotize testing the average post-treatment effect, we can set it to be the absolute sum:\n\n```r\nsummary(syn, stat_func = function(x) abs(sum(x)))\n#> \n#> Call:\n#> single_augsynth(form = form, unit = !!enquo(unit), time = !!enquo(time), \n#>     t_int = t_int, data = data, progfunc = \"None\", scm = ..2)\n#> \n#> Average ATT Estimate (p Value for Joint Null):  -0.029   ( 0.302 )\n#> L2 Imbalance: 0.083\n#> Percent improvement from uniform weights: 79.5%\n#> \n#> Avg Estimated Bias: NA\n#> \n#> Inference type: Conformal inference\n#> \n#>     Time Estimate 95% CI Lower Bound 95% CI Upper Bound p Value\n#>  2012.25   -0.018             -0.045              0.006   0.111\n#>  2012.50   -0.041             -0.070             -0.015   0.022\n#>  2012.75   -0.033             -0.062             -0.007   0.044\n#>  2013.00   -0.019             -0.046              0.005   0.111\n#>  2013.25   -0.029             -0.053             -0.005   0.044\n#>  2013.50   -0.046             -0.073             -0.022   0.022\n#>  2013.75   -0.032             -0.056             -0.010   0.022\n#>  2014.00   -0.045             -0.074             -0.018   0.022\n#>  2014.25   -0.043             -0.074             -0.014   0.022\n#>  2014.50   -0.029             -0.061              0.000   0.044\n#>  2014.75   -0.018             -0.053              0.011   0.144\n#>  2015.00   -0.029             -0.066              0.005   0.078\n#>  2015.25   -0.019             -0.051              0.010   0.122\n#>  2015.50   -0.022             -0.056              0.007   0.111\n#>  2015.75   -0.019             -0.055              0.013   0.189\n#>  2016.00   -0.028             -0.067              0.008   0.100\n```\n\n\nIt's easier to see this information visually. Below we plot the difference between Kansas and it's synthetic control. Before the tax cuts (to the left of the dashed line) we expect these to be close, and after the tax cuts we measure the effect (with point-wise confidence intervals).\n\n\n```r\nplot(syn)\n```\n\n<img src=\"figure/fig_syn-1.png\" title=\"plot of chunk fig_syn\" alt=\"plot of chunk fig_syn\" style=\"display: block; margin: auto;\" />\n\nWe can also compute point-wise confidence intervals using the [Jackknife+ procedure](https://arxiv.org/abs/1905.02928) by changing the `inf_type` argument, although this requires additional assumptions.\n\n\n```r\nplot(syn, inf_type = \"jackknife+\")\n```\n\n<img src=\"figure/fig_syn_plus-1.png\" title=\"plot of chunk fig_syn_plus\" alt=\"plot of chunk fig_syn_plus\" style=\"display: block; margin: auto;\" />\n\n\n### Augmenting synth with an outcome model\nIn this example the pre-intervention synthetic control fit has an L2 imbalance of 0.083, about 20% of the imbalance between Kansas and the average of the other states. We can reduce this by _augmenting_ synth with ridge regression. To do this we change `progfunc` to `\"Ridge\"`. We can also choose the ridge hyper-parameter by setting `lambda`, while not specifying `lambda` will determine one through cross validation:\n\n```r\nasyn <- augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas,\n                progfunc = \"Ridge\", scm = T)\n```\n\nWe can plot the cross-validation MSE when dropping pre-treatment time periods by setting `cv = T` in the `plot` function:\n\n\n```r\nplot(asyn, cv = T)\n```\n\n<img src=\"figure/fig_asyn_cv-1.png\" title=\"plot of chunk fig_asyn_cv\" alt=\"plot of chunk fig_asyn_cv\" style=\"display: block; margin: auto;\" />\n\nBy default, the CV procedure chooses the maximal value of `lambda` with MSE within one standard deviation of the minimal MSE. To instead choose the `lambda` that minizes the cross validation MSE, set `min_1se = FALSE`.\n\n\nWe can look at the summary and plot the results. Now in the summary output we see an estimate of the overall bias of synth; we measure this with the average amount that augmentation changes the synth estimate. Notice that the estimates become somewhat larger in magnitude, and the standard errors are tighter.\n\n```r\nsummary(asyn)\n#> \n#> Call:\n#> single_augsynth(form = form, unit = !!enquo(unit), time = !!enquo(time), \n#>     t_int = t_int, data = data, progfunc = \"Ridge\", scm = ..2)\n#> \n#> Average ATT Estimate (p Value for Joint Null):  -0.040   ( 0.057 )\n#> L2 Imbalance: 0.062\n#> Percent improvement from uniform weights: 84.7%\n#> \n#> Avg Estimated Bias: 0.011\n#> \n#> Inference type: Conformal inference\n#> \n#>     Time Estimate 95% CI Lower Bound 95% CI Upper Bound p Value\n#>  2012.25   -0.022             -0.044              0.003   0.056\n#>  2012.50   -0.047             -0.076             -0.018   0.022\n#>  2012.75   -0.043             -0.071             -0.010   0.022\n#>  2013.00   -0.030             -0.055             -0.004   0.033\n#>  2013.25   -0.041             -0.067             -0.012   0.022\n#>  2013.50   -0.059             -0.088             -0.030   0.022\n#>  2013.75   -0.045             -0.073             -0.019   0.022\n#>  2014.00   -0.058             -0.090             -0.026   0.022\n#>  2014.25   -0.055             -0.091             -0.020   0.022\n#>  2014.50   -0.041             -0.080             -0.006   0.033\n#>  2014.75   -0.029             -0.068              0.006   0.056\n#>  2015.00   -0.040             -0.082              0.000   0.056\n#>  2015.25   -0.030             -0.066              0.002   0.056\n#>  2015.50   -0.033             -0.072              0.003   0.056\n#>  2015.75   -0.029             -0.071              0.010   0.056\n#>  2016.00   -0.038             -0.087              0.004   0.056\n```\n\n\n```r\nplot(asyn)\n```\n\n<img src=\"figure/fig_asyn-1.png\" title=\"plot of chunk fig_asyn\" alt=\"plot of chunk fig_asyn\" style=\"display: block; margin: auto;\" />\n\nThere are also several auxiliary covariates. We can include these in the augmentation by fitting an outcome model using the auxiliary covariates. To do this we simply add the covariates into the formula after `|`. By default this will create time invariant covariates by averaging the auxiliary covariates over the pre-intervention period, dropping `NA` values. We can use a custom aggregation function by setting the `cov_agg` argument. Then the lagged outcomes and the auxiliary covariates are jointly balanced by SCM and the ridge outcome model includes both.\n\n\n```r\ncovsyn <- augsynth(lngdpcapita ~ treated | lngdpcapita + log(revstatecapita) +\n                                           log(revlocalcapita) + log(avgwklywagecapita) +\n                                           estabscapita + emplvlcapita,\n                   fips, year_qtr, kansas,\n                   progfunc = \"ridge\", scm = T)\n```\n\nAgain we can look at the summary and plot the results.\n\n```r\nsummary(covsyn)\n#> \n#> Call:\n#> single_augsynth(form = form, unit = !!enquo(unit), time = !!enquo(time), \n#>     t_int = t_int, data = data, progfunc = \"ridge\", scm = ..2)\n#> \n#> Average ATT Estimate (p Value for Joint Null):  -0.061   ( 0.11 )\n#> L2 Imbalance: 0.054\n#> Percent improvement from uniform weights: 86.6%\n#> \n#> Covariate L2 Imbalance: 0.005\n#> Percent improvement from uniform weights: 97.7%\n#> \n#> Avg Estimated Bias: 0.027\n#> \n#> Inference type: Conformal inference\n#> \n#>     Time Estimate 95% CI Lower Bound 95% CI Upper Bound p Value\n#>  2012.25   -0.021             -0.044              0.002   0.067\n#>  2012.50   -0.047             -0.076             -0.014   0.033\n#>  2012.75   -0.050             -0.083             -0.007   0.033\n#>  2013.00   -0.045             -0.074             -0.012   0.033\n#>  2013.25   -0.055             -0.088             -0.022   0.022\n#>  2013.50   -0.071             -0.105             -0.033   0.022\n#>  2013.75   -0.058             -0.091             -0.025   0.022\n#>  2014.00   -0.081             -0.119             -0.037   0.022\n#>  2014.25   -0.078             -0.121             -0.034   0.022\n#>  2014.50   -0.065             -0.114             -0.021   0.033\n#>  2014.75   -0.057             -0.110             -0.008   0.044\n#>  2015.00   -0.075             -0.124             -0.022   0.033\n#>  2015.25   -0.063             -0.106             -0.014   0.033\n#>  2015.50   -0.067             -0.106             -0.019   0.022\n#>  2015.75   -0.063             -0.101             -0.009   0.022\n#>  2016.00   -0.078             -0.122             -0.019   0.022\n```\n\n\n```r\nplot(covsyn)\n```\n\n<img src=\"figure/fig_covsyn-1.png\" title=\"plot of chunk fig_covsyn\" alt=\"plot of chunk fig_covsyn\" style=\"display: block; margin: auto;\" />\n\nNow we can additionally fit ridge ASCM on the residuals, look at the summary, and plot the results.\n\n```r\n\ncovsyn_resid <- augsynth(lngdpcapita ~ treated | lngdpcapita + log(revstatecapita) +\n                                           log(revlocalcapita) + log(avgwklywagecapita) +\n                                           estabscapita + emplvlcapita,\n                   fips, year_qtr, kansas,\n                   progfunc = \"ridge\", scm = T, lambda = asyn$lambda,\n                   residualize = T)\n```\n\n\n```r\nsummary(covsyn_resid)\n#> \n#> Call:\n#> single_augsynth(form = form, unit = !!enquo(unit), time = !!enquo(time), \n#>     t_int = t_int, data = data, progfunc = \"ridge\", scm = ..2, \n#>     lambda = ..3, residualize = ..4)\n#> \n#> Average ATT Estimate (p Value for Joint Null):  -0.055   ( 0.288 )\n#> L2 Imbalance: 0.067\n#> Percent improvement from uniform weights: 83.4%\n#> \n#> Covariate L2 Imbalance: 0.000\n#> Percent improvement from uniform weights: 100%\n#> \n#> Avg Estimated Bias: 0.006\n#> \n#> Inference type: Conformal inference\n#> \n#>     Time Estimate 95% CI Lower Bound 95% CI Upper Bound p Value\n#>  2012.25   -0.025             -0.046             -0.005   0.044\n#>  2012.50   -0.051             -0.076             -0.026   0.011\n#>  2012.75   -0.045             -0.070             -0.020   0.011\n#>  2013.00   -0.044             -0.069             -0.019   0.011\n#>  2013.25   -0.051             -0.077             -0.026   0.011\n#>  2013.50   -0.069             -0.094             -0.044   0.011\n#>  2013.75   -0.051             -0.077             -0.026   0.011\n#>  2014.00   -0.069             -0.095             -0.040   0.011\n#>  2014.25   -0.067             -0.097             -0.037   0.011\n#>  2014.50   -0.053             -0.083             -0.024   0.011\n#>  2014.75   -0.045             -0.075             -0.015   0.022\n#>  2015.00   -0.064             -0.093             -0.034   0.011\n#>  2015.25   -0.051             -0.076             -0.026   0.011\n#>  2015.50   -0.059             -0.089             -0.034   0.011\n#>  2015.75   -0.058             -0.087             -0.028   0.011\n#>  2016.00   -0.074             -0.103             -0.044   0.011\n```\n\n\n\n```r\nplot(covsyn_resid)\n```\n\n<img src=\"figure/fig_covsyn_resid-1.png\" title=\"plot of chunk fig_covsyn_resid\" alt=\"plot of chunk fig_covsyn_resid\" style=\"display: block; margin: auto;\" />\n\n\nFinally, we can augment synth with many different outcome models. The simplest outcome model is a unit fixed effect model, which we can include by setting `fixedeff = T`.\n\n```r\n\ndesyn <- augsynth(lngdpcapita ~ treated,\n                   fips, year_qtr, kansas,\n                   progfunc = \"none\", scm = T,\n                   fixedeff = T)\n```\n\n\n\n```r\nsummary(desyn)\n#> \n#> Call:\n#> single_augsynth(form = form, unit = !!enquo(unit), time = !!enquo(time), \n#>     t_int = t_int, data = data, progfunc = \"none\", scm = ..2, \n#>     fixedeff = ..3)\n#> \n#> Average ATT Estimate (p Value for Joint Null):  -0.034   ( 0.319 )\n#> L2 Imbalance: 0.082\n#> Percent improvement from uniform weights: 55.1%\n#> \n#> Avg Estimated Bias: NA\n#> \n#> Inference type: Conformal inference\n#> \n#>     Time Estimate 95% CI Lower Bound 95% CI Upper Bound p Value\n#>  2012.25   -0.022             -0.046              0.006   0.078\n#>  2012.50   -0.046             -0.070             -0.013   0.022\n#>  2012.75   -0.038             -0.062             -0.005   0.044\n#>  2013.00   -0.024             -0.048              0.003   0.078\n#>  2013.25   -0.033             -0.057             -0.006   0.044\n#>  2013.50   -0.050             -0.074             -0.023   0.022\n#>  2013.75   -0.035             -0.056             -0.010   0.022\n#>  2014.00   -0.049             -0.073             -0.019   0.022\n#>  2014.25   -0.047             -0.071             -0.014   0.022\n#>  2014.50   -0.033             -0.057              0.000   0.056\n#>  2014.75   -0.023             -0.047              0.010   0.122\n#>  2015.00   -0.034             -0.061              0.004   0.078\n#>  2015.25   -0.023             -0.047              0.007   0.100\n#>  2015.50   -0.026             -0.053              0.007   0.100\n#>  2015.75   -0.023             -0.050              0.012   0.144\n#>  2016.00   -0.033             -0.066              0.008   0.089\n```\n\n\n\n```r\nplot(desyn)\n```\n\n<img src=\"figure/fig_desyn-1.png\" title=\"plot of chunk fig_desyn\" alt=\"plot of chunk fig_desyn\" style=\"display: block; margin: auto;\" />\n\nWe can incorproate other outcome models by changing the `progfunc`.\nSeveral outcome models are available, including, fitting the factor model directly with `gsynth`, general elastic net regression, bayesian structural time series estimation with `CausalImpact`, and matrix completion with `MCPanel`. For each outcome model you can supply an optional set of parameters, see documentation for details.\n\n\n"
  }
]