Full Code of ebenmichael/augsynth for AI

master 65c5a6f34f4e cached
83 files
373.9 KB
110.9k tokens
1 requests
Download .txt
Showing preview only (397K chars total). Download the full file or copy to clipboard to get everything.
Repository: ebenmichael/augsynth
Branch: master
Commit: 65c5a6f34f4e
Files: 83
Total size: 373.9 KB

Directory structure:
gitextract_wwvwxodd/

├── .Rbuildignore
├── .gitignore
├── .travis.yml
├── DESCRIPTION
├── LICENSE
├── NAMESPACE
├── R/
│   ├── augsynth.R
│   ├── augsynth_pre.R
│   ├── cv.R
│   ├── data.R
│   ├── eligible_donors.R
│   ├── fit_synth.R
│   ├── format.R
│   ├── globalVariables.R
│   ├── highdim.R
│   ├── inference.R
│   ├── multi_outcomes.R
│   ├── multi_synth_qp.R
│   ├── multisynth_class.R
│   ├── outcome_models.R
│   ├── outcome_multi.R
│   ├── ridge.R
│   ├── ridge_lambda.R
│   └── time_regression_multi.R
├── README.md
├── data/
│   └── kansas.rda
├── data-raw/
│   ├── clean_kansas.R
│   └── kansas_longer2.dta
├── man/
│   ├── augsynth-package.Rd
│   ├── augsynth.Rd
│   ├── augsynth_multiout.Rd
│   ├── check_data_stag.Rd
│   ├── conformal_inf.Rd
│   ├── conformal_inf_linear.Rd
│   ├── conformal_inf_multiout.Rd
│   ├── get_nona_donors.Rd
│   ├── jackknife_se_single.Rd
│   ├── kansas.Rd
│   ├── make_V_matrix.Rd
│   ├── multisynth.Rd
│   ├── plot.augsynth.Rd
│   ├── plot.augsynth_multiout.Rd
│   ├── plot.multisynth.Rd
│   ├── plot.summary.augsynth.Rd
│   ├── plot.summary.augsynth_multiout.Rd
│   ├── plot.summary.multisynth.Rd
│   ├── predict.augsynth.Rd
│   ├── predict.augsynth_multiout.Rd
│   ├── predict.multisynth.Rd
│   ├── print.augsynth.Rd
│   ├── print.augsynth_multiout.Rd
│   ├── print.multisynth.Rd
│   ├── print.summary.augsynth.Rd
│   ├── print.summary.augsynth_multiout.Rd
│   ├── print.summary.multisynth.Rd
│   ├── rdirichlet_b.Rd
│   ├── rmultinom_b.Rd
│   ├── rwild_b.Rd
│   ├── single_augsynth.Rd
│   ├── summary.augsynth.Rd
│   ├── summary.augsynth_multiout.Rd
│   ├── summary.multisynth.Rd
│   ├── time_jackknife_plus.Rd
│   └── time_jackknife_plus_multiout.Rd
├── pkg.Rproj
├── tests/
│   ├── testthat/
│   │   ├── test_augsynth_pre.R
│   │   ├── test_format.R
│   │   ├── test_general.R
│   │   ├── test_lambda.R
│   │   ├── test_load_data.R
│   │   ├── test_multiple_outcomes.R
│   │   ├── test_multisynth.R
│   │   ├── test_multisynth_covariates.R
│   │   ├── test_outcome_models.R
│   │   ├── test_time_cohort.R
│   │   └── test_unbalanced_multisynth.R
│   └── testthat.R
└── vignettes/
    ├── .gitignore
    ├── multi-outcomes-vignette.Rmd
    ├── multisynth-vignette.Rmd
    ├── multisynth-vignette.md
    ├── singlesynth-vignette.Rmd
    └── singlesynth-vignette.md

================================================
FILE CONTENTS
================================================

================================================
FILE: .Rbuildignore
================================================
^data-raw$
^Meta$
^doc$
^\.travis\.yml$
^pkg.Rproj$
figure$
cache$


================================================
FILE: .gitignore
================================================
Meta
doc
inst/doc

## Files
# Emacs autosave files
*~
\#*#
# Don't put data in the repo
*.csv
*.feather

# R stuff
*.Rout
*.Rhistory
*.RData 
*.Rapp.history

# Mac stuff
*.DS_store


# C++ stuff
*.o
*.so
*.dll

test.R

*-vignette.pdf

================================================
FILE: .travis.yml
================================================
# R for travis: see documentation at https://docs.travis-ci.com/user/languages/r

language: r
r:
    - 3.5.1  
sudo: false
cache: packages
warnings_are_errors: false
r_binary_packages:
    - dplyr
    - magrittr
    - ggplot2
    - glmnet
    - plyr
    - kableExtra

================================================
FILE: DESCRIPTION
================================================
Package: augsynth
Title: The Augmented Synthetic Control Method
Version: 0.2.0
Authors@R: person("Eli", "Ben-Michael", email = "ebenmichael@berkeley.edu", role = c("aut", "cre"))
Description: A package implementing the Augmented Synthetic Controls Method.
Depends:
    R (>= 3.5.0)
Imports:
    dplyr,
    tidyr,
    magrittr,
    ggplot2,
    MASS,
    LiblineaR,
    Formula,
    Matrix,
    osqp,
    rlang,
    purrr,
    FNN
Remotes:
    susanathey/MCPanel
License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.2.3
Suggests:
    testthat,
    CausalImpact,
    keras,
    gsynth,
    knitr,
    rmarkdown,
    softImpute,
    MCPanel,
    glmnet,
    randomForest,
    kableExtra,
    ggrepel
VignetteBuilder: knitr


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2018 Elijahu Ben-Michael

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: NAMESPACE
================================================
# Generated by roxygen2: do not edit by hand

S3method(plot,augsynth)
S3method(plot,augsynth_multiout)
S3method(plot,multisynth)
S3method(plot,summary.augsynth)
S3method(plot,summary.augsynth_multiout)
S3method(plot,summary.multisynth)
S3method(predict,augsynth)
S3method(predict,augsynth_multiout)
S3method(predict,multisynth)
S3method(print,augsynth)
S3method(print,augsynth_multiout)
S3method(print,multisynth)
S3method(print,summary.augsynth)
S3method(print,summary.augsynth_multiout)
S3method(print,summary.multisynth)
S3method(summary,augsynth)
S3method(summary,augsynth_multiout)
S3method(summary,multisynth)
export(augsynth)
export(augsynth_multiout)
export(multisynth)
export(rdirichlet_b)
export(rmultinom_b)
export(rwild_b)
export(single_augsynth)
import(dplyr)
import(tidyr)
importFrom(ggplot2,aes)
importFrom(graphics,plot)
importFrom(magrittr,"%>%")
importFrom(purrr,reduce)
importFrom(stats,coef)
importFrom(stats,delete.response)
importFrom(stats,formula)
importFrom(stats,lm)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,na.omit)
importFrom(stats,poly)
importFrom(stats,predict)
importFrom(stats,sd)
importFrom(stats,terms)
importFrom(stats,update)
importFrom(utils,capture.output)


================================================
FILE: R/augsynth.R
================================================
################################################################################
## Main functions for single-period treatment augmented synthetic controls Method
################################################################################


#' Fit Augmented SCM
#' 
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @param progfunc What function to use to impute control outcomes
#'                 ridge=Ridge regression (allows for standard errors),
#'                 none=No outcome model,
#'                 en=Elastic Net, RF=Random Forest, GSYN=gSynth,
#'                 mcp=MCPanel, 
#'                 cits=Comparitive Interuppted Time Series
#'                 causalimpact=Bayesian structural time series with CausalImpact
#' @param scm Whether the SCM weighting function is used
#' @param fixedeff Whether to include a unit fixed effect, default F 
#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted
#' @param ... optional arguments for outcome model
#'
#' @return augsynth object that contains:
#'         \itemize{
#'          \item{"weights"}{Ridge ASCM weights}
#'          \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#'          \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#'          \item{"mhat"}{Outcome model estimate}
#'          \item{"data"}{Panel data as matrices}
#'         }
#' @export
single_augsynth <- function(form, unit, time, t_int, data,
                     progfunc = "ridge",
                     scm=T,
                     fixedeff = FALSE,
                     cov_agg=NULL, ...) {
    call_name <- match.call()

    form <- Formula::Formula(form)
    unit <- enquo(unit)
    time <- enquo(time)

    ## format data
    outcome <- terms(formula(form, rhs=1))[[2]]
    trt <- terms(formula(form, rhs=1))[[3]]

    wide <- format_data(outcome, trt, unit, time, t_int, data)
    synth_data <- do.call(format_synth, wide)
    
    treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)
    control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% 
                        distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit)
        ## add covariates
    if(length(form)[2] == 2) {
        Z <- extract_covariates(form, unit, time, t_int, data, cov_agg)
    } else {
        Z <- NULL
    }
    
    # fit augmented SCM
    augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc, 
                                      scm, fixedeff, ...)
    
    # add some extra data
    augsynth$data$time <- data %>% distinct(!!time) %>%
                                   arrange(!!time) %>% pull(!!time)
    augsynth$call <- call_name
    augsynth$t_int <- t_int 
    
    augsynth$weights <- matrix(augsynth$weights)
    rownames(augsynth$weights) <- control_units

    return(augsynth)
}


#' Internal function to fit augmented SCM
#' @param wide Data formatted from format_data
#' @param synth_data Data formatted from foramt_synth
#' @param Z Matrix of auxiliary covariates
#' @param progfunc outcome model to use
#' @param scm Whether to fit SCM
#' @param fixedeff Whether to de-mean synth
#' @param V V matrix for Synth, default NULL
#' @param ... Extra args for outcome model
#' 
#' @noRd
#' 
fit_augsynth_internal <- function(wide, synth_data, Z, progfunc,
                                  scm, fixedeff, V = NULL, ...) {

    n <- nrow(wide$X)
    t0 <- ncol(wide$X)
    ttot <- t0 + ncol(wide$y)
    if(fixedeff) {
        demeaned <- demean_data(wide, synth_data)
        fit_wide <- demeaned$wide
        fit_synth_data <- demeaned$synth_data
        mhat <- demeaned$mhat
    } else {
        fit_wide <- wide
        fit_synth_data <- synth_data
        mhat <- matrix(0, n, ttot)
    }
    if (is.null(progfunc)) {
        progfunc = "none"
    }
    progfunc = tolower(progfunc)
    ## fit augsynth
    if(progfunc == "ridge") {
        # Ridge ASCM
        augsynth <- do.call(fit_ridgeaug_formatted,
                            list(wide_data = fit_wide,
                                 synth_data = fit_synth_data,
                                 Z = Z, V = V, scm = scm, ...))
    } else if(progfunc == "none") {
        ## Just SCM
        augsynth <- do.call(fit_ridgeaug_formatted,
                        c(list(wide_data = fit_wide, 
                               synth_data = fit_synth_data,
                               Z = Z, ridge = F, scm = T, V = V, ...)))
    } else {
        ## Other outcome models
        progfuncs = c("ridge", "none", "en", "rf", "gsyn", "mcp",
                      "cits", "causalimpact", "seq2seq")
        if (progfunc %in% progfuncs) {
            augsynth <- fit_augsyn(fit_wide, fit_synth_data, 
                                   progfunc, scm, ...)
        } else {
            stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq', 'None'")
        }
        
    }

    augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0), 
                                  augsynth$mhat)
    augsynth$data <- wide
    augsynth$data$Z <- Z
    augsynth$data$synth_data <- synth_data
    augsynth$progfunc <- progfunc
    augsynth$scm <- scm
    augsynth$fixedeff <- fixedeff
    augsynth$extra_args <- list(...)
    if(progfunc == "ridge") {
        augsynth$extra_args$lambda <- augsynth$lambda
    } else if(progfunc == "gsyn") {
        augsynth$extra_args$r <- ncol(augsynth$params$factor)
        augsynth$extra_args$CV <- 0
    }
    ##format output
    class(augsynth) <- "augsynth"
    return(augsynth)
}

#' Get prediction of ATT or average outcome under control
#' @param object augsynth object
#' @param att If TRUE, return the ATT, if FALSE, return imputed counterfactual
#' @param ... Optional arguments
#'
#' @return Vector of predicted post-treatment control averages
#' @export
predict.augsynth <- function(object, att = F, ...) {
    # if ("att" %in% names(list(...))) {
    #     att <- list(...)$att
    # } else {
    #     att <- F
    # }
    augsynth <- object
    
    X <- augsynth$data$X
    y <- augsynth$data$y
    comb <- cbind(X, y)
    trt <- augsynth$data$trt
    mhat <- augsynth$mhat
    
    m1 <- colMeans(mhat[trt==1,,drop=F])

    resid <- (comb[trt==0,,drop=F] - mhat[trt==0,drop=F])

    y0 <- m1 + t(resid) %*% augsynth$weights
    if(att) {
        return(colMeans(comb[trt == 1,, drop = F]) - c(y0))
    } else {
        rnames <- rownames(y0)
        y0_vec <- c(y0)
        names(y0_vec) <- rnames
        return(y0_vec)
    }
}


#' Print function for augsynth
#' @param x augsynth object
#' @param ... Optional arguments
#' @export
print.augsynth <- function(x, ...) {
    augsynth <- x
    
    ## straight from lm
    cat("\nCall:\n", paste(deparse(augsynth$call), sep="\n", collapse="\n"), "\n\n", sep="")

    ## print att estimates
    tint <- ncol(augsynth$data$X)
    ttotal <- tint + ncol(augsynth$data$y)
    att_post <- predict(augsynth, att = T)[(tint + 1):ttotal]

    cat(paste("Average ATT Estimate: ",
              format(round(mean(att_post),3), nsmall = 3), "\n\n", sep=""))
}


#' Plot function for augsynth
#' @importFrom graphics plot
#' 
#' @param x Augsynth object to be plotted
#' @param inf Boolean, whether to get confidence intervals around the point estimates
#' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects
#' @param ... Optional arguments
#' @export
plot.augsynth <- function(x, inf = T, cv = F, ...) {
    # if ("se" %in% names(list(...))) {
    #     se <- list(...)$se
    # } else {
    #     se <- T
    # }

    augsynth <- x
    
    if (cv == T) {
        errors = data.frame(lambdas = augsynth$lambdas,
                            errors = augsynth$lambda_errors,
                            errors_se = augsynth$lambda_errors_se)
        p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) +
              ggplot2::geom_point(size = 2) + 
              ggplot2::geom_errorbar(
                ggplot2::aes(ymin = errors,
                             ymax = errors + errors_se),
                width=0.2, size = 0.5) 
        p <- p + ggplot2::labs(title = bquote("Cross Validation MSE over " ~ lambda),
                              x = expression(lambda), y = "Cross Validation MSE", 
                              parse = TRUE)
        p <- p + ggplot2::scale_x_log10()
        
        # find minimum and min + 1se lambda to plot
        min_lambda <- choose_lambda(augsynth$lambdas,
                                   augsynth$lambda_errors,
                                   augsynth$lambda_errors_se,
                                   F)
        min_1se_lambda <- choose_lambda(augsynth$lambdas,
                                       augsynth$lambda_errors,
                                       augsynth$lambda_errors_se,
                                       T)
        min_lambda_index <- which(augsynth$lambdas == min_lambda)
        min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda)

        p <- p + ggplot2::geom_point(
            ggplot2::aes(x = min_lambda, 
                         y = augsynth$lambda_errors[min_lambda_index]),
            color = "gold")
        p + ggplot2::geom_point(
              ggplot2::aes(x = min_1se_lambda,
                           y = augsynth$lambda_errors[min_1se_lambda_index]),
              color = "gold") +
            ggplot2::theme_bw()
    } else {
        plot(summary(augsynth, ...), inf = inf)
    }
}


#' Summary function for augsynth
#' @param object augsynth object
#' @param inf Boolean, whether to get confidence intervals around the point estimates
#' @param inf_type Type of inference algorithm. Options are
#'         \itemize{
#'          \item{"conformal"}{Conformal inference (default)}
#'          \item{"jackknife+"}{Jackknife+ algorithm over time periods}
#'          \item{"jackknife"}{Jackknife over units}
#'         }
#' @param linear_effect Boolean, whether to invert the conformal inference hypothesis test to get confidence intervals for a linear-in-time treatment effect: intercept + slope * time 
#' @param ... Optional arguments for inference, for more details for each `inf_type` see
#'         \itemize{
#'          \item{"conformal"}{`conformal_inf`}
#'          \item{"jackknife+"}{`time_jackknife_plus`}
#'          \item{"jackknife"}{`jackknife_se_single`}
#'         }
#' @export
summary.augsynth <- function(object, inf = T, inf_type = "conformal",
                             linear_effect = F,
                             ...) {
    augsynth <- object
    summ <- list()

    t0 <- ncol(augsynth$data$X)
    t_final <- t0 + ncol(augsynth$data$y)

    if(inf) {
        if(inf_type == "jackknife") {
            att_se <- jackknife_se_single(augsynth)
        } else if(inf_type == "jackknife+") {
            att_se <- time_jackknife_plus(augsynth, ...)
        } else if(inf_type == "conformal") {
          att_se <- conformal_inf(augsynth, ...)
          # get CIs for linear treatment effects
          if(linear_effect) {
            att_linear <- conformal_inf_linear(augsynth, ...)
          }
        } else {
            stop(paste(inf_type, "is not a valid choice of 'inf_type'"))
        }

        att <- data.frame(Time = augsynth$data$time,
                          Estimate = att_se$att[1:t_final])
        if(inf_type == "jackknife") {
          att$Std.Error <- att_se$se[1:t_final]
          att_avg_se <- att_se$se[t_final + 1]
        } else {
          att_avg_se <- NA
        }
        att_avg <- att_se$att[t_final + 1]
        if(inf_type %in% c("jackknife+", "nonpar_bs", "t_dist", "conformal")) {
            att$lower_bound <- att_se$lb[1:t_final]
            att$upper_bound <- att_se$ub[1:t_final]
        }
        if(inf_type == "conformal") {
          att$p_val <- att_se$p_val[1:t_final]
        }

    } else {
        t0 <- ncol(augsynth$data$X)
        t_final <- t0 + ncol(augsynth$data$y)
        att_est <- predict(augsynth, att = T)
        att <- data.frame(Time = augsynth$data$time,
                          Estimate = att_est)
        att$Std.Error <- NA
        att_avg <- mean(att_est[(t0 + 1):t_final])
        att_avg_se <- NA
    }

    summ$att <- att

    if(inf) {
      if(inf_type %in% c("jackknife+")) {
        summ$average_att <- data.frame(Value = "Average Post-Treatment Effect",
                              Estimate = att_avg, Std.Error = att_avg_se)
        summ$average_att$lower_bound <- att_se$lb[t_final + 1]
        summ$average_att$upper_bound <- att_se$ub[t_final + 1]
        summ$alpha <-  att_se$alpha
      }
      if(inf_type == "conformal") {
        # summ$average_att$p_val <- att_se$p_val[t_final + 1]
        # summ$average_att$lower_bound <- att_se$lb[t_final + 1]
        # summ$average_att$upper_bound <- att_se$ub[t_final + 1]
        # summ$alpha <-  att_se$alpha
        if(linear_effect) {
          summ$average_att <- data.frame(
                                Value = c("Average Post-Treatment Effect",
                                          "Treatment Effect Intercept",
                                          "Treatment Effect Slope"),
                                Estimate = c(att_avg, att_linear$est_int,
                                             att_linear$est_slope),
                                Std.Error = c(att_avg_se, NA, NA),
                                p_val = c(att_se$p_val[t_final + 1], NA, NA),
                                lower_bound = c(att_se$lb[t_final + 1],
                                            att_linear$ci_int[1],
                                            att_linear$ci_slope[1]),
                                upper_bound =  c(att_se$ub[t_final + 1],
                                            att_linear$ci_int[2],
                                            att_linear$ci_slope[2])
          )
        } else {
          summ$average_att <- data.frame(
                                Value = c("Average Post-Treatment Effect"),
                                Estimate = att_avg,
                                Std.Error = att_avg_se,
                                p_val = att_se$p_val[t_final + 1],
                                lower_bound = att_se$lb[t_final + 1],
                                upper_bound =  att_se$ub[t_final + 1]
          )

        }
        summ$alpha <-  att_se$alpha
      }
    } else {
              summ$average_att <- data.frame(Value = "Average Post-Treatment Effect",
                              Estimate = att_avg, Std.Error = att_avg_se)
    }
    summ$t_int <- augsynth$t_int
    summ$call <- augsynth$call
    summ$l2_imbalance <- augsynth$l2_imbalance
    summ$scaled_l2_imbalance <- augsynth$scaled_l2_imbalance
    if(!is.null(augsynth$covariate_l2_imbalance)) {
      summ$covariate_l2_imbalance <- augsynth$covariate_l2_imbalance
      summ$scaled_covariate_l2_imbalance <- augsynth$scaled_covariate_l2_imbalance
    }
    ## get estimated bias

    if(tolower(augsynth$progfunc) == "ridge") {
        mhat <- augsynth$ridge_mhat
        w <- augsynth$synw
    } else {
        mhat <- augsynth$mhat
        w <- augsynth$weights
    }
    trt <- augsynth$data$trt
    m1 <- colMeans(mhat[trt==1,,drop=F])

    if(tolower(augsynth$progfunc) == "none" | (!augsynth$scm)) {
        summ$bias_est <- NA
    } else {
      summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w
    }
    
    
    summ$inf_type <- if(inf) inf_type else "None"
    class(summ) <- "summary.augsynth"
    return(summ)
}

#' Print function for summary function for augsynth
#' @param x summary object
#' @param ... Optional arguments
#' @export
print.summary.augsynth <- function(x, ...) {
    summ <- x
    
    ## straight from lm
    cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="")

    t_final <- nrow(summ$att)

    ## distinction between pre and post treatment
    att_est <- summ$att$Estimate
    t_total <- length(att_est)
    t_int <- summ$att %>% filter(Time <= summ$t_int) %>% nrow()
    
    att_pre <- att_est[1:(t_int-1)]
    att_post <- att_est[t_int:t_total]


    out_msg <- ""


    # print out average post treatment estimate
    att_post <- summ$average_att$Estimate[1]
    se_est <- summ$att$Std.Error
    if(summ$inf_type == "jackknife") {
      se_avg <- summ$average_att$Std.Error

      out_msg <- paste("Average ATT Estimate (Jackknife Std. Error): ",
                        format(round(att_post,3), nsmall=3), 
                        "  (",
                        format(round(se_avg,3)), ")\n")
      inf_type <- "Jackknife over units"
    } else if(summ$inf_type == "conformal") {
      p_val <- summ$average_att$p_val[1]
      out_msg <- paste("Average ATT Estimate (p Value for Joint Null): ",
                        format(att_post, digits = 3), 
                        "  (",
                        format(p_val, digits = 2), ")\n")
      inf_type <- "Conformal inference"
      if("Treatment Effect Slope" %in% summ$average_att$Value) {
        lowers <- summ$average_att$lower_bound[2:3]
        uppers <- summ$average_att$upper_bound[2:3]
        out_msg_line2 <- paste0("Confidence intervals for linear-in-time treatment effects (Intercept + Slope * Time)\n",
        "\tIntercept: [", format(lowers[1], digits = 3), ",",
        format(uppers[1], digits = 3), "]\n",
        "\tSlope: [", format(lowers[2], digits = 3), ",",
        format(uppers[2], digits = 3), "]\n")
        out_msg <- paste0(out_msg, out_msg_line2)
      }
    } else if(summ$inf_type == "jackknife+") {
      out_msg <- paste("Average ATT Estimate: ",
                        format(round(att_post,3), nsmall=3), "\n")
      inf_type <- "Jackknife+ over time periods"
    } else {
      out_msg <- paste("Average ATT Estimate: ",
                        format(round(att_post,3), nsmall=3), "\n")
      inf_type <- "None"
    }


    out_msg <- paste(out_msg, 
              "L2 Imbalance: ",
              format(round(summ$l2_imbalance,3), nsmall=3), "\n",
              "Percent improvement from uniform weights: ",
              format(round(1 - summ$scaled_l2_imbalance,3)*100), "%\n\n",
              sep="")
  if(!is.null(summ$covariate_l2_imbalance)) {

    out_msg <- paste(out_msg,
                     "Covariate L2 Imbalance: ",
                     format(round(summ$covariate_l2_imbalance,3), 
                                  nsmall=3),
                    "\n",
                     "Percent improvement from uniform weights: ",
                     format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100), 
                     "%\n\n",
                     sep="")

  }
  out_msg <- paste(out_msg, 
              "Avg Estimated Bias: ",
              format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n",
              "Inference type: ",
              inf_type,
              "\n\n",
              sep="")
  cat(out_msg)

    if(summ$inf_type == "jackknife") {
      out_att <- summ$att[t_int:t_final,] %>% 
              select(Time, Estimate, Std.Error)
    } else if(summ$inf_type == "conformal") {
      out_att <- summ$att[t_int:t_final,] %>% 
              select(Time, Estimate, lower_bound, upper_bound, p_val)
      names(out_att) <- c("Time", "Estimate", 
                          paste0((1 - summ$alpha) * 100, "% CI Lower Bound"),
                          paste0((1 - summ$alpha) * 100, "% CI Upper Bound"),
                          paste0("p Value"))
    } else if(summ$inf_type == "jackknife+") {
      out_att <- summ$att[t_int:t_final,] %>% 
              select(Time, Estimate, lower_bound, upper_bound)
      names(out_att) <- c("Time", "Estimate", 
                          paste0((1 - summ$alpha) * 100, "% CI Lower Bound"),
                          paste0((1 - summ$alpha) * 100, "% CI Upper Bound"))
    } else {
      out_att <- summ$att[t_int:t_final,] %>% 
              select(Time, Estimate)
    }
    out_att %>%
      mutate_at(vars(-Time), ~ round(., 3)) %>%
      print(row.names = F)

    
}

#' Plot function for summary function for augsynth
#' @param x Summary object
#' @param inf Boolean, whether to plot confidence intervals
#' @param ... Optional arguments
#' @export
plot.summary.augsynth <- function(x, inf = T, ...) {
    summ <- x
    # if ("inf" %in% names(list(...))) {
    #     inf <- list(...)$inf
    # } else {
    #     inf <- T
    # }
    
    p <- summ$att %>%
        ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
    if(inf) {
        if(all(is.na(summ$att$lower_bound))) {
            p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error,
                        ymax=Estimate+2*Std.Error),
                    alpha=0.2)
        } else {
            p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound,
                        ymax=upper_bound),
                    alpha=0.2)
        }

    }
    p + ggplot2::geom_line() +
        ggplot2::geom_vline(xintercept=summ$t_int, lty=2) +
        ggplot2::geom_hline(yintercept=0, lty=2) + 
        ggplot2::theme_bw()

}



#' augsynth
#' 
#' @description A package implementing the Augmented Synthetic Controls Method
#' @docType package
#' @name augsynth-package
#' @importFrom magrittr "%>%"
#' @importFrom purrr reduce
#' @import dplyr
#' @import tidyr
#' @importFrom stats terms
#' @importFrom stats formula
#' @importFrom stats update 
#' @importFrom stats delete.response 
#' @importFrom stats model.matrix 
#' @importFrom stats model.frame 
#' @importFrom stats na.omit
NULL


================================================
FILE: R/augsynth_pre.R
================================================
################################################################################
## Main function for the augmented synthetic controls Method
################################################################################


#' Fit Augmented SCM
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param data Panel data as dataframe
#' @param t_int Time of intervention (used for single-period treatment only)
#' @param ... Optional arguments
#' \itemize{
#'   \item Single period augsynth with/without multiple outcomes
#'     \itemize{
#'       \item{"progfunc"}{What function to use to impute control outcomes: Ridge=Ridge regression (allows for standard errors), None=No outcome model, EN=Elastic Net, RF=Random Forest, GSYN=gSynth, MCP=MCPanel, CITS=CITS, CausalImpact=Bayesian structural time series with CausalImpact, seq2seq=Sequence to sequence learning with feedforward nets}
#'       \item{"scm"}{Whether the SCM weighting function is used}
#'       \item{"fixedeff"}{Whether to include a unit fixed effect, default F }
#'       \item{"cov_agg"}{Covariate aggregation functions, if NULL then use mean with NAs omitted}
#'     }
#'   \item Multi period (staggered) augsynth
#'    \itemize{
#'          \item{"relative"}{Whether to compute balance by relative time}
#'          \item{"n_leads"}{How long past treatment effects should be estimated for}
#'          \item{"n_lags"}{Number of pre-treatment periods to balance, default is to balance all periods}
#'          \item{"alpha"}{Fraction of balance for individual balance}
#'          \item{"lambda"}{Regularization hyperparameter, default = 0}
#'          \item{"force"}{Include "none", "unit", "time", "two-way" fixed effects. Default: "two-way"}
#'          \item{"n_factors"}{Number of factors for interactive fixed effects, default does CV}
#'         }
#' }
#' 
#' @return augsynth object that contains:
#'         \itemize{
#'          \item{"weights"}{weights}
#'          \item{"data"}{Panel data as matrices}
#'         }
#' @export
#' 
augsynth <- function(form, unit, time, data, t_int=NULL, ...) {

  call_name <- match.call()

  form <- Formula::Formula(form)
  unit_quosure <- enquo(unit)
  time_quosure <- enquo(time)
  

  ## format data
  outcome <- terms(formula(form, rhs=1))[[2]]
  trt <- terms(formula(form, rhs=1))[[3]]

  # check for multiple outcomes
  multi_outcome <- length(outcome) != 1

  ## get first treatment times
  trt_time <- data %>%
      group_by(!!unit_quosure) %>%
      filter(!all(!!trt == 0)) %>%
      summarise(trt_time = min((!!time_quosure)[(!!trt) == 1])) %>%
      mutate(trt_time = replace_na(as.numeric(trt_time), Inf))

  num_trt_years <- sum(is.finite(unique(trt_time$trt_time)))

  if(multi_outcome & num_trt_years > 1) {
    stop("augsynth is not currently implemented for more than one outcome and more than one treated unit")
  } else if(num_trt_years > 1) {
    message("More than one treatment time found. Running multisynth.")
    if("progfunc" %in% names(list(...))) {
      warning("`progfunc` is not an argument for multisynth, so it is ignored")
    }
    return(multisynth(form, !!enquo(unit), !!enquo(time), data, ...)) 
  } else {
    if (is.null(t_int)) {
      t_int <- trt_time %>% filter(is.finite(trt_time)) %>%
        summarise(t_int = max(trt_time)) %>% pull(t_int)
    }
    if(!multi_outcome) {
      message("One outcome and one treatment time found. Running single_augsynth.")
      return(single_augsynth(form, !!enquo(unit), !!enquo(time), t_int,
                             data = data, ...))
    } else {
      message("Multiple outcomes and one treatment time found. Running augsynth_multiout.")
      return(augsynth_multiout(form, !!enquo(unit), !!enquo(time), t_int,
                               data = data, ...))
    }
  }
}


================================================
FILE: R/cv.R
================================================
drop_time_t <- function(wide_data, Z, t_drop) {
  new_wide_data <- list()
  new_wide_data$trt <- wide_data$trt
  
  if (is.list(wide_data$X)) {
    # TODO
  } else {
    new_wide_data$X <- wide_data$X[, -t_drop, drop = F]
    new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F], 
                             wide_data$y)
    
    X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]
    x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),
                 ncol=1)
    y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]
    y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])
    
    new_synth_data <- list()
    new_synth_data$Z0 <- t(X0)
    new_synth_data$X0 <- t(X0)
    new_synth_data$Z1 <- x1
    new_synth_data$X1 <- x1
    
    return(list(wide_data = new_wide_data,
                synth_data = new_synth_data,
                Z = Z)) 
  }
}

drop_time_and_refit <- function(wide_data, Z, t_drop, progfunc, scm, fixedeff, ...) {
  new_data <- drop_time_t(wide_data, Z, t_drop)
  new_ascm <- do.call(fit_augsynth_internal,
                      c(list(wide = new_data$wide,
                             synth_data = new_data$synth_data,
                             Z = new_data$Z,
                             progfunc = progfunc,
                             scm = scm,
                             fixedeff = fixedeff, ...)))
  return(new_ascm)
}

cv_internal <- function(wide_data, Z, progfunc, scm, fixedeff, lambdas, holdout_periods, ...) {
  X <- wide_data$X
  lambda_error_vals <- vapply(lambdas, function(lambda){
    errors <- apply(holdout_periods, 1, function(t_drop){
      new_ascm <- drop_time_and_refit(wide_data, Z, t_drop, progfunc, scm, fixedeff, lambda = lambda, ...)
      err <- sum((predict(new_ascm, att = T)[(ncol(X)-length(t_drop)+1):ncol(X)])^2)
      err
    })
    lambda_error <- mean(errors)
    lambda_error_se <- sd(errors) / sqrt(length(errors))
    c(lambda_error, lambda_error_se)
  }, numeric(2))
  return(list(lambda_errors = lambda_error_vals[1,], lambda_errors_se = lambda_error_vals[2,]))
}

cv_ridge <- function(wide_data, synth_data, Z, progfunc, scm, fixedeff, how = 'time', holdout_length = 1, lambdas = NULL, 
               lambda_min_ratio = 1e-8, n_lambda = 20, lambda_max = NULL, min_1se = T, V = NULL, ...) {
  X <- wide_data$X
  trt <- wide_data$trt
  if (is.null(lambdas)) {
    if(is.null(lambda_max)) {
      X_cent <- apply(X, 2, function(x) x - mean(x[trt==0]))
      X_c <- X_cent[trt==0,,drop=FALSE]
      t0 <- ncol(X_c)
      
      if(is.null(V)) {
        V <- diag(rep(1, t0))
      } else if(is.vector(V)) {
        V <- diag(V)
      } else if(ncol(V) == 1 & nrow(V) == t0) {
        V <- diag(c(V))
      } else if(ncol(V) == t0 & nrow(V) == 1) {
        V <- diag(c(V))
      } else if(nrow(V) == t0) {
      } else {
        stop("`V` must be a vector with t0 elements or a t0xt0 matrix")
      }
      X_c <- X_c %*% V

      if(!is.null(Z)) {
        Z_cent <- apply(Z, 2, function(x) x - mean(x[trt==0]))
        Z_c <- Z_cent[trt==0,,drop=FALSE]
        Xc_hat <- Z_c %*% solve(t(Z_c) %*% Z_c) %*% t(Z_c) %*% X_c
        res_c <- X_c - Xc_hat
        X_c <- res_c
      }
      lambda_max <- svd(X_c)$d[1] ^ 2
    }
    lambdas <- create_lambda_list(lambda_max, lambda_min_ratio, n_lambda)
  }
  
  if (how == 'time') {
    period_starts <- 1:(ncol(X) - holdout_length)
    if (holdout_length == 1) {
      holdout_periods <- matrix(period_starts, nrow = length(period_starts), ncol = 1)
    } else {
      holdout_periods <- t(vapply(period_starts, function(t) t:(t+holdout_length-1), numeric(holdout_length)))
    }
    results <- cv_internal(wide_data, Z, progfunc, scm, fixedeff, lambdas, holdout_periods, ...)
    lambda <- choose_lambda(lambdas, results$lambda_errors, results$lambda_errors_se, min_1se)
    return(list(lambda = lambda, lambdas = lambdas, lambda_errors = results$lambda_errors, lambda_errors_se = results$lambda_errors_se))
  }
}

================================================
FILE: R/data.R
================================================
#' Economic indicators for US states from 1990-2016
#' 
#' 
#' @format A dataframe with 5250 rows and 32 variables:
#' \describe{
#'  \item{fips}{FIPS code for each state}
#'  \item{year}{Year of measurement}
#'  \item{qtr}{Quarter (1-4) of measurement}
#'  \item{state}{Name of State}
#'  \item{gdp}{Gross State Product (millions of $) Values before 2005 are linearly interpolated between years}
#'  \item{revenuepop}{State and local revenue per capita}
#'  \item{rev_state_total}{State total general revenue (millions of $)}
#'  \item{rev_local_total}{Local total general revenue (millions of $)}
#'  \item{popestimate}{Population estimate}
#'  \item{qtrly_estabs_count}{Count of establishments for a given quarter}
#'  \item{month1_emplvl, month2_emplvl, month3_emplvl}{ Employment level for first, second, and third months of a given quarter}
#'  \item{total_qtrly_wages}{Total wages for a givne quarter}
#'  \item{taxable_qtrly_wage}{Taxable wages for a given quarter}
#'  \item{avg_wkly_wage}{Average weekly wage for a given quarter}
#'  \item{year_qtr}{Year and quarter combined into one continuous variable}
#'  \item{treated}{Whether the state passed tax cuts before the given year and quareter}
#'  \item{lngdpcapita}{Natural log of GDP per capita}
#'  \item{emplvlcapita}{Average employment level per capita}
#'  \item{Xcapita}{Per capita value of X}
#'  \item{abb}{State abbreviation}
#' }
"kansas"

================================================
FILE: R/eligible_donors.R
================================================
##############################################################################
## Code to get eligible donor units based on covariates
##############################################################################

get_donors <- function(X, y, trt, Z, time_cohort, n_lags,
                       n_leads, how = "knn", 
                       exact_covariates = NULL, ...) {

  # first get eligible donors by treatment time
  donors <- get_eligible_donors(trt, time_cohort, n_leads)

  # get donors with no NA values
  nona_donors <- get_nona_donors(X, y, trt, n_lags, n_leads, time_cohort)

  donors <- lapply(1:length(donors),
                     function(j) donors[[j]] & nona_donors[[j]])

  # if Z isn't NULL, futher restrict the donors by matching
  if(!is.null(Z)) {
    if(ncol(Z) != 0) {
      donors <- get_matched_donors(trt, Z, donors, how, exact_covariates, ...)
    }
  }

  return(donors)
}


get_eligible_donors <- function(trt, time_cohort, n_leads) {

    # get treatment times
    if(time_cohort) {
        grps <- unique(trt[is.finite(trt)])
    } else {
        grps <- trt[is.finite(trt)]
    }

    J <- length(grps)

    # only allow weights on donors treated after n_leads
    donors <- lapply(1:J, function(j) trt > n_leads + grps[j])

    return(donors)
}

#' Get donors that don't have missing outcomes where treated units have outcomes
get_nona_donors <- function(X, y, trt, n_lags, n_leads, time_cohort) {

  n <- length(trt)
  # find na treatment times
  fulldat <- cbind(X, y)
  is_na <- is.na(fulldat[is.finite(trt), , drop = F])
  # aggregate by time cohort
  if(time_cohort) {
    grps <- unique(trt[is.finite(trt)])
    # if doing a time cohort, convert the boolean mask
    finite_trt <- trt[is.finite(trt)]
    is_na <- t(sapply(grps,
                    function(tj) apply(is_na[finite_trt == tj, , drop = F],
                                       2, all)))
  } else {
      grps <- trt[is.finite(trt)]
  }
  not_na <- !is.na(fulldat)
  J <- length(grps)
  lapply(1:J,
             function(j) {
               idxs <- max(grps[j] - n_lags + 1, 1):min(grps[j] + n_leads,
                                                        ncol(fulldat))
               isna_j <- is_na[j, idxs]
               apply(not_na[, idxs, drop = F][, !isna_j, drop = F], 1, all)
        }) -> donors

  return(donors)
}

get_matched_donors <- function(trt, Z, donors, how, exact_covariates = NULL, k = NULL, ...) {

  J <- sum(is.finite(trt))
  trt_idx <- which(is.finite(trt))
  if(is.null(exact_covariates)) {
    if(how == "exact") {
      return(
        lapply(1:J,
            function(j) donors[[j]] & apply(t(Z) == Z[trt_idx[j], ], 2, all)
        )
      )
    } else if(how == "knn") {
        return(get_knn_donors(trt, Z, donors, k))
    } else {
      stop("Option for exact matching must be in ('exact', 'knn')")
    }
  } else {
        if(how == "exact") {
      return(
        lapply(1:J,
            function(j) donors[[j]] & apply(t(Z) == Z[trt_idx[j], 
                                                   exact_covariates], 2, all)
        )
      )
    } else if(how == "knn") {
        donors <- lapply(1:J,
            function(j) { donors[[j]] &
              apply(t(Z[, exact_covariates, drop = F]) == 
                Z[trt_idx[j],exact_covariates], 2, all)
            }
              )
        approx_covs <- which(!colnames(Z) %in% exact_covariates)
        if(length(approx_covs != 0)) {
          return(get_knn_donors(trt, Z[, approx_covs, drop = F], donors, k))
        } else {
          return(donors)
        }
        
    } else {
      stop("Option for exact matching must be in ('exact', 'knn')")
    }
  }

}

get_knn_donors <- function(trt, Z, donors, k) {

  if(is.null(k)) {
    stop("Number of neighbors for knn not selected, please choose k.")
  }
  # knn matching within time cohort
  trt_idxs <- which(is.finite(trt))
  lapply(1:length(trt_idxs), 
        function(j) {
          idx <- trt_idxs[j]
          # idxs for treated units treated at time tj
          Z_tj <- Z[idx, , drop = F]

          # get donors for treated cohort
          donors_tj <- donors[[j]]
          Z_donors_tj <- Z[donors_tj, , drop = F]
          # check that k is less than the number of donors
          # if not, warn and set k to be the number of donors - 1
          if(k >= nrow(Z_donors_tj)) {
            warning(paste("Number of potential donor units is less than",
                          "the number of required matches,",
                          "returning all donors as matches"))
            return(donors_tj)
          }
          # do knn matching
          nn <- FNN::get.knnx(data = Z_donors_tj, query = Z_tj, k = k)
          # keep track of which indices these are
          donors_j <- logical(length(donors_tj))
          true_idx <- which(donors_tj)[nn$nn.index[1, ]]
          donors_j[true_idx] <- TRUE
          return(donors_j)
         }) -> matches
  names(matches) <- trt_idxs
  return(matches)
}

================================================
FILE: R/fit_synth.R
================================================
#######################################################
# Helper scripts to fit synthetic controls to simulations
#######################################################

#' Make a V matrix from a vector (or null)
make_V_matrix <- function(t0, V) {
  if(is.null(V)) {
        V <- diag(rep(1, t0))
    } else if(is.vector(V)) {
        if(length(V) != t0) {
          stop(paste("`V` must be a vector with", t0, "elements or a", t0, 
                     "x", t0, "matrix"))
        }
        V <- diag(V)
    } else if(ncol(V) == 1 & nrow(V) == t0) {
        V <- diag(c(V))
    } else if(ncol(V) == t0 & nrow(V) == 1) {
        V <- diag(c(V))
    } else if(nrow(V) == t0) {
    } else {
        stop(paste("`V` must be a vector with", t0, "elements or a", t0, 
                     "x", t0, "matrix"))
    }

  return(V)
}

#' Fit synthetic controls on outcomes after formatting data
#' @param synth_data Panel data in format of Synth::dataprep
#' @param V Matrix to scale the obejctive by
#' @noRd
#' @return \itemize{
#'          \item{"weights"}{Synth weights}
#'          \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#'          \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' }
fit_synth_formatted <- function(synth_data, V = NULL) {


    t0 <- dim(synth_data$Z0)[1]
    ## if no  is supplied, set equal to 1

    V <- make_V_matrix(t0, V)

    weights <- synth_qp(synth_data$X1, t(synth_data$X0), V)
    l2_imbalance <- sqrt(sum((synth_data$Z0 %*% weights - synth_data$Z1)^2))

    ## primal objective value scaled by least squares difference for mean
    uni_w <- matrix(1/ncol(synth_data$Z0), nrow=ncol(synth_data$Z0), ncol=1)
    unif_l2_imbalance <- sqrt(sum((synth_data$Z0 %*% uni_w - synth_data$Z1)^2))
    scaled_l2_imbalance <- l2_imbalance / unif_l2_imbalance

    return(list(weights=weights,
                l2_imbalance=l2_imbalance,
                scaled_l2_imbalance=scaled_l2_imbalance))
}

#' Solve the synth QP directly
#' @param X1 Target vector
#' @param X0 Matrix of control outcomes
#' @param V Scaling matrix
#' @noRd
synth_qp <- function(X1, X0, V) {
    
    Pmat <- X0 %*% V %*% t(X0)
    qvec <- - t(X1) %*% V %*% t(X0)

    n0 <- nrow(X0)
    A <- rbind(rep(1, n0), diag(n0))
    l <- c(1, numeric(n0))
    u <- c(1, rep(1, n0))

    settings = osqp::osqpSettings(verbose = FALSE,
                                  eps_rel = 1e-8,
                                  eps_abs = 1e-8)
    sol <- osqp::solve_osqp(P = Pmat, q = qvec,
                            A = A, l = l, u = u, 
                            pars = settings)

    return(sol$x)
}


================================================
FILE: R/format.R
================================================
################################################################################
## Scripts to format panel data into matrices
################################################################################

#' Format "long" panel data into "wide" program evaluation matrices
#' @param outcome Name of outcome column
#' @param trt Name of treatment column
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @noRd
#' @return \itemize{
#'          \item{"X"}{Matrix of pre-treatment outcomes}
#'          \item{"trt"}{Vector of treatment assignments}
#'          \item{"y"}{Matrix of post-treatment outcomes}
#'         }
format_data <- function(outcome, trt, unit, time, t_int, data) {

    ## pre treatment outcomes
    X <- data %>%
        filter(!!time < t_int) %>%
        select(!!unit, !!time, !!outcome) %>%
        spread(!!time, !!outcome) %>%
        select(-!!unit) %>%
        as.matrix()
    


    ## post treatment outcomes
    y <- data %>%
        filter(!!time >= t_int) %>%
        select(!!unit, !!time, !!outcome) %>%
        spread(!!time, !!outcome) %>%
        select(-!!unit) %>%
        as.matrix()


    ## treatment status
    trt <- data %>%
        select(!!unit, !!trt) %>%
        group_by(!!unit) %>%
        summarise(trt = max(!!trt)) %>%
        ungroup() %>%
        pull(trt)

    return(list(X=X, trt=trt, y=y))
}


#' Format "long" panel data into "wide" program evaluation matrices
#' @param outcomes Vectors of names of outcome columns
#' @param trt Name of treatment column
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @noRd
#' @return \itemize{
#'          \item{"X"}{List of matrices of pre-treatment outcomes}
#'          \item{"trt"}{Vector of treatment assignments}
#'          \item{"y"}{List of matrices of post-treatment outcomes}
#'         }
format_data_multi <- function(outcomes, trt, unit, time, t_int, data) {


    lapply(outcomes, 
        function(outcome) format_data(outcome, trt, unit, 
                                     time, t_int, data)
          ) -> formats

    # X <- simplify2array(lapply(formats, function(x) x$X))
    # y <- simplify2array(lapply(formats, function(x) x$y))
    # X <- lapply(formats, function(x) t(na.omit(t(x$X))))
    X <- lapply(formats, `[[`, "X")
    y <- lapply(formats, function(x) t(na.omit(t(x$y))))
    trt <- formats[[1]]$trt
    return(list(X = X, trt = trt, y = y))
}




#' Format "long" panel data into "wide" program evaluation matrices with staggered adoption
#' @param outcome Name of outcome column
#' @param trt Name of treatment column
#' @param unit Name of unit column
#' @param time Name of time column
#' @param data Panel data as dataframe
#' @noRd
#' @return \itemize{
#'          \item{"X"}{Matrix of pre-treatment outcomes}
#'          \item{"trt"}{Vector of treatment assignments}
#'          \item{"y"}{Matrix of post-treatment outcomes}
#'         }
format_data_stag <- function(outcome, trt, unit, time, data) {

    # arrange data by time first
    data <- data %>% arrange(!!time)
      
    ## get first treatment times
    trt_time <- data %>%
        group_by(!!unit) %>%
        summarise(trt_time=(!!time)[(!!trt) == 1][1]) %>%
        mutate(trt_time=replace_na(as.numeric(trt_time), Inf))
    

    t_int <- trt_time %>% filter(is.finite(trt_time)) %>%
        summarise(t_int=max(trt_time)) %>% pull(t_int)

    ## ## boolean mask of available data for treatment groups
    ## mask <- data %>% inner_join(trt_time %>%
    ##                             filter(is.finite(trt_time))) %>%
    ##     filter(!!time < t_int) %>%
    ##     mutate(trt=1-!!trt) %>%
    ##     select(!!unit, !!time, trt_time, trt) %>%
    ##     spread(!!time, trt) %>% 
    ##     group_by(trt_time) %>% 
    ##     summarise_all(list(max)) %>%
    ##     arrange(trt_time) %>% 
    ##     select(-trt_time, -!!unit) %>%
    ##     as.matrix()

    ## boolean mask of available data for treatment groups
    mask <- data %>% inner_join(trt_time %>%
                                filter(is.finite(trt_time)),
                                by = rlang::as_name(unit)) %>%
        filter(!!time < t_int) %>%
        mutate(trt=1-!!trt) %>%
        select(!!unit, !!time, trt_time, trt) %>%
        spread(!!time, trt) %>% 
        ## arrange(!!unit) %>% 
        select(-trt_time, -!!unit) %>%
        as.matrix()
    
    # outcomes as a matrix
    Xy <- data %>%
        select(!!unit, !!time, !!outcome) %>%
        spread(!!time, !!outcome) %>%
        select(-!!unit) %>%
        as.matrix()

    pre_times <- data %>% filter(!!time < t_int) %>%
        distinct(!!time) %>% pull(!!time)
    post_times <- data %>% filter(!!time >= t_int) %>%
        distinct(!!time) %>% pull(!!time)
    X <- Xy[, as.character(pre_times), drop = F]
    y <- Xy[, as.character(post_times), drop = F]

    if(nrow(X) != nrow(y)) {
      stop("There are not the same number of units after the last unit is treated as before the last unit is treated")
    }

    t_vec <- data %>% pull(!!time) %>% unique() %>% sort()
    trt <- sapply(trt_time$trt_time, function(x) which(t_vec == x)-1) %>%
        as.numeric() %>%
        replace_na(Inf)
   

    units <- data %>%
        filter(!!time < t_int) %>%
        select(!!unit, !!time, !!outcome) %>%
        spread(!!time, !!outcome) %>%
        pull(!!unit)

    
    return(list(X=X,
                trt=trt,
                y=y,
                mask=mask,
                time = t_vec,
                units=units))
}


#' Format program eval matrices into synth form
#'
#' @param X Matrix of pre-treatment outcomes
#' @param trt Vector of treatment assignments
#' @param y Matrix of post-treatment outcomes
#' @noRd
#' @return List with data formatted as Synth::dataprep
format_synth <- function(X, trt, y) {


    synth_data <- list()

    ## pre-treatment values as covariates
    synth_data$Z0 <- t(X[trt==0,,drop=F])

    ## average treated units together
    synth_data$Z1 <- as.matrix((colMeans(X[trt==1,,drop=F])), ncol=1)

    ## combine everything together also
    synth_data$Y0plot <- t(cbind(X[trt==0,,drop=F], y[trt==0,,drop=F]))
    synth_data$Y1plot <- as.matrix(colMeans(
        cbind(X[trt==1,,drop=F], y[trt==1,,drop=F])), ncol=1)


    ## predictors are pre-period outcomes
    synth_data$X0 <- synth_data$Z0
    synth_data$X1 <- synth_data$Z1

    return(synth_data)
    
}

#' Remove unit means 
#' @param wide_data X, y, trt
#' @param synth_data List with data formatted as Synth::dataprep
#' @noRd
demean_data <- function(wide_data, synth_data) {

    # pre treatment means
    means <- rowMeans(wide_data$X)

    new_wide_data <- list()
    new_X <- wide_data$X - means
    trt <- wide_data$trt

    new_wide_data$X <- new_X
    new_wide_data$y <- wide_data$y - means
    new_wide_data$trt <- trt

    new_synth_data <- list()
    new_synth_data$X0 <- t(new_X[trt == 0,, drop = FALSE])
    new_synth_data$Z0 <- new_synth_data$X0
    new_synth_data$X1 <- as.matrix((colMeans(new_X[trt==1,,drop = F])), 
                                   ncol = 1)
    new_synth_data$Z1 <- new_synth_data$X1


    # estimate post-treatment as pre-treatment means
    mhat <- replicate(ncol(wide_data$X) + ncol(wide_data$y), means)

    return(list(wide = new_wide_data,
                synth_data = new_synth_data,
                mhat = mhat))
}

#' Helper function to extract covariate matrix from data
#' @param form Formula as outcome ~ treatment | covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @param cov_agg Covariate aggregation function
#' @noRd
extract_covariates <- function(form, unit, time, t_int, data, cov_agg) {

    ## if no aggregation functions, use the mean (omitting NAs)
    if(is.null(cov_agg)) {
        cov_agg <- c(function(x) mean(x, na.rm=T))
    }

    cov_form <- update(formula(delete.response(terms(form, rhs=2, data=data))),
                        ~. - 1) ## ensure that there is no intercept

    ## pull out relevant covariates and aggregate
    pre_data <- data %>% 
        filter(!! (time) < t_int)

    model.matrix(cov_form,
                    model.frame(cov_form, pre_data,
                                na.action=NULL) ) %>%
        data.frame() %>%
        mutate(unit=pull(pre_data, !!unit)) %>%
        group_by(unit) %>%
        summarise_all(cov_agg) -> Z

    # recombine with any missing units and convert to matrix
    data %>% distinct(!!unit) %>%
      rename(unit = !!unit) %>%
      left_join(Z, by = "unit") %>%
      arrange(unit) %>%
      select(-unit) %>%
      as.matrix() -> Z
    
    if(nrow(distinct(data, !!unit))  != nrow(Z)) {
      stop("Some units missing all covariate data")
    }

    # check if any covariates have no variation
    Zsds <- apply(Z, 2, sd)

    if(any(Zsds == 0)) {
      zero_covs <- paste(colnames(Z)[Zsds == 0], collapse = ", ")
      stop(paste("The following covariates have no variation across units:",
                 zero_covs))
    }
    return(Z)
}

#' Check that we can actually run multisynth on the data
#' @param wide Output of format_data_stag
#' @param fixedeff Whether to include a unit fixed effect
#' @param n_leads How long past treatment effects should be estimated for, default is number of post treatment periods for last treated unit
#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods
check_data_stag <- function(wide, fixedeff, n_leads, n_lags) {

  # If there are less than 5 pre-treatment outcomes, give a warning
  less_5 <- wide$units[wide$trt < 5]
  warn_msg <- ""
  if(length(less_5) != 0) {
    warn_msg <- paste0(
      warn_msg,
      "The following units have less than 5 pre-treatment outcomes: (",
      paste(less_5, collapse = ","),
      "). Be cautious!"
    )
  }

  # check if there are any always treated units
  always_trt <- wide$units[wide$trt == 0]

  # If including a fixed effect, check that there is more than one pretreatment
  # outcome for each unit
  n1 <- wide$units[wide$trt == 1]

  err_msg <- ""
  if(length(always_trt) != 0) {
    err_msg <- paste0(
      err_msg,
      "The following units are always treated and should be removed: (",
      paste(always_trt, collapse = ","),
      ")\n")
  }

  if(length(n1) != 0 & fixedeff) {
    if(nchar(err_msg) > 0) {
      err_msg <- paste0(err_msg, "  Also: ")
    }
    err_msg <- paste0(
      err_msg,
      "You are including a fixed effect with `fixedeff = T`, but the ",
      "following units only have one pre-treatment outcome: (",
      paste(n1, collapse = ","),
      "). Either remove these units or set `fixedeff = F`.\n"
    )
  }
  # check if there are never treated units
  if(max(wide$trt) < ncol(wide$X) + ncol(wide$y)) {
    if(nchar(err_msg) > 0) {
      err_msg <- paste0(err_msg, "  Also: ")
    }
    err_msg <- paste0(
      err_msg,
      "All units are eventually treated. The last treatment time is ",
      wide$time[max(wide$trt)],
      ". To run multisynth, remove all periods after this time. ",
      "Units treated at this time will be considered 'never treated' in the ",
      "narrowed sample.\n"
    )
  }

  if(nchar(warn_msg) > 0) {
    warning(warn_msg)
  }
  if(nchar(err_msg) > 0) {
    stop(err_msg)
  }

}

================================================
FILE: R/globalVariables.R
================================================
utils::globalVariables(c("time", "val", "post", "weight", ".", "Time",
                         "Estimate", "Std.Error", "Level", "last_time",
                         "is_avg", "label", "Outcome", "unit", "obs",
                         "lambdas", "errors_se",
                         "upper_bound", "lower_bound"))

================================================
FILE: R/highdim.R
================================================
################################################################################
## Methods to use flexible outcome models
################################################################################

##### Augmented SCM with general outcome models

#' Use zero weights, do nothing but output everything in the right way
#' @param synth_data Panel data in format of Synth::dataprep
#' @noRd
#' @return \itemize{
#'          \item{"weights"}{Synth weights}
#'          \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#'          \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' }
fit_zero_weights <- function(synth_data) {
    
    ## Imbalance is uniform weights imbalance
    uni_w <- matrix(1/ncol(synth_data$Z0), nrow=ncol(synth_data$Z0), ncol=1)
    unif_l2_imbalance <- sqrt(sum((synth_data$Z0 %*% uni_w - synth_data$Z1)^2))
    scaled_l2_imbalance <- 1
    
    return(list(weights=rep(0, ncol(synth_data$Z0)),
                l2_imbalance=unif_l2_imbalance,
                scaled_l2_imbalance=scaled_l2_imbalance))
}



#' Fit E[Y(0)|X] and for each post-period and balance pre-period
#'
#' @param wide_data Output of `format_ipw`
#' @param synth_data Output of `synth_data`
#' @param fit_progscore Function to fit prognostic score
#' @param fit_weights Function to fit synth weights
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#'          \item{"weights"}{Ridge ASCM weights}
#'          \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#'          \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#'          \item{"mhat"}{Outcome model estimate}
#' }
fit_augsyn_formatted <- function(wide_data, synth_data,
                                fit_progscore, fit_weights, ...) {


    X <- wide_data$X
    y <- wide_data$y
    trt <- wide_data$trt
    
    ## fit prognostic scores
    fitout <- do.call(fit_progscore,
                          list(X=X, y=y, trt=trt, ...))
    
    ## fit synth
    syn <- fit_weights(synth_data)

    syn$params <- fitout$params

    syn$mhat <- fitout$y0hat
    
    return(syn)
}


#' Fit outcome model and balance pre-period
#' @param wide_data Output of `format_ipw`
#' @param synth_data Output of `synth_data`
#' @param progfunc What function to use to impute control outcomes
#'                 EN=Elastic Net, RF=Random Forest, GSYN=gSynth,
#'                 Comp=softImpute, MCP=MCPanel, CITS=CITS
#'                 CausalImpact=Bayesian structural time series with CausalImpact
#'                 seq2seq=Sequence to sequence learning with feedforward nets
#' @param scm Whether the SCM weighting function is used
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#'          \item{"weights"}{Ridge ASCM weights}
#'          \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#'          \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#'          \item{"mhat"}{Outcome model estimate}
#' }
fit_augsyn <- function(wide_data, synth_data,
                       progfunc=c("EN", "RF", "GSYN", "MCP","CITS", "CausalImpact", "seq2seq"),
                       scm=T, ...) {
    ## prognostic score and weight functions to use
    progfunc = tolower(progfunc)
    if(progfunc == "en") {
        progf <- fit_prog_reg
    } else if(progfunc == "rf") {
        progf <- fit_prog_rf
    } else if(progfunc == "gsyn"){
        progf <- fit_prog_gsynth
    } else if(progfunc == "mcp"){
        progf <- fit_prog_mcpanel
    } else if(progfunc == "cits") {
        progf <- fit_prog_cits
    } else if(progfunc == "causalimpact") {
        progf <- fit_prog_causalimpact
    } else if(progfunc == "seq2seq"){
        progf <- fit_prog_seq2seq
    } else {
        stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq'")
    }

    if(scm) {
        weightf <- fit_synth_formatted
    } else {
        ## still fit synth even if none
        ## TODO: This is a dumb wasteful hack
        weightf <- fit_zero_weights
    }
    return(fit_augsyn_formatted(wide_data, synth_data,
                                progf, weightf, ...))
}



### Combine synth and gsynth by balancing pre-period residuals
#' Fit outcome model and balance residuals
#'
#' @param wide_data Output of `format_data`
#' @param synth_data Output of `format_synth`
#' @param fit_progscore Function to fit prognostic score
#' @param fit_weights Function to fit synth weights
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#'          \item{"weights"}{Ridge ASCM weights}
#'          \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#'          \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#'          \item{"mhat"}{Outcome model estimate}
#' }
fit_residaug_formatted <- function(wide_data, synth_data,
                                  fit_progscore, fit_weights, ...) {


    X <- wide_data$X
    y <- wide_data$y
    trt <- wide_data$trt

    ## fit prognostic scores
    fitout <- do.call(fit_progscore, list(X=X, y=y, trt=trt, ...))

    
    y0hat <- fitout$y0hat

    ## get residuals
    ctrl_resids <- fitout$params$ctrl_resids
    trt_resids <- fitout$params$trt_resids
    
    ## replace outcomes with pre-period residuals
    t0 <- dim(X)[2]

    synth_data$Z0 <- ctrl_resids[1:t0, ]
    synth_data$Z1 <- as.matrix(trt_resids[1:t0])
    
    ## fit synth weights
    syn <- fit_weights(synth_data)

    syn$params <- fitout$params    

    ## return predicted values for treatment and control
    syn$mhat <- y0hat
    
    return(syn)
}
#' Fit outcome model and balance residuals
#'
#' @param wide_data Output of `format_data`
#' @param synth_data Output of `format_synth`
#' @param progfunc What function to use to impute control outcomes
#'                 GSYN=gSynth, MCP=MCPanel,
#'                 CITS=Comparative interrupted time series
#'                 CausalImpact=Bayesian structural time series with CausalImpact
#' @param weightfunc What function to use to fit weights
#'                   SCM=Vanilla Synthetic Controls
#'                   NONE=No reweighting, just outcome model
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#'          \item{"weights"}{Ridge ASCM weights}
#'          \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#'          \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#'          \item{"mhat"}{Outcome model estimate}
#' }
fit_residaug <- function(wide_data, synth_data,
                        progfunc=c("GSYN", "MCP", "CITS", "CausalImpact"),
                        weightfunc=c("SC","ENT", "SVD", "NONE"), ...) {

    ## prognostic score and weight functions to use
    if(progfunc == "GSYN"){
        progf <- fit_prog_gsynth
    } else if(progfunc == "MCP"){
        progf <- fit_prog_mcpanel
    } else if(progfunc == "CITS") {
        progf <- fit_prog_cits
    } else if(progfunc == "CausalImpact") {
        progf <- fit_prog_causalimpact
    } else {
        stop("progfunc must be one of 'GSYN', 'MCP', 'CITS', 'CausalImpact'")
    }

    
    ## weight function to use
    if(weightfunc == "SCM") {
        weightf <- fit_synth_formatted
    } else if(weightfunc == "NONE") {
        ## still fit synth even if none
        ## TODO: This is a dumb wasteful hack
        weightf <- fit_synth_formatted
    } else {
        stop("weightfunc must be one of 'SCM', 'NONE'")
    }

    return(fit_residaug_formatted(wide_data, synth_data,
                                  progf, weightf, ...))
}



================================================
FILE: R/inference.R
================================================
################################################################################
## Code for inference
################################################################################

#' Jackknife+ algorithm over time
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param conservative Whether to use the conservative jackknife+ procedure
#' @return List that contains:
#'         \itemize{
#'          \item{"att"}{Vector of ATT estimates}
#'          \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#'          \item{"se"}{Standard error, always NA but returned for compatibility}
#'          \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#'          \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#'          \item{"alpha"}{Level of confidence interval}
#'         }
time_jackknife_plus <- function(ascm, alpha = 0.05, conservative = F) {
    wide_data <- ascm$data
    synth_data <- ascm$data$synth_data
    n <- nrow(wide_data$X)
    n_c <- dim(synth_data$Z0)[2]
    Z <- wide_data$Z

    t0 <- dim(synth_data$Z0)[1]
    tpost <- ncol(wide_data$y)
    t_final <- dim(synth_data$Y0plot)[1]

    jack_ests <- lapply(1:t0, 
        function(tdrop) {
            # drop unit i
            new_data <- drop_time_t(wide_data, Z, tdrop)
            # refit
            new_ascm <- do.call(fit_augsynth_internal,
                    c(list(wide = new_data$wide,
                            synth_data = new_data$synth_data,
                            Z = new_data$Z,
                            progfunc = ascm$progfunc,
                            scm = ascm$scm,
                            fixedeff = ascm$fixedeff),
                        ascm$extra_args))
            # get ATT estimates and held out error for time t
            # t0 is prediction for held out time
            est <- predict(new_ascm, att = F)[(t0 +1):t_final]
            est <- c(est, mean(est))
            err <- c(colMeans(wide_data$X[wide_data$trt == 1,
                                         tdrop,
                                         drop = F]) -
                    predict(new_ascm, att = F)[t0])
            list(err, rbind(est + abs(err), est - abs(err), est + err, est))
        })
    # get errors and jackknife distribution
    held_out_errs <- vapply(jack_ests, `[[`, numeric(1), 1)
    jack_dist <- vapply(jack_ests, `[[`,
                        matrix(0, nrow = 4, ncol = tpost + 1), 2)

    out <- list()
    att <- predict(ascm, att = T)
    out$att <- c(att, 
                 mean(att[(t0 + 1):t_final]))
    # held out ATT
    out$heldout_att <- c(held_out_errs, 
                          att[(t0 + 1):t_final], 
                          mean(att[(t0 + 1):t_final]))

    # out$se <- rep(NA, 10 + tpost)
    if(conservative) {
        qerr <- stats::quantile(abs(held_out_errs), 1 - alpha)
        out$lb <- c(rep(NA, t0), apply(jack_dist[4,,], 1, min) - qerr)
        out$ub <- c(rep(NA, t0), apply(jack_dist[4,,], 1, max) + qerr)
    } else {
        out$lb <- c(rep(NA, t0), apply(jack_dist[2,,], 1, stats::quantile, alpha / 2))
        out$ub <- c(rep(NA, t0), apply(jack_dist[1,,], 1, stats::quantile, 1 - alpha / 2))
    }
    # shift back to ATT scale
    y1 <- predict(ascm, att = F) + att
    y1 <-  c(y1, mean(y1[(t0 + 1):t_final]))
    shifted_lb <- y1 - out$ub
    shifted_ub <- y1 - out$lb
    out$lb <- shifted_lb
    out$ub <- shifted_ub
    out$alpha <- alpha


    return(out)
}

#' Drop time period from pre-treatment data
#' @param wide_data (X, y, trt)
#' @param Z Covariates matrix
#' @param t_drop Time to drop
#' @noRd
drop_time_t <- function(wide_data, Z, t_drop) {

        new_wide_data <- list()
        new_wide_data$trt <- wide_data$trt
        new_wide_data$X <- wide_data$X[, -t_drop, drop = F]
        new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F], 
                                 wide_data$y)

        X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]
        x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,,
                                              drop = F]),
                     ncol=1)
        y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]
        y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])

        new_synth_data <- list()
        new_synth_data$Z0 <- t(X0)
        new_synth_data$X0 <- t(X0)
        new_synth_data$Z1 <- x1
        new_synth_data$X1 <- x1

        return(list(wide_data = new_wide_data,
                    synth_data = new_synth_data,
                    Z = Z)) 
}

#' Conformal inference procedure to compute p-values and point-wise confidence intervals
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param stat_func Function to compute test statistic
#' @param type Either "iid" for iid permutations or "block" for moving block permutations; default is "block"
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param grid_size Number of grid points to use when inverting the hypothesis test
#' @return List that contains:
#'         \itemize{
#'          \item{"att"}{Vector of ATT estimates}
#'          \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#'          \item{"se"}{Standard error, always NA but returned for compatibility}
#'          \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#'          \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#'          \item{"p_val"}{p-value for test of no post-treatment effect}
#'          \item{"alpha"}{Level of confidence interval}
#'         }
conformal_inf <- function(ascm, alpha = 0.05, 
                          stat_func = NULL, type = "iid",
                          q = 1, ns = 1000, grid_size = 50) {
  wide_data <- ascm$data
  synth_data <- ascm$data$synth_data
  n <- nrow(wide_data$X)
  n_c <- dim(synth_data$Z0)[2]
  Z <- wide_data$Z

  t0 <- dim(synth_data$Z0)[1]
  tpost <- ncol(wide_data$y)
  t_final <- dim(synth_data$Y0plot)[1]

  # grid of nulls
  att <- predict(ascm, att = T)
  post_att <- att[(t0 +1):t_final]
  post_sd <- sqrt(mean(post_att ^ 2))
  # iterate over post-treatment periods to get pointwise CIs
  vapply(1:tpost,
         function(j) {
          # fit using t0 + j as a pre-treatment period and get reisduals
          new_wide_data <- wide_data
          new_wide_data$X <- cbind(wide_data$X, wide_data$y[, j, drop = TRUE])
          if(tpost > 1) {
            new_wide_data$y <- wide_data$y[, -j, drop = FALSE]
          } else {
            # set the post period has to be *something*
            new_wide_data$y <- matrix(1, nrow = n, ncol = 1)
          }


          # make a grid around the estimated ATT
          grid <- seq(att[t0 + j] - 2 * post_sd, att[t0 + j] + 2 * post_sd,
                      length.out = grid_size)
          compute_permute_ci(new_wide_data, ascm, grid, 1, alpha, type,
                             q, ns, stat_func)
         },
         numeric(3)) -> cis

  # test a null post-treatment effect
  new_wide_data <- wide_data
  new_wide_data$X <- cbind(wide_data$X, wide_data$y)
  new_wide_data$y <- matrix(1, nrow = n, ncol = 1)
  null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y), 
                                 type, q, ns, stat_func)
  out <- list()
  att <- predict(ascm, att = T)
  out$att <- c(att, mean(att[(t0 + 1):t_final]))
  # out$se <- rep(NA, t_final)
  # out$sigma <- NA
  out$lb <- c(rep(NA, t0), cis[1, ], NA)
  out$ub <- c(rep(NA, t0), cis[2, ], NA)
  out$p_val <- c(rep(NA, t0), cis[3, ], null_p)
  out$alpha <- alpha
  return(out)
}


#' Conformal inference procedure to compute a confidence interval for a linear in time effect
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param stat_func Function to compute test statistic
#' @param type Either "iid" for iid permutations or "block" for moving block permutations; default is "iid"
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param grid_size Number of grid points to use when inverting the hypothesis test
#' @return List that contains:
#'         \itemize{
#'          \item{"att"}{Vector of ATT estimates}
#'          \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#'          \item{"se"}{Standard error, always NA but returned for compatibility}
#'          \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#'          \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#'          \item{"p_val"}{p-value for test of no post-treatment effect}
#'          \item{"alpha"}{Level of confidence interval}
#'         }
conformal_inf_linear <- function(ascm, alpha = 0.05, 
                          stat_func = NULL, type = "iid",
                          q = 1, ns = 1000, grid_size = 50) {
  wide_data <- ascm$data
  synth_data <- ascm$data$synth_data
  n <- nrow(wide_data$X)
  n_c <- dim(synth_data$Z0)[2]
  Z <- wide_data$Z

  t0 <- dim(synth_data$Z0)[1]
  tpost <- ncol(wide_data$y)
  t_final <- dim(synth_data$Y0plot)[1]

  # grid of nulls
  att <- predict(ascm, att = T)
  post_att <- att[(t0 +1):t_final]
  post_second <- sqrt(mean(post_att^2))

  # grid for slope
  # use ols to get pilot estimate
  ts <- 1:tpost
  lm_out <- summary(lm(post_att ~ ts))$coefficients
  # grid for intercept
  grid_int <- seq(lm_out[1,1] - 2 * post_second,
                  lm_out[1,1] + 2 * post_second,
                  length.out = grid_size)
  if(tpost == 2) {
    warning(paste0("There are 2 post-treatment time periods, so a linear model has a perfect fit. A confidence interval for the slope may not be reasonable here."))
    grid_slope <- seq(lm_out[2,1] - abs(lm_out[2,1]),
                  lm_out[2,1] + abs(lm_out[2,1]),
                  length.out = grid_size)
  } else if(tpost <= 1) {
    stop("There is only one post-treatment time period, so an intercept and a slope cannot be computed.")
  } else {
    grid_slope <- seq(lm_out[2,1] - 4 * lm_out[2,2] * sqrt(tpost),
                  lm_out[2,1] + 4 * lm_out[2,2] * sqrt(tpost),
                  length.out = grid_size)
  }

  # test a null post-treatment effect
  new_wide_data <- wide_data
  new_wide_data$X <- cbind(wide_data$X, wide_data$y)
  new_wide_data$y <- matrix(1, nrow = n, ncol = 1)
  null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y), 
                                 type, q, ns, stat_func)

  # confidence interval for linear in time treatment effects
  cis <- compute_permute_ci_linear(new_wide_data, ascm, grid_int, grid_slope,
                                   ncol(wide_data$y), alpha, type, q, ns, stat_func)

  return(cis)
}


#' Compute conformal test statistics
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param h0 Null hypothesis to test
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#' 
#' @return List that contains:
#'         \itemize{
#'          \item{"resids"}{Residuals after enforcing the null}
#'          \item{"test_stats"}{Permutation distribution of test statistics}
#'          \item{"stat_func"}{Test statistic function}
#'         }
#' @noRd
compute_permute_test_stats <- function(wide_data, ascm, h0,
                                       post_length, type,
                                       q, ns, stat_func) {
  # format data
  new_wide_data <- wide_data
  t0 <- ncol(wide_data$X) - post_length
  tpost <- t0 + post_length
  # adjust outcomes for null
  new_wide_data$X[wide_data$trt == 1,(t0 + 1):tpost ] <- new_wide_data$X[wide_data$trt == 1,(t0 + 1):tpost] - h0
  X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]
  x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),
              ncol=1)

  new_synth_data <- list()
  new_synth_data$Z0 <- t(X0)
  new_synth_data$X0 <- t(X0)
  new_synth_data$Z1 <- x1
  new_synth_data$X1 <- x1

  # fit synth with adjusted data and get residuals
  new_ascm <- do.call(fit_augsynth_internal,
                    c(list(wide = new_wide_data,
                            synth_data = new_synth_data,
                            Z = wide_data$Z,
                            progfunc = ascm$progfunc,
                            scm = ascm$scm,
                            fixedeff = ascm$fixedeff),
                        ascm$extra_args))
  resids <- predict(new_ascm, att = T)[1:tpost]
  # permute residuals and compute test statistic
  if(is.null(stat_func)) {
    stat_func <- function(x) (sum(abs(x) ^ q)  / sqrt(length(x))) ^ (1 / q)
  }
  if(type == "iid") {
    test_stats <- sapply(1:ns, 
                        function(x) {
                          reorder <- sample(resids)
                          stat_func(reorder[(t0 + 1):tpost])
                        })
  } else {
    ## increment time by one step and wrap
    test_stats <- sapply(1:tpost,
                        function(j) {
                          reorder <- resids[(0:tpost -1 + j) %% tpost + 1]
                          stat_func(reorder[(t0 + 1):tpost])
                        })
  }
  
  return(list(resids = resids,
              test_stats = test_stats,
              stat_func = stat_func))
}


#' Compute conformal p-value
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param h0 Null hypothesis to test
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#' 
#' @return Computed p-value
#' @noRd
compute_permute_pval <- function(wide_data, ascm, h0,
                                 post_length, type,
                                 q, ns, stat_func) {
  t0 <- ncol(wide_data$X) - post_length
  tpost <- t0 + post_length
  out <- compute_permute_test_stats(wide_data, ascm, h0,
                                    post_length, type, q, ns, stat_func)
  mean(out$stat_func(out$resids[(t0 + 1):tpost]) <= out$test_stats)
}

#' Compute conformal p-value
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param grid Set of null hypothesis to test for inversion
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#' 
#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)
#' @noRd
compute_permute_ci <- function(wide_data, ascm, grid,
                               post_length, alpha, type,
                               q, ns, stat_func) {
  # make sure 0 is in the grid
  grid <- c(grid, 0)
  ps <-sapply(grid, 
              function(x) {
                compute_permute_pval(wide_data, ascm, x, 
                                     post_length, type, q, ns, stat_func)
              })
  c(min(grid[ps >= alpha]), max(grid[ps >= alpha]), ps[grid == 0])
}



#' Compute conformal confidence interval for a linear model for effects
#' int + slope * time
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param grid_int Set of null hypothesis values for the intercept
#' @param grid_slope Set of null hypothesis values for the slope
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#' 
#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)
#' @noRd
compute_permute_ci_linear <- function(wide_data, ascm, grid_int, grid_slope,
                               post_length, alpha, type,
                               q, ns, stat_func) {
  # make sure 0 is in both grids
  # grid_int <- c(grid_int, 0)
  # grid_slope <- c(grid_slope, 0)
  # make the combined grid
  grid_comb <- expand.grid(grid_int, grid_slope)
  grid_comb$p_val <-apply(grid_comb, 1,
              function(x) {
                compute_permute_pval(wide_data, ascm, x[1] + x[2] * (1:post_length), 
                                     post_length, type, q, ns, stat_func)
              })
  ci_int <- c(min(grid_comb[grid_comb$p_val >= alpha, 1]),
              max(grid_comb[grid_comb$p_val >= alpha, 1]))
  ci_slope <- c(min(grid_comb[grid_comb$p_val >= alpha, 2]),
                   max(grid_comb[grid_comb$p_val >= alpha, 2]))
  int_slope_est <- as.numeric(grid_comb[which.max(grid_comb$p_val), 1:2])
  return(list(est_int = int_slope_est[1], ci_int = ci_int,
              est_slope = int_slope_est[2], ci_slope = ci_slope))
}


#' Jackknife+ algorithm over time
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param conservative Whether to use the conservative jackknife+ procedure
#' @return List that contains:
#'         \itemize{
#'          \item{"att"}{Vector of ATT estimates}
#'          \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#'          \item{"se"}{Standard error, always NA but returned for compatibility}
#'          \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#'          \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#'          \item{"alpha"}{Level of confidence interval}
#'         }
time_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative = F) {
    wide_data <- ascm_multi$data
    data_list <- ascm_multi$data_list

    n <- nrow(wide_data$X)
    k <- length(data_list$X)


    t0 <- min(sapply(data_list$X, ncol))
    tpost <- max(sapply(data_list$y, ncol))
    t_final <- t0 + tpost
    Z <- wide_data$Z

    jack_ests <- lapply(1:t0, 
        function(tdrop) {
            # drop unit i
            new_data_list <- drop_time_t_multiout(data_list, Z, tdrop)
            # refit
            new_ascm <- do.call(fit_augsynth_multiout_internal,
                    c(list(wide_list = new_data_list,
                            combine_method = ascm_multi$combine_method,
                            Z = data_list$Z,
                            progfunc = ascm_multi$progfunc,
                            scm = ascm_multi$scm,
                            fixedeff = ascm_multi$fixedeff,
                            outcomes_str = ascm_multi$outcomes),
                        ascm_multi$extra_args))
            # get ATT estimates and held out error for time t
            # t0 is prediction for held out time
            est <- predict(new_ascm, att = F)[(t0 +1):t_final, , drop = F]
            est <- rbind(est, colMeans(est))
            # err <- c(colMeans(wide_data$X[wide_data$trt == 1,
            #                              tdrop,
            #                              drop = F]) -
            #         predict(new_ascm, att = F)[t0])
            err <- c(predict(new_ascm, att = T)[t0, , drop = F])
            list(err, t(t(est) + abs(err)), t(t(est) - abs(err)), t(t(est) + err), est)
        })
    # get errors and jackknife distribution
    held_out_errs <- matrix(vapply(jack_ests, `[[`, numeric(k), 1), nrow = k)
    jack_dist_high <- vapply(jack_ests, `[[`,
                        matrix(0, nrow = tpost + 1, ncol = k), 2)
    jack_dist_low <- vapply(jack_ests, `[[`,
                        matrix(0, nrow = tpost + 1, ncol = k), 3)
    jack_dist_cons <- vapply(jack_ests, `[[`,
                        matrix(0, nrow = tpost + 1, ncol = k), 4)

    out <- list()
    att <- predict(ascm_multi, att = T)
    out$att <- rbind(att, 
                      colMeans(att[(t0 + 1):t_final, , drop = F]))
    # held out ATT

    out$heldout_att <- rbind(t(held_out_errs), 
                              att[(t0 + 1):t_final, , drop = F], 
                              colMeans(att[(t0 + 1):t_final, , drop = F]))
    if(conservative) {
        qerr <- apply(abs(held_out_errs), 1, 
                      stats::quantile, 1 - alpha, type = 1)
        out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),
                        t(t(apply(jack_dist_cons, 1:2, min)) - qerr))
        out$ub <- rbind(matrix(NA, nrow = t0, ncol = k),
                        t(t(apply(jack_dist_cons, 1:2, max)) + qerr))

    } else {
        out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),
                        apply(jack_dist_low, 1:2,
                              stats::quantile, alpha, type = 1))
        out$ub <- rbind(matrix(NA, nrow = t0, ncol = k), 
                        apply(jack_dist_high, 1:2, 
                              stats::quantile, 1 - alpha, type = 1))
    }
    # shift back to ATT scale
    y1 <- predict(ascm_multi, att = F) + att
    y1 <-  rbind(y1, colMeans(y1[(t0 + 1):t_final, , drop = F]))
    shifted_lb <- y1 - out$ub
    shifted_ub <- y1 - out$lb
    out$lb <- shifted_lb
    out$ub <- shifted_ub
    out$alpha <- alpha


    return(out)
}

#' Drop time period from pre-treatment data
#' @param wide_data (X, y, trt)
#' @param Z Covariates matrix
#' @param t_drop Time to drop
#' @noRd
drop_time_t_multiout <- function(data_list, Z, t_drop) {

        new_data_list <- list()
        new_data_list$trt <- data_list$trt
        new_data_list$X <- lapply(data_list$X,
                                  function(x) x[, -t_drop, drop = F])
        new_data_list$y <- lapply(1:length(data_list$y),
                                  function(k) {
                                    cbind(data_list$X[[k]][, t_drop, drop = F], 
                                          data_list$y[[k]])
                                  })
        return(new_data_list)
}


#' Conformal inference procedure to compute p-values and point-wise confidence intervals
#' @param ascm Fitted `augsynth` object
#' @param alpha Confidence level
#' @param stat_func Function to compute test statistic
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param grid_size Number of grid points to use when inverting the hypothesis test (default is 1, so only to test joint null)
#' @return List that contains:
#'         \itemize{
#'          \item{"att"}{Vector of ATT estimates}
#'          \item{"heldout_att"}{Vector of ATT estimates with the time period held out}
#'          \item{"se"}{Standard error, always NA but returned for compatibility}
#'          \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#'          \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#'          \item{"p_val"}{p-value for test of no post-treatment effect}
#'          \item{"alpha"}{Level of confidence interval}
#'         }
conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, 
                                    stat_func = NULL, type = "iid",
                                    q = 1, ns = 1000, grid_size = 1,
                                    lin_h0 = NULL) {
  wide_data <- ascm_multi$data
  data_list <- ascm_multi$data_list

  n <- nrow(wide_data$X)
  k <- length(data_list$X)


  t0 <- min(sapply(data_list$X, ncol))
  tpost <- max(sapply(data_list$y, ncol))
  t_final <- t0 + tpost

  # grid of nulls
  att <- predict(ascm_multi, att = T)
  post_att <- att[(t0 +1):t_final,, drop = F]
  post_sd <- apply(post_att, 2, function(x) sqrt(mean(x ^ 2, na.rm = T)))
  # iterate over post-treatment periods to get pointwise CIs
  
  vapply(1:tpost,
         function(j) {
          # fit using t0 + j as a pre-treatment period and get residuals
          new_data_list <- data_list
          new_data_list$X <- lapply(1:k,
              function(i) {
                Xi <- cbind(data_list$X[[i]], data_list$y[[i]][, j, drop = TRUE])
                colnames(Xi) <- c(colnames(data_list$X[[i]]),
                                  colnames(data_list$y[[i]])[j])
                Xi
          })
          
          
          if(tpost > 1) {
            new_data_list$y <- lapply(1:k,
              function(i) {
                data_list$y[[i]][, -j, drop = FALSE]
            })
          } else {
            # set the post period has to be *something*
            new_data_list$y <- lapply(1:k,
              function(i) {
                x <- matrix(1, nrow = n, ncol = 1)
                colnames(x) <- max(as.numeric(colnames(data_list$y[[i]]))) + 1
                x
            })
          }


          # make a grid around the estimated ATT
          if(is.null(lin_h0)) {
            grid <- lapply(1:k, 
            function(i) {
              seq(att[t0 + j, i] - 2 * post_sd[i], att[t0 + j, i] + 2 * post_sd[i],
                    length.out = grid_size)
            })
          } else {
            grid <- seq(min(att[t0 + j, ]) - 2 * max(post_sd),
                max(att[t0 + j, ]) + 2 * max(post_sd),
                length.out = grid_size)
          }
          if(grid_size > 1) {
            compute_permute_ci_multiout(new_data_list, ascm_multi, grid, 1,
                                    alpha, type, q, ns, lin_h0, stat_func)
          } else {
            rbind(matrix(0, ncol = k, nrow = 2),
              compute_permute_pval_multiout(new_data_list, ascm_multi, numeric(k),
                                          1, type, q, ns, stat_func))
          }
          
         },
         matrix(0, ncol = k, nrow=3)) -> cis
  # # test a null post-treatment effect

  new_data_list <- data_list
  new_data_list$X <- lapply(1:k,
      function(i) {
        Xi <- cbind(data_list$X[[i]], data_list$y[[i]])
        colnames(Xi) <- c(colnames(data_list$X[[i]]),
                          colnames(data_list$y[[i]]))
        Xi
      })
  # set post treatment to be *something*
  new_data_list$y <- lapply(1:k,
      function(i) {
        data_list$y[[i]][, 1, drop = FALSE]
    })
  null_p <- compute_permute_pval_multiout(new_data_list, ascm_multi,
                                          numeric(k), 
                                          tpost, type, q, ns, stat_func)
  if(is.null(lin_h0)) {
    grid <- lapply(1:k, 
            function(i) {
              seq(min(att[(t0 + 1):tpost, i]) - 4 * post_sd[i],
                  max(att[(t0 + 1):tpost, i]) + 4 * post_sd[i],
                    length.out = grid_size)
            })
  } else {
    grid <- seq(min(att[t0 + 1, ]) - 3 * max(post_sd),
                max(att[t0 + 1, ]) + 3 * max(post_sd),
                length.out = grid_size)
  }
  null_ci <- compute_permute_ci_multiout(new_data_list, ascm_multi, grid,
                                          tpost, alpha, type, q, ns,
                                          lin_h0, stat_func)
  out <- list()
  att <- predict(ascm_multi, att = T)
  out$att <- rbind(att, apply(att[(t0 + 1):t_final, , drop = F], 2, mean))
  out$lb <- rbind(matrix(NA, nrow = t0, ncol = k),
                  t(matrix(cis[1, ,], nrow = k)),
                  # rep(NA, k)
                  null_ci[1,]
                  )
  colnames(out$lb) <- ascm_multi$outcomes
  out$ub <- rbind(matrix(NA, nrow = t0, ncol = k),
                  t(matrix(cis[2, ,], nrow = k)),
                  # rep(NA, k)
                  null_ci[2,]
                  )
  colnames(out$ub) <- ascm_multi$outcomes
  out$p_val <- rbind(matrix(NA, nrow = t0, ncol = k),
                  t(matrix(cis[3, ,], nrow = k)),
                  # rep(null_p, k)
                  null_ci[3,])
  colnames(out$p_val) <- ascm_multi$outcomes
  out$alpha <- alpha
  return(out)
}



#' Compute conformal test statistics
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param h0 Null hypothesis to test
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#' 
#' @return List that contains:
#'         \itemize{
#'          \item{"resids"}{Residuals after enforcing the null}
#'          \item{"test_stats"}{Permutation distribution of test statistics}
#'          \item{"stat_func"}{Test statistic function}
#'         }
#' @noRd
compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0,
                                              post_length, type,
                                              q, ns, stat_func) {
  # format data
  new_data_list <- data_list
  t0 <- ncol(data_list$X[[1]]) - post_length
  tpost <- t0 + post_length
  k <- length(data_list$X)
  # adjust outcomes for null
  for(i in 1:k) {
    new_data_list$X[[k]][data_list$trt == 1,(t0 + 1):tpost ] <- new_data_list$X[[k]][data_list$trt == 1,(t0 + 1):tpost] - h0[i]
  }
  # fit synth with adjusted data and get residuals
  new_ascm <- do.call(fit_augsynth_multiout_internal,
                    c(list(wide_list = new_data_list,
                            combine_method = ascm_multi$combine_method,
                            Z = data_list$Z,
                            progfunc = ascm_multi$progfunc,
                            scm = ascm_multi$scm,
                            fixedeff = ascm_multi$fixedeff,
                            outcomes_str = ascm_multi$outcomes),
                        ascm_multi$extra_args))

  resids <- predict(new_ascm, att = T)[1:tpost, , drop = F]

  # permute residuals and compute test statistic
  if(is.null(stat_func)) {
    stat_func <- function(x) {
      x <- na.omit(x)
      (sum(abs(x) ^ q)  / sqrt(length(x))) ^ (1 / q)
    }
  }
  if(type == "iid") {
    test_stats <- sapply(1:ns, 
                        function(x) {
                          idxs <- sample(1:nrow(resids))
                          reorder <- resids[idxs, , drop = F]
                          apply(reorder[(t0 + 1):tpost, ,drop = F], 2, stat_func)
                        })
  } else {
    ## increment time by one step and wrap
    test_stats <- sapply(0:(tpost - 1),
                        function(j) {
                          reorder <- resids[(0:(tpost -1) + j) %% tpost + 1, ,drop = F]
                          if(!all(dim(reorder) == dim(resids))) {
                            stop("Error in block resampling")
                          }
                          apply(reorder[(t0 + 1):tpost, , drop = F], 2, stat_func)
                        })
  }
  
  return(list(resids = resids,
              test_stats = matrix(test_stats, nrow = k),
              stat_func = stat_func))
}


#' Compute conformal p-value
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param h0 Null hypothesis to test
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#' 
#' @return Computed p-value
#' @noRd
compute_permute_pval_multiout <- function(data_list, ascm_multi, h0,
                                        post_length, type,
                                        q, ns, stat_func) {
  t0 <- ncol(data_list$X[[1]]) - post_length
  tpost <- t0 + post_length

  out <- compute_permute_test_stats_multiout(data_list, ascm_multi, h0,
                                          post_length, type, q, ns, stat_func)
  k <- length(data_list$X)

  comb_stat <- mean(apply(out$resids[(t0 + 1):tpost, , drop = F], 2, out$stat_func), na.rm = TRUE)
  comb_test_stats <- apply(out$test_stats, 2, mean, na.rm = TRUE)
  # if(h0 == 0) {
  #   hist(comb_test_stats)
  #   abline(v = comb_stat)
  #   print(mean(comb_stat <= comb_test_stats))
  #   print(1 - mean(comb_stat > comb_test_stats))
  # }
  1 - mean(comb_stat > comb_test_stats)
}

#' Compute conformal p-value
#' @param wide_data List containing pre- and post-treatment outcomes and outcome vector
#' @param ascm Fitted `augsynth` object
#' @param grid Set of null hypothesis to test for inversion
#' @param post_length Number of post-treatment periods
#' @param type Either "iid" for iid permutations or "block" for moving block permutations
#' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)`
#' @param ns Number of resamples for "iid" permutations
#' @param stat_func Function to compute test statistic
#' 
#' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect)
#' @noRd
compute_permute_ci_multiout <- function(data_list, ascm_multi, grid,
                                      post_length, alpha, type,
                                      q, ns, lin_h0 = NULL, stat_func) {
  # make sure 0 is in the grid
  if(is.null(lin_h0)) {
    grid <- lapply(grid, function(x) c(x, 0))
    k <- length(grid)
    # get all combinations of grid
    grid <- expand.grid(grid)
    grid_low <- NULL
  } else {
    k <- length(lin_h0)
    # keep track of low dimensional grid
    grid_low <- c(grid, 0)
    # transform into high dimensional grid with linear hypothesis
    grid <- sapply(lin_h0, function(x) x * grid_low)
  }
  ps <- apply(grid, 1,
              function(x) {
                compute_permute_pval_multiout(data_list, ascm_multi, x, 
                                     post_length, type, q, ns, stat_func)
              })
  sapply(1:k, 
    function(i) c(min(grid[ps >= alpha, i]), 
                  max(grid[ps >= alpha, i]), 
                  ps[apply(grid == 0, 1, all)]))
}




#' Drop unit i from data
#' @param wide_data (X, y, trt)
#' @param Z Covariates matrix
#' @param i Unit to drop
#' @noRd
drop_unit_i <- function(wide_data, Z, i) {

        new_wide_data <- list()
        new_wide_data$trt <- wide_data$trt[-i]
        new_wide_data$X <- wide_data$X[-i,, drop = F]
        new_wide_data$y <- wide_data$y[-i,, drop = F]

        X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F]
        x1 <- matrix(colMeans(new_wide_data$X[new_wide_data$trt == 1,, drop = F]),
                     ncol=1)
        y0 <- new_wide_data$y[new_wide_data$trt == 0,, drop = F]
        y1 <- colMeans(new_wide_data$y[new_wide_data$trt == 1,, drop = F])

        new_synth_data <- list()
        new_synth_data$Z0 <- t(X0)
        new_synth_data$X0 <- t(X0)
        new_synth_data$Z1 <- x1
        new_synth_data$X1 <- x1
        new_Z <- if(!is.null(Z)) Z[-i, , drop = F] else NULL

        return(list(wide_data = new_wide_data,
                    synth_data = new_synth_data,
                    Z = new_Z))
}

#' Drop unit i from data
#' @param wide_list (X, y, trt)
#' @param Z Covariates matrix
#' @param i Unit to drop
#' @noRd
drop_unit_i_multiout <- function(wide_list, Z, i) {

        new_wide_data <- list()
        new_wide_data$trt <- wide_list$trt[-i]
        new_wide_data$X <- lapply(wide_list$X, function(x) x[-i,, drop = F])
        new_wide_data$y <- lapply(wide_list$y, function(x) x[-i,, drop = F])
        new_Z <- if(!is.null(Z)) Z[-i, , drop = F] else NULL

        return(list(wide_list = new_wide_data,
                    Z = new_Z))
}


#' Estimate standard errors for single ASCM with the jackknife
#' Do this for ridge-augmented synth
#' @param ascm Fitted augsynth object
#' 
#' @return List that contains:
#'         \itemize{
#'          \item{"att"}{Vector of ATT estimates}
#'          \item{"se"}{Standard error estimate}
#'          \item{"lb"}{Lower bound of 1 - alpha confidence interval}
#'          \item{"ub"}{Upper bound of 1 - alpha confidence interval}
#'          \item{"alpha"}{Level of confidence interval}
#'         }
jackknife_se_single <- function(ascm) {

    wide_data <- ascm$data
    synth_data <- ascm$data$synth_data
    n <- nrow(wide_data$X)
    n_c <- dim(synth_data$Z0)[2]
    Z <- wide_data$Z

    t0 <- dim(synth_data$Z0)[1]
    tpost <- ncol(wide_data$y)
    t_final <- dim(synth_data$Y0plot)[1]
    errs <- matrix(0, n_c, t_final - t0)


    # only drop out control units with non-zero weights
    nnz_weights <- numeric(n)
    nnz_weights[wide_data$trt == 0] <- round(ascm$weights, 3) != 0
    # if more than one unit is treated, include them in the jackknife
    if(sum(wide_data$trt) > 1) {
      nnz_weights[wide_data$trt == 1] <- 1
    }

    trt_idxs <- (1:n)[as.logical(nnz_weights)]


    # jackknife estimates
    ests <- vapply(trt_idxs,
                   function(i) {
                       # drop unit i
                       new_data <- drop_unit_i(wide_data, Z, i)
                       # refit
                       new_ascm <- do.call(fit_augsynth_internal,
                                c(list(wide = new_data$wide,
                                       synth_data = new_data$synth_data,
                                       Z = new_data$Z,
                                       progfunc = ascm$progfunc,
                                       scm = ascm$scm,
                                       fixedeff = ascm$fixedeff),
                                  ascm$extra_args))
                       # get ATT estimates
                       est <- predict(new_ascm, att = T)[(t0 + 1):t_final]
                       c(est, mean(est))
                   },
                   numeric(tpost + 1))
    # convert to matrix
    ests <- matrix(ests, nrow = tpost + 1, ncol = length(trt_idxs))
    ## standard errors
    se2 <- apply(ests, 1,
                 function(x) (n - 1) / n * sum((x - mean(x, na.rm = T)) ^ 2))
    se <- sqrt(se2)

    out <- list()
    att <- predict(ascm, att = T)
    out$att <- c(att, mean(att[(t0 + 1):t_final]))

    out$se <- c(rep(NA, t0), se)
    # out$sigma <- NA
    return(out)
}


#' Compute standard errors using the jackknife
#' @param multisynth fitted multisynth object
#' @param relative Whether to compute effects according to relative time
#' @noRd
jackknife_se_multi <- function(multisynth, relative=NULL, alpha = 0.05, att_weight = NULL) {
    ## get info from the multisynth object
    if(is.null(relative)) {
        relative <- multisynth$relative
    }
    n_leads <- multisynth$n_leads
    n <- nrow(multisynth$data$X)
    att <- predict(multisynth, att=T, att_weight = att_weight)
    outddim <- nrow(att)

    J <- length(multisynth$grps)

    ## drop each unit and estimate overall treatment effect
    jack_est <- vapply(1:n,
                       function(i) {
                           msyn_i <- drop_unit_i_multi(multisynth, i)
                           pred <- predict(msyn_i[[1]], relative=relative, att=T, att_weight = att_weight)
                           if(nrow(pred) < outddim) {
                               pred <- rbind(
                                   pred[1:(nrow(pred)-1), ],
                                   matrix(NA, nrow=outddim-nrow(pred), ncol=ncol(pred)),
                                   pred[nrow(pred), ]
                               )
                           }
                           if(length(msyn_i[[2]]) != 0) {
                               out <- matrix(NA, nrow=nrow(pred), ncol=(J+1))
                               out[,-(msyn_i[[2]]+1)] <- pred
                           } else {
                               out <- pred
                           }
                           out
                       },
                       matrix(0, nrow=outddim,ncol=(J+1)))

    se2 <- apply(jack_est, c(1,2),
                function(x) (n-1) / n * sum((x - mean(x,na.rm=T))^2, na.rm=T))
    lower_bound <- att - qnorm(1 - alpha / 2) * sqrt(se2)
    upper_bound <- att + qnorm(1 - alpha / 2) * sqrt(se2)
    return(list(att = att, se = sqrt(se2),
                lower_bound = lower_bound, upper_bound = upper_bound))

}

#' Helper function to drop unit i and refit
#' @param msyn multisynth_object
#' @param i Unit to drop
#' @noRd
drop_unit_i_multi <- function(msyn, i) {

    n <- nrow(msyn$data$X)
    time_cohort <- msyn$time_cohort
    which_t <- (1:n)[is.finite(msyn$data$trt)]

    not_miss_j <- which_t %in% setdiff(which_t, i)

    # drop unit i from data
    drop_i <- msyn$data
    drop_i$X <- msyn$data$X[-i, , drop = F]
    drop_i$y <- msyn$data$y[-i, , drop = F]
    drop_i$trt <- msyn$data$trt[-i]
    drop_i$mask <- msyn$data$mask[not_miss_j,, drop = F]

    if(!is.null(msyn$data$Z)) {
      drop_i$Z <- msyn$data$Z[-i, , drop = F]
    } else {
      drop_i$Z <- NULL
    }

    long_df <- msyn$long_df
    unit <- colnames(long_df)[1]
    # make alphabetical, because the ith unit is the index in alphabetical ordering
    long_df <- long_df[order(long_df[, unit, drop = TRUE]),]
    ith_unit <- unique(long_df[,unit, drop = TRUE])[i]
    long_df <- long_df[long_df[,unit, drop = TRUE] != ith_unit,]

    # re-fit everything
    args_list <- list(wide = drop_i, relative = msyn$relative,
                      n_leads = msyn$n_leads, n_lags = msyn$n_lags,
                      nu = msyn$nu, lambda = msyn$lambda,
                      V = msyn$V,
                      force = msyn$force, n_factors = msyn$n_factors,
                      scm = msyn$scm, time_w = msyn$time_w,
                      lambda_t = msyn$lambda_t,
                      fit_resids = msyn$fit_resids,
                      time_cohort = msyn$time_cohort, long_df = long_df,
                      how_match = msyn$how_match)
    msyn_i <- do.call(multisynth_formatted, c(args_list, msyn$extra_pars))

    # check for dropped treated units/time periods
    if(time_cohort) {
        dropped <- which(!msyn$grps %in% msyn_i$grps)
    } else {
        dropped <- which(!not_miss_j)
    }
    return(list(msyn_i,
                dropped))
}


#' Estimate standard errors for multi outcome ascm with jackknife
#' @param ascm Fitted augsynth object
#' @noRd
jackknife_se_multiout <- function(ascm) {

    wide_data <- ascm$data
    wide_list <- ascm$data_list
    n <- nrow(wide_data$X)
    Z <- wide_data$Z


    # only drop out control units with non-zero weights
    nnz_weights <- numeric(n)
    nnz_weights[wide_data$trt == 0] <- round(ascm$weights, 3) != 0

    trt_idxs <- (1:n)[as.logical(nnz_weights)]

    # jackknife estimates
    ests <- lapply(trt_idxs,
                   function(i) {
                       # drop unit i
                       new_data <- drop_unit_i_multiout(wide_list, Z, i)
                       # refit
                       new_ascm <- do.call(fit_augsynth_multiout_internal,
                                c(list(wide = new_data$wide,
                                       combine_method = ascm$combine_method,
                                       Z = new_data$Z,
                                       progfunc = ascm$progfunc,
                                       scm = ascm$scm,
                                       fixedeff = ascm$fixedeff,
                                       outcomes_str = ascm$outcomes),
                                  ascm$extra_args))
                        new_ascm$outcomes <- ascm$outcomes
                        new_ascm$data_list <- ascm$data_list
                        new_ascm$data$time <- ascm$data$time
                       # get ATT estimates
                       est <- predict(new_ascm, att = T)
                       est <- est[as.numeric(rownames(est)) >= ascm$t_int,, drop = F]
                       rbind(est, colMeans(est, na.rm = T))
                   })
    ests <- simplify2array(ests)
    ## standard errors
    se2 <- apply(ests, c(1, 2),
                 function(x) (n - 1) / n * sum((x - mean(x, na.rm = T)) ^ 2))
    se <- sqrt(se2)
    out <- list()
    att <- predict(ascm, att = T)
    att_post <- colMeans(att[as.numeric(rownames(att)) >= ascm$t_int,, drop = F],
                         na.rm = T)
    out$att <- rbind(att, att_post)
    t0 <- sum(as.numeric(rownames(att)) < ascm$t_int)
    out$se <- rbind(matrix(NA, t0, ncol(se)), se)
    out$sigma <- NA
    return(out)
}



#' Compute the weighted bootstrap distribution
#' @param multisynth fitted multisynth object
#' @param rweight Function to draw random weights as a function of n (e.g rweight(n))
#' @param relative Whether to compute effects according to relative time
#' @noRd
weighted_bootstrap_multi <- function(multisynth,
                                    rweight = rwild_b,
                                    n_boot = 1000,
                                    alpha = 0.05,
                                    att_weight = NULL,
                                    relative=NULL) {
  ## get info from the multisynth object
  if(is.null(relative)) {
      relative <- multisynth$relative
  }

  n <- nrow(multisynth$data$X)
  att <- predict(multisynth, att=T, att_weight = att_weight)
  outddim <- nrow(att)
  n1 <- sum(is.finite(multisynth$data$trt))
  J <- length(multisynth$grps)


  # draw random weights to get bootstrap distribution
  bs_est <- vapply(1:n_boot,
                      function(i) {
                        Z <- rweight(n)# / sqrt(n1)

                        predict(multisynth, att=T, att_weight = att_weight, bs_weight = Z) - sum(Z) / n1 * att
                      },
                      matrix(0, nrow=outddim,ncol=(J+1)))

  se2 <- apply(bs_est, c(1,2),
              function(x) mean((x - mean(x))^2, na.rm=T))
  bias <- apply(bs_est, c(1,2),
              function(x) mean(x, na.rm=T))
  upper_bound <- att - apply(bs_est, c(1,2),
              function(x) quantile(x, alpha / 2, na.rm = T))
  
  lower_bound <- att - apply(bs_est, c(1,2),
              function(x) quantile(x, 1 - alpha / 2, na.rm = T))

  return(list(att = att,
              bias = bias,
              se = sqrt(se2),
              upper_bound = upper_bound,
              lower_bound = lower_bound))

}

#' Bayesian bootstrap
#' @param n Number of units
#' @export
rdirichlet_b <- function(n) {
  Z <- as.numeric(rgamma(n, 1, 1))
  return(Z / sum(Z) * n)
}

#' Non-parametric bootstrap
#' @param n Number of units
#' @export
rmultinom_b <- function(n) as.numeric(rmultinom(1, n, rep(1 / n, n)))

#' Wild bootstrap (Mammen 1993)
#' @param n Number of units
#' @export
rwild_b <- function(n) {
  sample(c(-(sqrt(5) - 1) / 2, (sqrt(5) + 1) / 2 ), n,
         replace = TRUE,
         prob = c((sqrt(5) + 1)/ (2 * sqrt(5)), (sqrt(5) - 1) / (2 * sqrt(5))))
}

================================================
FILE: R/multi_outcomes.R
================================================
#' Fit Augmented SCM with multiple outcomes
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @param progfunc What function to use to impute control outcomes
#'                 Ridge=Ridge regression (allows for standard errors),
#'                 None=No outcome model,
#' @param scm Whether the SCM weighting function is used
#' @param fixedeff Whether to include a unit fixed effect, default F 
#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted
#' @param combine_method How to combine outcomes: `concat` concatenates outcomes and `avg` averages them, default: 'avg'
#' @param ... optional arguments for outcome model
#'
#' @return augsynth object that contains:
#'         \itemize{
#'          \item{"weights"}{Ridge ASCM weights}
#'          \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#'          \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#'          \item{"mhat"}{Outcome model estimate}
#'          \item{"data"}{Panel data as matrices}
#'         }
#' @export
augsynth_multiout <- function(form, unit, time, t_int, data,
                              progfunc=c("Ridge", "None"),
                              scm=T,
                              fixedeff = FALSE,
                              cov_agg=NULL,
                              combine_method = "avg",
                              ...) {
    call_name <- match.call()

    form <- Formula::Formula(form)
    unit <- enquo(unit)
    time <- enquo(time)
    
    ## format data
    outcome <- terms(formula(form, rhs=1))[[2]]
    trt <- terms(formula(form, rhs=1))[[3]]

    outcomes_str <- all.vars(outcome)
    outcomes <- sapply(outcomes_str, quo)
    # get outcomes as a list
    wide_list <- format_data_multi(outcomes, trt, unit, time, t_int, data)
    

    

    ## add covariates
    if(length(form)[2] == 2) {
        cov_form <- paste(deparse(terms(formula(form, rhs = 2))[[3]]), collapse = "")
        new_form <- as.formula(paste("~", cov_form))
        Z <- extract_covariates(new_form, unit, time, t_int, data, cov_agg)
    } else {
        Z <- NULL
    }

    # only allow ridge augmentation
    if(! tolower(progfunc) %in% c("none", "ridge")) {
      stop(paste(progfunc, "is not a valid augmentation function with multiple outcomes. Only `none` or `ridge` are allowable options for `prog_func`"))
    }

    # fit augmented SCM
    augsynth <- fit_augsynth_multiout_internal(wide_list, combine_method, Z,
                                               progfunc, scm,
                                               fixedeff, outcomes_str, ...)

    # add some extra data
    augsynth$data$time <- data %>% distinct(!!time) %>% pull(!!time)
    augsynth$call <- call_name
    augsynth$t_int <- t_int 
    augsynth$combine_method <- combine_method

    treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)
    control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% 
                        distinct(!!unit) %>% pull(!!unit)
    augsynth$weights <- matrix(augsynth$weights)
    rownames(augsynth$weights) <- control_units

    return(augsynth)
}


#' Internal function to fit augmented SCM with multiple outcomes
#' @param wide_list List of matrices for each outcome formatted from format_data
#' @param combine_method How to combine outcomes
#' @param Z Matrix of auxiliary covariates
#' @param progfunc outcome model to use
#' @param scm Whether to fit SCM
#' @param fixedeff Whether to de-mean synth
#' @param ... Extra args for outcome model
#' @noRd
fit_augsynth_multiout_internal <- function(wide_list, combine_method, Z,
                                           progfunc, scm, fixedeff, 
                                           outcomes_str, ...) {


    # combine into a matrix for fitting and balancing
    out <- combine_outcomes(wide_list, combine_method, fixedeff, ...)
    wide_bal <- out$wide_bal
    mhat <- out$mhat
    V <- out$V
    synth_data <- do.call(format_synth, wide_bal)

    # set Y1 and Y0plot to be raw concatenated outcomes
    X <- do.call(cbind, wide_list$X)
    y <- do.call(cbind, wide_list$y)
    trt <- wide_list$trt
    synth_data$Y0plot <- t(cbind(X, y)[trt == 0,, drop = F])
    synth_data$Y1plot <- colMeans(cbind(X, y)[trt == 1,, drop = F])


    augsynth <- fit_augsynth_internal(wide_bal, synth_data, Z, progfunc, 
                                      scm, fixedeff = F, V = V, ...)

    # potentially add back in fixed effects
    augsynth$mhat <- mhat# + augsynth$mhat

    augsynth$data <- list(X = X, trt = trt, y = y, Z = Z)
    augsynth$data_list <- wide_list
    augsynth$outcomes <- outcomes_str
    # change fixedeff flag to match input (rather than fixedeff = F in fit_augsynth_internal)
    augsynth$fixedeff <- fixedeff
    ##format output
    class(augsynth) <- c("augsynth_multiout", "augsynth")
    return(augsynth)
}

#' Helper function to combine multiple outcomes into a single balance matrix
#' @param wide_list List of lists of pre/post treatment data for each outcome
#' @param combine_method How to combine outcomes
#' @param fixedeff Whether to take out unit fixed effects or not
#' @param nu Weighting between concatenated and averaged objectives
#' @param ... Extra arguments for combination
#' @noRd
#' @return \itemize{
#'          \item{"X"}{Matrix of combined pre-treatment outcomes}
#'          \item{"trt"}{Vector of treatment assignments}
#'          \item{"y"}{Matrix of combined post-treatment outcomes}
#'         }
combine_outcomes <- function(wide_list, combine_method, fixedeff,
                             nu = NULL, ...) {
  n_outs <- length(wide_list$X)
  n_units <- Map(nrow, wide_list$X) %>% Reduce(max, .)
  # take out unit fixed effects
  demean_j <- function(j) {
    means <- rowMeans(wide_list$X[[j]], na.rm = TRUE)

    new_wide_data <- list()
    new_X <- wide_list$X[[j]] - means
    new_y <- wide_list$y[[j]] - means

    new_wide_data$X <- new_X
    new_wide_data$y <- new_y
    new_wide_data$mhat_pre <- replicate(ncol(wide_list$X[[j]]),
                                        means)
    new_wide_data$mhat_post <- replicate(ncol(wide_list$y[[j]]),
                                        means)
    return(new_wide_data)
  }
  if(fixedeff) {
    new_wide_list <- lapply(1:n_outs, demean_j)
    wide_list$X <- lapply(new_wide_list, function(x) x$X)
    wide_list$y <- lapply(new_wide_list, function(x) x$y)
    mhat_pre <- lapply(new_wide_list, function(x) x$mhat_pre)
    mhat_post <- lapply(new_wide_list, function(x) x$mhat_post)
  } else {
    mhat_pre <- lapply(
      1:n_outs,
      function(j) matrix(0, nrow = n_units, ncol = ncol(wide_list$X[[j]])))
    mhat_post <- lapply(
      1:n_outs,
      function(j) matrix(0, nrow = n_units, ncol = ncol(wide_list$y[[j]])))
  }

    # combine outcomes
    if(combine_method == "concat") {
      # center X and scale by overall variance for outcome
      # X <- lapply(wide_list$X, function(x) t(t(x) - colMeans(x)) / sd(x))
      wide_bal <- list(X = do.call(cbind, lapply(wide_list$X, function(x) t(na.omit(t(x))))),
                        y = do.call(cbind, lapply(wide_list$y, function(x) t(na.omit(t(x))))),
                        trt = wide_list$trt)

      # V matrix scales by inverse variance for outcome and number of periods
      V <- do.call(c, 
          lapply(wide_list$X, 
            function(x) rep(1 / (sqrt(nrow(na.omit(t(x)))) * 
                    sd(x[wide_list$trt == 0, , drop = F], na.rm=T)),
                    nrow(na.omit(t(x))))))

    # } else if(combine_method == "svd") {
    #     wide_bal <- list(X = do.call(cbind, wide_list$X),
    #                      y = do.call(cbind, wide_list$y),
    #                      trt = wide_list$trt)

    #     # first get the standard deviations of the outcomes to put on the same scale
    #     sds <- do.call(c, 
    #         lapply(wide_list$X, 
    #             function(x) rep((sqrt(ncol(x)) * sd(x, na.rm=T)), ncol(x))))

    #     # do an SVD on centered and scaled outcomes
    #     X0 <- wide_bal$X[wide_bal$trt == 0, , drop = FALSE]
    #     X0 <- t((t(X0) - colMeans(X0)) / sds)
    #     k <- if(is.null(k)) ncol(X0) else k
    #     V <- diag(1 / sds) %*% svd(X0)$v[, 1:k, drop = FALSE]
  } else if(combine_method == "avg") {
      # average pre-treatment outcomes, dividing by standard deviation and removing missing values
      X_avg <- rowMeans(simplify2array(lapply(wide_list$X,
                                  function(x) (x - mean(x[wide_list$trt == 0,], na.rm = TRUE)) / sd(x[wide_list$trt == 0,], na.rm = TRUE))), dims = 2, na.rm = TRUE)
      # remove any time periods with NAs
      X_avg <- t(na.omit(t(X_avg)))
    wide_bal <- list(X = X_avg,
        y = rowMeans(simplify2array(wide_list$y), dims = 2, na.rm = TRUE),
        trt = wide_list$trt)

    V <- diag(ncol(wide_bal$X))

  } else if(combine_method == "avg_concat") {
      # average pre-treatment outcomes, dividing by standard deviation and removing missing values
      # standardize the outcomes
      X_list_std<- lapply(wide_list$X,function(x) (x - mean(x[wide_list$trt == 0,], na.rm = TRUE)) / sd(x[wide_list$trt == 0,], na.rm = TRUE))


      X_avg <- rowMeans(simplify2array(X_list_std), dims = 2, na.rm = TRUE)
      # remove any time periods with NAs
      X_avg <- t(na.omit(t(X_avg)))

      X_concat <- do.call(cbind, lapply(X_list_std, function(x) t(na.omit(t(x)))))

      # V matrix assigns weight nu to the averaged objective and (1 - nu) to the concatenated objective
      # V <- c(rep(sqrt(nu), ncol(X_avg)),
      #         sqrt(1 - nu) / sqrt(n_outs) * do.call(c, 
      #             lapply(wide_list$X,
      #               function(x) rep(1 / (sqrt(nrow(na.omit(t(x)))) * 
      #                       sd(x[wide_list$trt == 0, , drop = F], na.rm=T)),
      #                       nrow(na.omit(t(x))))))
      # )
      V <- c(rep(sqrt(nu), ncol(X_avg)), rep(sqrt(1 - nu) / sqrt(n_outs), ncol(X_concat)))
      wide_bal <- list(
        X = cbind(X_avg, X_concat),
        y = do.call(cbind, lapply(wide_list$y, function(x) t(na.omit(t(x))))),
        trt = wide_list$trt
      )
    } else {
        stop(paste("combine_method should be one of ('avg', 'concat', 'avg_concat'),", 
            combine_method, " is not a valid combining option"))
    }

  mhat_pre <- do.call(cbind, mhat_pre)
  mhat_post <- do.call(cbind, mhat_post)
  mhat <- cbind(mhat_pre, mhat_post)

  return(list(wide_bal = wide_bal, mhat = mhat, V = V))

}

#' Get prediction of ATT or average outcome under control
#' @param object augsynth_multiout object
#' @param ... Optional arguments, including \itemize{\item{"att"}{Whether to return the ATT or average outcome under control}}
#'
#' @return Vector of predicted post-treatment control averages
#' @export
predict.augsynth_multiout <- function(object, ...) {
    if ("att" %in% names(list(...))) {
        att <- list(...)$att
    } else {
        att <- F
    }

    # call augsynth predict
    pred <- NextMethod()

    # separate out by outcome
    n_outs <- length(object$data_list$X)
    max_t <- max(sapply(1:n_outs, 
      function(k) ncol(object$data_list$X[[k]]) + ncol(object$data_list$y[[k]])))
    pred_reshape <- matrix(NA, ncol = n_outs, 
                               nrow = max_t)
    colnames <- lapply(1:n_outs, 
      function(k) colnames(cbind(object$data_list$X[[k]], 
                                 object$data_list$y[[k]])))
    rownames(pred_reshape) <- colnames[[which.max(sapply(colnames, length))]]
    colnames(pred_reshape) <- object$outcomes
    # get outcome names for predictions

    pre_outs <- do.call(c, 
                        sapply(1:n_outs, 
                               function(j) {
                                   rep(object$outcomes[j],
                                       ncol(object$data_list$X[[j]]))
                               }, simplify = FALSE))
    
    post_outs <- do.call(c,
                         sapply(1:n_outs, 
                                function(j) {
                                    rep(object$outcomes[j],
                                        ncol(object$data_list$y[[j]]))
                               }, simplify = FALSE))
    # print(pred)
    # print(cbind(names(pred), c(pre_outs, post_outs)))
    
    pred_reshape[cbind(names(pred), c(pre_outs, post_outs))] <- pred
    return(pred_reshape)
}


#' Print function for augsynth
#' @param x augsynth_multiout object
#' @param ... Optional arguments
#' @export
print.augsynth_multiout <- function(x, ...) {
    ## straight from lm
    cat("\nCall:\n", paste(deparse(x$call), sep="\n", collapse="\n"), "\n\n", sep="")

    ## print att estimates
    att <- predict(x, att = T)
    att_post <- data.frame(
        colMeans(att[as.numeric(rownames(att)) >= x$t_int,, drop = F]))
    names(att_post) <- c("")
    cat("Average ATT Estimate:\n")
    print(att_post)
    cat("\n\n")
}

#' Summary function for augsynth
#' @param object augsynth_multiout object
#' @param inf whether or not to perform inference
#' @param inf_typ Type of inference, default is "conformal"
#' @param grid_size Grid to compute prediction intervals over, default is 1 and only p-values are computed
#' @param ... Optional arguments, including \itemize{\item{"se"}{Whether to plot standard error}}
#' @export
summary.augsynth_multiout <- function(object, inf = T, inf_type = "conformal", grid_size = 1, ...) {
    

    summ <- list()

    if(inf) {
        if(inf_type == "conformal") {
          if(grid_size > 1) {
            cat(paste0("A grid size of ", grid_size, " will require ",
                           grid_size, "^", length(object$outcomes),
                           " = ", grid_size ^ length(object$outcomes),
                           " evaluations. This could take a while..."))
          }
          att_se <- conformal_inf_multiout(object, grid_size = grid_size, ...)
        } else {
          stop("Only conformal inference is supported for multiple outcomes")
        }
        # if(inf_type == "jackknife") {
        #     att_se <- jackknife_se_multiout(object)
        # } else if(inf_type == "jackknife+") {
        #   att_se <- time_jackknife_plus_multiout(object, ...)
        # } else if(inf_type == "conformal") {
        #   att_se <- conformal_inf_multiout(object, ...)
        # } else {
        #     stop(paste(inf_type, "is not a valid choice of 'inf_type'"))
        # }

        t_final <- nrow(att_se$att)

        att_df <- data.frame(att_se$att[1:(t_final - 1),, drop=F])
        names(att_df) <- object$outcomes
        att_df$Time <- object$data$time
        att_df <- att_df %>% gather(Outcome, Estimate, -Time)

        # if(inf_type == "jackknife") {
        #   se_df <- data.frame(att_se$se[1:(t_final - 1),, drop=F])
        #   names(se_df) <- object$outcomes
        #   se_df$Time <- object$data$time
        #   se_df <- se_df %>% gather(Outcome, Std.Error, -Time)

        #   att <- inner_join(att_df, se_df, by = c("Time", "Outcome"))
        # } else if(inf_type %in% c("conformal", "jackknife+")) {
          
        lb_df <- data.frame(att_se$lb[1:(t_final - 1),, drop=F])
        names(lb_df) <- object$outcomes
        lb_df$Time <- object$data$time
        lb_df <- lb_df %>% gather(Outcome, lower_bound, -Time)

        ub_df <- data.frame(att_se$ub[1:(t_final - 1),, drop=F])
        names(ub_df) <- object$outcomes
        ub_df$Time <- object$data$time
        ub_df <- ub_df %>% gather(Outcome, upper_bound, -Time)

        att <- inner_join(att_df, lb_df, by = c("Time", "Outcome")) %>%
            inner_join(ub_df, by = c("Time", "Outcome")) 
          # if(inf_type == "conformal") {

          pval_df <- data.frame(att_se$p_val[1:(t_final - 1),, drop=F])
          names(pval_df) <- object$outcomes
          pval_df$Time <- object$data$time
          pval_df <- pval_df %>% gather(Outcome, p_val, -Time)
          att <- inner_join(att, pval_df, by = c("Time", "Outcome")) 
          # }
        # }
        if(grid_size == 1) {
          att <- att %>% mutate(lower_bound = NA, upper_bound = NA)
        }

        att_avg <- data.frame(att_se$att[t_final,, drop = F])
        names(att_avg) <- object$outcomes
        att_avg <- gather(att_avg, Outcome, Estimate)

        # if(inf_type == "jackknife") {
        #   att_avg_se <- data.frame(att_se$se[t_final,, drop = F])
        #   names(att_avg_se) <- object$outcomes
        #   att_avg_se <- gather(att_avg_se, Outcome, Std.Error)
        #   average_att <- inner_join(att_avg, att_avg_se, by="Outcome")
        # } else if(inf_type %in% c("conformal", "jackknife+")){
        att_avg_lb <- data.frame(att_se$lb[t_final,, drop = F])
        names(att_avg_lb) <- object$outcomes
        att_avg_lb <- gather(att_avg_lb, Outcome, lower_bound)

        att_avg_ub <- data.frame(att_se$ub[t_final,, drop = F])
        names(att_avg_ub) <- object$outcomes
        att_avg_ub <- gather(att_avg_ub, Outcome, upper_bound)
        

        average_att <- inner_join(att_avg, att_avg_lb, by="Outcome") %>%
            inner_join(att_avg_ub, by = "Outcome")
          
          # if(inf_type == "conformal") {
        att_avg_pval <- data.frame(att_se$p_val[t_final,, drop = F])
        names(att_avg_pval) <- object$outcomes
        att_avg_pval <- gather(att_avg_pval, Outcome, p_val)

        average_att <- inner_join(average_att, att_avg_pval, by = "Outcome")

        if(grid_size == 1) {
          average_att <- average_att %>% mutate(lower_bound = NA, upper_bound = NA)
        }
          # }
        # } else {
        #   average_att <- gather(att_avg, Outcome, Estimate)
        # }
        

    } else {
        att_est <- predict(object, att = T)
        att_df <- data.frame(att_est)
        names(att_df) <- object$outcomes
        att_df$Time <- object$data$time
        att <- att_df %>% gather(Outcome, Estimate, -Time)
        att$Std.Error <- NA
        t_int <- min(sapply(object$data_list$X, ncol))
        att_avg <- data.frame(t(colMeans(
            att_est[t_int:nrow(att_est),, drop = F])))
        print(att_avg)
        names(att_avg) <- object$outcomes
        average_att <- gather(att_avg, Outcome, Estimate)
        average_att$Std.Error <- NA
    }

      # get average of all outcomes
    sds <- data.frame(Outcome = object$outcomes,
                      sdo = sapply(object$data_list$X,
                                    function(x)  sd(x[object$data_list$trt == 0,], na.rm = TRUE)))

    att %>%
      inner_join(sds, by = "Outcome") %>%
      mutate(Estimate = Estimate / sdo) %>%
      group_by(Time) %>%
      summarise(Estimate = mean(Estimate, na.rm = TRUE)) %>%
      mutate(Outcome = "Average") %>%
      bind_rows(att, .) -> att

    summ$att <- att
    summ$average_att <- average_att
    summ$t_int <- object$t_int
    summ$call <- object$call
    summ$l2_imbalance <- object$l2_imbalance
    summ$scaled_l2_imbalance <- object$scaled_l2_imbalance
    summ$inf_type <- inf_type
    ## get estimated bias

    if(object$progfunc == "Ridge") {
        mhat <- object$ridge_mhat
        w <- object$synw
    } else {
        mhat <- object$mhat
        w <- object$weights
    }
    trt <- object$data$trt
    m1 <- colMeans(mhat[trt==1,,drop=F])

    summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w
    if(object$progfunc == "None" | (!object$scm)) {
        summ$bias_est <- NA
    }

    

    
    class(summ) <- "summary.augsynth_multiout"
    return(summ)
}


#' Print function for summary function for augsynth
#' @param x summary.augsynth_multiout object
#' @param ... Optional arguments
#' @export
print.summary.augsynth_multiout <- function(x, ...) {
    ## straight from lm
    cat("\nCall:\n", paste(deparse(x$call), sep="\n", collapse="\n"), "\n\n", sep="")
    
    att_est <- x$att$Estimate
    ## get pre-treatment fit by outcome
    imbal <- x$att %>% 
        filter(Time < x$t_int) %>%
        group_by(Outcome) %>%
        summarise(Pre.RMSE = sqrt(mean(Estimate ^ 2, na.rm = TRUE)))

    cat(paste("Overall L2 Imbalance (Scaled):",
              format(round(x$l2_imbalance,3), nsmall=3), "  (",
              format(round(x$scaled_l2_imbalance,3), nsmall=3), ")\n\n",
            #   "Avg Estimated Bias: ",
            #   format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n",
              sep=""))
    cat("Average ATT Estimate:\n")
    print(inner_join(x$average_att, imbal, by = "Outcome"))
    cat("\n\n")
}


#' Plot function for summary function for augsynth
#' @importFrom graphics plot
#' @param x summary.augsynth_multiout object
#' @param inf Boolean, whether to plot uncertainty intervals, default TRUE
#' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE
#' @param ... Optional arguments for summary function
#' 
#' @export
plot.augsynth_multiout  <- function(x, inf = T, plt_avg = F, ...) {
  plot(summary(x, ...), inf =  inf, plt_avg = plt_avg)
}

#' Plot function for summary function for augsynth
#' @param x summary.augsynth_multiout object
#' @param inf Boolean, whether to plot uncertainty intervals, default TRUE
#' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE
#' 
#' @export
plot.summary.augsynth_multiout <- function(x, inf = F, plt_avg = F, ...) {
    if(plt_avg) {
      p <- x$att %>%
        ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
    } else {
      p <- x$att %>%
        filter(Outcome != "Average") %>% 
        ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
    }
    if(inf) {
      if(x$inf_type == "jackknife") {
        p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error,
                        ymax=Estimate+2*Std.Error),
                    alpha=0.2, data = . %>% filter(Outcome != "Average"))
      } else if(x$inf_type %in% c("conformal", "jackknife+")) {
        p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound,
                        ymax=upper_bound),
                    alpha=0.2, data =  . %>% filter(Outcome != "Average"))
      }

    }
    p + ggplot2::geom_line() +
        ggplot2::geom_vline(xintercept=x$t_int, lty=2) +
        ggplot2::geom_hline(yintercept=0, lty=2) +
        ggplot2::facet_wrap(~ Outcome, scales = "free_y") +
        ggplot2::theme_bw()

}

================================================
FILE: R/multi_synth_qp.R
================================================
################################################################################
## Solve the multisynth problem as a QP
################################################################################


#' Internal function to fit synth with staggered adoption with a QP solver
#' @param X Matrix of pre-final intervention outcomes, or list of such matrices after transformations
#' @param trt Vector of treatment levels/times
#' @param mask Matrix with indicators for observed pre-intervention times for each treatment group
#' @param n_leads Number of time periods after treatment to impute control values.
#'            For units treated at time T_j, all units treated after T_j + n_leads
#'            will be used as control values. If larger than the number of periods,
#'            only never never treated units (pure controls) will be used as comparison units
#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods
#' @param relative Whether to re-index time according to treatment date, default T
#' @param nu Hyper-parameter that controls trade-off between overall and individual balance.
#'              Larger values of nu place more emphasis on individual balance.
#'              Balance measure is
#'                nu ||global|| + (1-nu) ||individual||
#'              Default: 0
#' @param lambda Regularization hyper-parameter. Default, 0
#' @param time_cohort Whether to average synthetic controls into time cohorts
#' @param norm_pool Normalizing value for pooled objective, default: number of treated units squared
#' @param norm_sep Normalizing value for separate objective, default: number of treated units
#' @param verbose Whether to print logs for osqp
#' @param eps_rel Relative error tolerance for osqp
#' @param eps_abs Absolute error tolerance for osqp
#' @noRd
#' @return \itemize{
#'          \item{"weights"}{Matrix of unit weights}
#'          \item{"imbalance"}{Matrix of overall and group specific imbalance}
#'          \item{"global_l2"}{Imbalance overall}
#'          \item{"ind_l2"}{Matrix of imbalance for each group}
#'         }
multisynth_qp <- function(X, trt, mask, Z = NULL, n_leads=NULL, n_lags=NULL,
                          relative=T, nu=0, lambda=0, V = NULL, time_cohort = FALSE,
                          donors = NULL, norm_pool = NULL, norm_sep = NULL,
                          verbose = FALSE, 
                          eps_rel=1e-4, eps_abs=1e-4) {

    # if Z has no columns then set it to NULL
    if(!is.null(Z)) {
      if(ncol(Z) == 0) {
        Z <- NULL
      }
    }

    n <- if(typeof(X) == "list") dim(X[[1]])[1] else dim(X)[1]
    d <- if(typeof(X) == "list") dim(X[[1]])[2] else dim(X)[2]

    if(is.null(n_leads)) {
        n_leads <- d+1
    } else if(n_leads > d) {
        n_leads <- d+1
    }
    if(is.null(n_lags)) {
        n_lags <- d
    } else if(n_lags > d) {
        n_lags <- d
    }
    V <- make_V_matrix(n_lags, V)
    ## treatment times
    if(time_cohort) {
        grps <- unique(trt[is.finite(trt)])
        which_t <- lapply(grps, function(tj) (1:n)[trt == tj])
        # if doing a time cohort, convert the boolean mask
        mask <- unique(mask)
    } else {
        grps <- trt[is.finite(trt)]
        which_t <- (1:n)[is.finite(trt)]
    }


    J <- length(grps)
    if(is.null(norm_sep)) {
      norm_sep <- 1#J
    }
    if(is.null(norm_pool)) {
      norm_pool <- 1#J ^ 2
    }
    n1 <- sapply(1:J, function(j) length(which_t[[j]]))

    # if no specific donors passed in,
    # then all donors treated after n_lags are eligible
    if(is.null(donors)) {
      donors <- get_eligible_donors(trt, time_cohort, n_leads)
    }

    ## handle X differently if it is a list
    if(typeof(X) == "list") {
        x_t <- lapply(1:J, function(j) colSums(X[[j]][which_t[[j]], mask[j,]==1, drop=F]))
        
        # Xc contains pre-treatment data for valid donor units
        Xc <- lapply(1:nrow(mask),
                 function(j) X[[j]][donors[[j]], mask[j,]==1, drop=F])
        
        # std dev of outcomes for first treatment time
        sdx <- sd(X[[1]][is.finite(trt)])
    } else {
        x_t <- lapply(1:J, function(j) colSums(X[which_t[[j]], mask[j,]==1, drop=F]))        
        
        # Xc contains pre-treatment data for valid donor units
        Xc <- lapply(1:nrow(mask),
                 function(j) X[donors[[j]], mask[j,]==1, drop=F])

        # std dev of outcomes
        sdx <- sd(X[is.finite(trt)])
    }

    # get covariates for donors
    if(!is.null(Z)) {
      # scale covariates to have same variance as pure control outcomes
      Z_scale <- sdx * apply(Z, 2, 
        function(z) (z - mean(z[!is.finite(trt)])) / sd(z[!is.finite(trt)]))

      z_t <- lapply(1:J, function(j) colSums(Z_scale[which_t[[j]], , drop = F]))
      Zc <- lapply(1:J, function(j) Z_scale[donors[[j]], , drop = F])
    } else {
      z_t <- lapply(1:J, function(j) c(0))
      Zc <- lapply(1:J, function(j) Matrix::Matrix(0,
                                                   nrow = sum(donors[[j]]),
                                                   ncol = 1))
    }
    dz <- ncol(Zc[[1]])

    # replace NA values with zero
    x_t <- lapply(x_t, function(xtk) tidyr::replace_na(xtk, 0))
    Xc <- lapply(Xc, function(xck) apply(xck, 2, tidyr::replace_na, 0))
    ## make matrices for QP
    n0s <- sapply(Xc, nrow)
    if(any(n0s == 0)) {
      stop("Some treated units have no possible donor units!")
    }
    n0 <- sum(n0s)

    const_mats <- make_constraint_mats(trt, grps, n_leads, n_lags, Xc, Zc, d, n1)
    Amat <- const_mats$Amat
    lvec <- const_mats$lvec
    uvec <- const_mats$uvec

    ## quadratic balance measures

    qvec <- make_qvec(Xc, x_t, z_t, nu, n_lags, d, V, norm_pool, norm_sep)

    Pmat <- make_Pmat(Xc, x_t, dz, nu, n_lags, lambda, d, V, norm_pool, norm_sep)

    ## Optimize
    settings <- do.call(osqp::osqpSettings, 
                        c(list(verbose = verbose, 
                               eps_rel = eps_rel, 
                               eps_abs = eps_abs)))

    out <- osqp::solve_osqp(Pmat, qvec, Amat, lvec, uvec, pars = settings)

    ## get weights
    total_ctrls <- n0 * J
    weights <- matrix(out$x[1:total_ctrls], nrow = n0)
    nj0 <- as.numeric(lapply(Xc, nrow))
    nj0cumsum <- c(0, cumsum(nj0))
    imbalance <- vapply(1:J,
                        function(j) {
                            dj <- length(x_t[[j]])
                            ndim <- min(dj, n_lags)
                            c(numeric(d-ndim),
                            x_t[[j]][(dj-ndim+1):dj] -
                                t(Xc[[j]][,(dj-ndim+1):dj, drop = F]) %*% 
                                out$x[(nj0cumsum[j] + 1):nj0cumsum[j + 1]])
                        },
                        numeric(d))
    avg_imbal <- rowMeans(t(t(imbalance)))

    Vsq <- t(V) %*% V
    global_l2 <- c(sqrt(t(avg_imbal[(d - n_lags + 1):d]) %*% Vsq %*%
                          avg_imbal[(d - n_lags + 1):d])) / sqrt(d)
    avg_l2 <- mean(apply(imbalance, 2,
                  function(x) c(sqrt(t(x[(d - n_lags + 1):d]) %*% Vsq %*%
                                x[(d - n_lags + 1):d]))))
    ind_l2 <- sqrt(mean(
      apply(imbalance, 2,
      function(x) c(x[(d - n_lags + 1):d] %*% Vsq %*% x[(d - n_lags + 1):d]) /
          sum(x[(d - n_lags + 1):d] != 0))))
    # pad weights with zeros for treated units and divide by number of treated units
    vapply(1:J,
           function(j) {
             weightj <-  numeric(n)
             weightj[donors[[j]]] <- out$x[(nj0cumsum[j] + 1):nj0cumsum[j + 1]]
             weightj
           },
           numeric(n)) -> weights

    weights <- t(t(weights) / n1)
    # manually enforce non-negativity constraint
    # (osqp solver only enforces constraint up to a tolerance)
    weights <- pmax(weights, 0)

    output <- list(weights = weights,
                   imbalance = cbind(avg_imbal, imbalance),
                   global_l2 = global_l2,
                   ind_l2 = ind_l2,
                   avg_l2 = avg_l2,
                   V = V)

    if(!is.null(Z)) {
      # imbalance in auxiliary covariates
      z_t <- sapply(1:J, function(j) colMeans(Z[which_t[[j]], , drop = F]))
      imbal_z <- z_t - t(Z) %*% weights
      avg_imbal_z <- rowSums(t(t(imbal_z) * n1)) / sum(n1)
      global_l2_z <- sqrt(sum(avg_imbal_z ^ 2))
      ind_l2_z <- sum(apply(imbal_z, 2, function(x) sqrt(sum(x ^ 2))))
      imbal_z <- cbind(avg_imbal_z, imbal_z)
      rownames(imbal_z) <- colnames(Z)

      output$imbalance_aux <- imbal_z
      output$global_l2_aux <- global_l2_z
      output$ind_l2_aux <- ind_l2_z
    }
    

    
    

    return(output)
}


#' Create constraint matrices for multisynth QP
#' @param trt Vector of treatment levels/times
#' @param grps Treatment times
#' @param n_leads Number of time periods after treatment to impute control values.
#' @param n_lags Number of pre-treatment periods to balance
#' @param Xc List of outcomes for possible comparison units
#' @param d Max number of lagged outcomes
#' @param n1 Vector of number of treated units per cohort
#' @noRd
#' @return 
#'         \itemize{
#'          \item{"Amat"}{Linear constraint matrix}
#'          \item{"lvec"}{Lower bounds for linear constraints}
#'          \item{"uvec"}{Upper bounds for linear constraints}
#'         }
make_constraint_mats <- function(trt, grps, n_leads, n_lags, Xc, Zc, d, n1) {

    J <- length(grps)
    idxs0  <- trt  > n_leads + min(grps)

    n0 <- sum(idxs0)

    ## sum to n1 constraint
    A1 <- do.call(Matrix::bdiag, lapply(1:(J), function(x) rep(1, n0)))
    A1 <- Matrix::bdiag(lapply(1:J, function(j) rep(1, nrow(Xc[[j]]))))
    
    Amat <- as.matrix(Matrix::t(A1))
    Amat <- Matrix::rbind2(Matrix::t(A1), Matrix::Diagonal(nrow(A1)))

    dz <- ncol(Zc[[1]])
    # constraints for transformed weights
    A_trans1 <- do.call(Matrix::bdiag,
                       lapply(1:J,
                        function(j)  {
                            dj <- ncol(Xc[[j]])
                            ndim <- min(dj, n_lags)
                            max_dim <- min(d, n_lags)
                            mat <- Xc[[j]][, (dj - ndim + 1):dj, drop = F]
                            n0 <- nrow(mat)
                            zero_mat <- Matrix::Matrix(0, n0, max_dim - ndim)
                            Matrix::t(cbind(zero_mat, mat))
                       }))

    # sum of total number of pre-periods
    sum_tj <- min(d, n_lags) * J
    A_trans2 <- - Matrix::Diagonal(sum_tj)
    A_trans <- Matrix::cbind2(
      Matrix::cbind2(A_trans1, A_trans2),
      Matrix::Matrix(0, nrow = nrow(A_trans1), ncol = dz * J))

    # constraints for transformed weights on auxiliary covariates
    A_transz <- Matrix::t(Matrix::bdiag(Zc))
    A_transz <- Matrix::cbind2(
      Matrix::cbind2(A_transz, 
                     Matrix::Matrix(0, nrow = nrow(A_transz), ncol = sum_tj)),
      -Matrix::Diagonal(dz * J))

    # add in zero columns for transformated weights
    Amat <- Matrix::cbind2(Amat, 
                           Matrix::Matrix(0,
                                          nrow = nrow(Amat),
                                          ncol = sum_tj + dz * J))
    Amat <- Matrix::rbind2(Matrix::rbind2(Amat, A_trans), A_transz)

    lvec <- c(n1, # sum to n1 constraint
              numeric(nrow(A1)), # lower bound by zero
              numeric(sum_tj), # constrain transformed weights
              numeric(dz * J) # constrain transformed weights
             ) 
    
    uvec <- c(n1, #sum to n1 constraint
              rep(Inf, nrow(A1)),
              numeric(sum_tj), # constrain transformed weights
              numeric(dz * J) # constrain transformed weights
              )


    return(list(Amat = Amat, lvec = lvec, uvec = uvec))
}

#' Make the vector in the QP
#' @param Xc List of outcomes for possible comparison units
#' @param x_t List of outcomes for treated units
#' @param nu Hyperparameter between global and individual balance
#' @param n_lags Number of lags to balance
#' @param d Largest number of pre-intervention time periods
#' @param V Scaling matrix
#' @param norm_pool Normalizing value for pooled objective
#' @param norm_sep Normalizing value for separate objective
#' @noRd
make_qvec <- function(Xc, x_t, z_t, nu, n_lags, d, V, norm_pool, norm_sep) {

    J <- length(x_t)
    Vsq <- t(V) %*% V
    qvec <- lapply(1:J,
                   function(j) {
                       dj <- length(x_t[[j]])
                       ndim <- min(dj, n_lags)
                       max_dim <- min(d, n_lags)
                       vec <- x_t[[j]][(dj - ndim + 1):dj] / ndim
                       Vsq %*% c(numeric(max_dim - ndim), vec)
                   })

    avg_target_vec <- lapply(x_t,
                            function(xtk) {
                                dk <- length(xtk)
                                ndim <- min(dk, n_lags)
                                max_dim <- min(d, n_lags)
                                c(numeric(max_dim - ndim), 
                                    xtk[(dk - ndim + 1):dk])
                            }) %>% reduce(`+`) %*% Vsq
    qvec_avg <- rep(avg_target_vec, J)
    # qvec <- - (nu * qvec_avg / n_lags + (1 - nu) * reduce(qvec, c))
    # qvec <- - (nu * qvec_avg / (J ^ 2 * n_lags) +
    #           (1 - nu) * reduce(qvec, c) / J)
    qvec <- - (nu * qvec_avg / (norm_pool * n_lags * J ^ 2) +
               (1 - nu) * reduce(qvec, c) / (norm_sep * J))

    qvec_avg_z <- z_t %>% reduce(`+`)
    qvec_avg_z <- rep(qvec_avg_z, J)
    # qvec_z <- - (nu * qvec_avg_z + (1 - nu) * reduce(z_t, c)) / length(z_t[[1]])
    # qvec_z <- - (nu * qvec_avg_z / J ^2 +
    #              (1 - nu) * reduce(z_t, c) / J) / length(z_t[[1]])
    qvec_z <- - (nu * qvec_avg_z / (norm_pool * J ^ 2) +
                 (1 - nu) * reduce(z_t, c) / (norm_sep * J)) / length(z_t[[1]])

    total_ctrls <- lapply(Xc, nrow) %>% reduce(`+`)
    return(c(numeric(total_ctrls), qvec, qvec_z))
}


#' Make the matrix in the QP
#' @param Xc List of outcomes for possible comparison units
#' @param x_t List of outcomes for treated units
#' @param nu Hyperparameter between global and individual balance
#' @param n_lags Number of lags to balance
#' @param lambda Regularization hyperparameter
#' @param d Largest number of pre-intervention time periods
#' @param V Scaling matrix
#' @param norm_pool Normalizing value for pooled objective
#' @param norm_sep Normalizing value for separate objective
#' @noRd
make_Pmat <- function(Xc, x_t, dz, nu, n_lags, lambda, d, V,
                      norm_pool, norm_sep) {

    J <- length(x_t)

    Vsq <- t(V) %*% V
    ndims <- vapply(1:J,
                    function(j) min(length(x_t[[j]]), n_lags),
                    numeric(1))
    max_dim <- min(d, n_lags)
    total_dim <- sum(ndims)
    total_dim <- max_dim * J
    V1 <- Matrix::bdiag(lapply(ndims, 
                        function(ndim) Matrix::Diagonal(max_dim, 1 / ndim)))
    V1 <- Matrix::bdiag(lapply(ndims, function(ndim) Vsq / ndim))
    tile_sparse <- function(j) {
        kronecker(Matrix::Matrix(1, nrow = j, ncol = j), Vsq)
    }
    tile_sparse_cov <- function(d, j) {
        kronecker(Matrix::Matrix(1, nrow = j, ncol = j),
                  Matrix::Diagonal(d))
    }
    V2 <- tile_sparse(J) / n_lags
    # Pmat <- nu * V2 + (1 - nu) * V1
    # Pmat <- nu * V2 / J ^ 2 + (1 - nu) * V1 / J
    Pmat <- nu * V2 / (norm_pool * J ^ 2) + (1 - nu) * V1 / (norm_sep * J)
    V1_z <- Matrix::Diagonal(dz * J, 1 / dz)
    V2_z <- tile_sparse_cov(dz, J) / dz
    # Pmat_z <- nu * V2_z + (1 - nu) * V1_z
    # Pmat_z <- nu * V2_z / J ^ 2 + (1 - nu) * V1_z / J
    Pmat_z <- nu * V2_z / (norm_pool * J ^ 2) + (1 - nu) * V1_z / (norm_sep * J)
    # combine
    total_ctrls <- lapply(Xc, nrow) %>% reduce(`+`)
    Pmat <- Matrix::bdiag(Matrix::Matrix(0, nrow = total_ctrls,
                                         ncol = total_ctrls),
                          Pmat, Pmat_z)
    I0 <- Matrix::bdiag(Matrix::Diagonal(total_ctrls),
                        Matrix::Matrix(0, nrow = total_dim + dz * J,
                                          ncol = total_dim + dz * J))
    return(Pmat + lambda * I0)

}

================================================
FILE: R/multisynth_class.R
================================================
################################################################################
## Fitting, plotting, summarizing staggered synth
################################################################################

#' Fit staggered synth
#' @param form outcome ~ treatment | weighting covariates | approximate matching covaraites | exact matching covariates
#' \itemize{
#'    \item{outcome}{Name of the outcome of interest}
#'    \item{treatment}{Name of the treatment assignment variable}
#'    \item{weighting covariates}{Auxiliary covariates to weight on}
#'    \item{approximate matching covariates}{Auxiliary covariates to approximately match one before weighting}
#'    \item{exact matching covariates}{Auxiliary covariates to exactly match on before weighting}
#' }
#' If covariates are time-varying, their average value before the first unit is treated will be used. This can be changed by supplying a custom aggregation function to cov_agg.
#' @param unit Name of unit column
#' @param time Name of time column
#' @param data Panel data as dataframe
#' @param n_leads How long past treatment effects should be estimated for, default is number of post treatment periods for last treated unit
#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods
#' @param nu Fraction of balance for individual balance
#' @param lambda Regularization hyperparameter, default = 0
#' @param V Scaling matrix for synth optimization, default NULL is identity
#' @param fixedeff Whether to include a unit fixed effect, default TRUE
#' @param n_factors Number of factors for interactive fixed effects, setting to NULL fits with CV, default is 0
#' @param scm Whether to fit scm weights
#' @param time_cohort Whether to average synthetic controls into time cohorts, default FALSE
#' @param cov_agg Covariate aggregation function
#' @param eps_abs Absolute error tolerance for osqp
#' @param eps_rel Relative error tolerance for osqp
#' @param verbose Whether to print logs for osqp
#' @param ... Extra arguments
#' 
#' @return multisynth object that contains:
#'         \itemize{
#'          \item{"weights"}{weights matrix where each column is a set of weights for a treated unit}
#'          \item{"data"}{Panel data as matrices}
#'          \item{"imbalance"}{Matrix of treatment minus synthetic control for pre-treatment time periods, each column corresponds to a treated unit}
#'          \item{"global_l2"}{L2 imbalance for the pooled synthetic control}
#'          \item{"scaled_global_l2"}{L2 imbalance for the pooled synthetic control, scaled by the imbalance for unitform weights}
#'          \item{"ind_l2"}{Average L2 imbalance for the individual synthetic controls}
#'          \item{"scaled_ind_l2"}{Average L2 imbalance for the individual synthetic controls, scaled by the imbalance for unitform weights}
#'         \item{"n_leads", "n_lags"}{Number of post treatment outcomes (leads) and pre-treatment outcomes (lags) to include in the analysis}
#'          \item{"nu"}{Fraction of balance for individual balance}
#'          \item{"lambda"}{Regularization hyperparameter}
#'          \item{"scm"}{Whether to fit scm weights}
#'          \item{"grps"}{Time periods for treated units}
#'          \item{"y0hat"}{Pilot estimates of control outcomes}
#'          \item{"residuals"}{Difference between the observed outcomes and the pilot estimates}
#'          \item{"n_factors"}{Number of factors for interactive fixed effects}
#'         }
#' @export
multisynth <- function(form, unit, time, data,
                       n_leads=NULL, n_lags=NULL,
                       nu=NULL, lambda=0, V = NULL,
                       fixedeff = TRUE,
                       n_factors=0,
                       scm=T,
                       time_cohort = F,
                       how_match = "knn",
                       cov_agg = NULL,
                       eps_abs = 1e-4,
                       eps_rel = 1e-4,
                       verbose = FALSE, ...) {
    call_name <- match.call()

    form <- Formula::Formula(form)
    unit <- enquo(unit)
    time <- enquo(time)

    ## format data
    outcome <- terms(formula(form, rhs=1))[[2]]
    trt <- terms(formula(form, rhs=1))[[3]]
    wide <- format_data_stag(outcome, trt, unit, time, data)

    check_data_stag(wide, fixedeff, n_leads, n_lags)

    force <- if(fixedeff) 3 else 2

    # get covariates
    if(length(form)[2] == 2) {
      Z <- extract_covariates(form, unit, time, wide$time[min(wide$trt) + 1],
                              data, cov_agg)
    } else if(length(form)[2] == 3) {
      app_form <- Formula::Formula(formula(form, rhs = 1:2))
      Z_weight <- extract_covariates(app_form, unit, time,
                                  wide$time[min(wide$trt) + 1],
                                  data, cov_agg)
      exact_form <- Formula::Formula(formula(form, rhs = c(1,3)))
      Z_match<- extract_covariates(exact_form, unit, time,
                                  wide$time[min(wide$trt) + 1],
                                  data, cov_agg)
      Z <- cbind(Z_weight, Z_match)
      wide$match_covariates <- colnames(Z_match)
    } else if(length(form)[2] == 4) {
      if(time_cohort) {
        stop("Aggregating by time cohort and matching on covariates are not ",
             "implemented together. If matching then you cannot aggregate ",
             "by time cohort.")
      }
      weight_form <- Formula::Formula(formula(form, rhs = c(1,2)))
      Z_weight <- extract_covariates(weight_form, unit, time,
                                      wide$time[min(wide$trt) + 1],
                                      data, cov_agg)
      app_form <- Formula::Formula(formula(form, rhs = c(1,3)))
      Z_app <- extract_covariates(app_form, unit, time,
                                  wide$time[min(wide$trt) + 1],
                                  data, cov_agg)
      exact_form <- Formula::Formula(formula(form, rhs = c(1,4)))
      Z_exact <- extract_covariates(exact_form, unit, time,
                                  wide$time[min(wide$trt) + 1],
                                  data, cov_agg)
      Z <- cbind(Z_weight, Z_app, Z_exact)
      wide$exact_covariates <- colnames(Z_exact)
      wide$match_covariates <- c(colnames(Z_app), wide$exact_covariates)
    } else {
        Z <- NULL
    }
    wide$Z <- Z

    # if n_leads is NULL set it to be the largest possible number of leads
    # for the last treated unit
    if(is.null(n_leads)) {
        n_leads <- ncol(wide$y)
    } else if(n_leads > max(apply(1-wide$mask, 1, sum, na.rm = T)) +
                                                              ncol(wide$y)) {
        n_leads <- max(apply(1-wide$mask, 1, sum, na.rm = T)) + ncol(wide$y)
    }

    ## if n_lags is NULL set it to the largest number of pre-treatment periods
    if(is.null(n_lags)) {
        n_lags <- ncol(wide$X)
    } else if(n_lags > ncol(wide$X)) {
        n_lags <- ncol(wide$X)
    }

    long_df <- data[c(quo_name(unit), quo_name(time), as.character(trt), as.character(outcome))]

    msynth <- multisynth_formatted(wide = wide, relative = T,
                                n_leads = n_leads, n_lags = n_lags,
                                nu = nu, lambda = lambda, V = V,
                                force = force, n_factors = n_factors,
                                scm = scm, time_cohort = time_cohort,
                                time_w = F, lambda_t = 0,
                                fit_resids = TRUE, eps_abs = eps_abs,
                                eps_rel = eps_rel, verbose = verbose, long_df = long_df, 
                                how_match = how_match, ...)
    
   
    units <- data %>% arrange(!!unit) %>% distinct(!!unit) %>% pull(!!unit)
    rownames(msynth$weights) <- units


    if(scm) {
        ## Get imbalance for uniform weights on raw data

        ## TODO: Get rid of this stupid hack of just fitting the weights again with big lambda
        unif <- multisynth_qp(X=wide$X, ## X=residuals[,1:ncol(wide$X)],
                            trt=wide$trt,
                            mask=wide$mask,
                            Z = Z[, ! colnames(Z) %in% wide$match_covariates,
                                  drop = F],
                            n_leads=n_leads,
                            n_lags=n_lags,
                            relative=T,
                            nu=0, lambda=1e10,
                            V = V,
                            time_cohort = time_cohort,
                            donors = msynth$donors,
                            eps_rel = eps_rel, 
                            eps_abs = eps_abs,
                            verbose = verbose)
        ## scaled global balance
        ## msynth$scaled_global_l2 <- msynth$global_l2  / sqrt(sum(unif$imbalance[,1]^2))
        msynth$scaled_global_l2 <- msynth$global_l2  / unif$global_l2

        ## balance for individual estimates
        ## msynth$scaled_ind_l2 <- msynth$ind_l2  / sqrt(sum(unif$imbalance[,-1]^2))
        msynth$scaled_ind_l2 <- msynth$ind_l2  / unif$ind_l2
    }

    msynth$call <- call_name

    return(msynth)

}


#' Internal funciton to fit staggered synth with formatted data
#' @param wide List containing data elements
#' @param relative Whether to compute balance by relative time
#' @param n_leads How long past treatment effects should be estimated for
#' @param n_lags Number of pre-treatment periods to balance, default is to balance all periods
#' @param nu Fraction of balance for individual balance
#' @param lambda Regularization hyperparameter, default = 0
#' @param V Scaling matrix for synth optimization, default NULL is identity
#' @param force c(0,1,2,3) what type of fixed effects to include
#' @param n_factors Number of factors for interactive fixed effects, default does CV
#' @param scm Whether to fit scm weights
#' @param time_cohort Whether to average synthetic controls into time cohorts
#' @param time_w Whether to fit time weights
#' @param lambda_t Regularization for time regression
#' @param fit_resids Whether to fit SCM on the residuals or not
#' @param eps_abs Absolute error tolerance for osqp
#' @param eps_rel Relative error tolerance for osqp
#' @param verbose Whether to print logs for osqp
#' @param long_df A long dataframe with 4 columns in the order unit, time, trt, outcome
#' @param ... Extra arguments
#' @noRd
#' @return multisynth object
multisynth_formatted <- function(wide, relative=T, n_leads, n_lags,
                       nu, lambda, V,
                       force,
                       n_factors,
                       scm, time_cohort, 
                       time_w, lambda_t,
                       fit_resids,
                       eps_abs, eps_rel,
                       verbose, long_df, 
                       how_match, ...) {
    ## average together treatment groups
    ## grps <- unique(wide$trt) %>% sort()
    if(time_cohort) {
        grps <- unique(wide$trt[is.finite(wide$trt)])
    } else {
        grps <- wide$trt[is.finite(wide$trt)]
    }
    J <- length(grps)

    ## fit outcome models
    if(time_w) {
        # Autoregressive model
        out <- fit_time_reg(cbind(wide$X, wide$y), wide$trt,
                            n_leads, lambda_t, ...)
        y0hat <- out$y0hat
        residuals <- out$residuals
        params <- out$time_weights
    } else if(is.null(n_factors)) {
        out <- tryCatch({
            fit_gsynth_multi(long_df, cbind(wide$X, wide$y), wide$trt, force=force)
        }, error = function(error_condition) {
            stop("Cannot run CV because there are too few pre-treatment periods.")
        })

        y0hat <- out$y0hat
        params <- out$params
        n_factors <- ncol(params$factor)
        ## get residuals from outcome model
        residuals <- cbind(wide$X, wide$y) - y0hat
        
    } else if (n_factors != 0) {
        ## if number of factors is provided don't do CV
        out <- fit_gsynth_multi(long_df, cbind(wide$X, wide$y), wide$trt,
                                r=n_factors, CV=0, force=force)
        y0hat <- out$y0hat
        params <- out$params        
        
        ## get residuals from outcome model
        residuals <- cbind(wide$X, wide$y) - y0hat
    } else if(force == 0 & n_factors == 0) {
        # if no fixed effects or factors, just take out 
        # control averages at each time point
        # time fixed effects from pure controls
        pure_ctrl <- cbind(wide$X, wide$y)[!is.finite(wide$trt), , drop = F]
        y0hat <- matrix(colMeans(pure_ctrl, na.rm = TRUE),
                          nrow = nrow(wide$X), ncol = ncol(pure_ctrl), 
                          byrow = T)
        residuals <- cbind(wide$X, wide$y) - y0hat
        params <- NULL
    } else {
        ## take out pre-treatment averages
        fullmask <- cbind(wide$mask, matrix(0, nrow=nrow(wide$mask),
                                            ncol=ncol(wide$y)))
        out <- fit_feff(cbind(wide$X, wide$y), wide$trt, fullmask, force, time_cohort)
        y0hat <- out$y0hat
        residuals <- out$residuals
        params <- NULL
    }

    ## balance the residuals
    if(fit_resids) {
        if(time_w) {
            # fit scm on residuals after taking out unit fixed effects
            fullmask <- cbind(wide$mask, matrix(0, nrow=nrow(wide$mask),
                                            ncol=ncol(wide$y)))
            out <- fit_feff(cbind(wide$X, wide$y), wide$trt, fullmask, force, time_cohort)
            bal_mat <- lapply(out$residuals, function(x) x[,1:ncol(wide$X)])
        } else if(typeof(residuals) == "list") {
            bal_mat <- lapply(residuals, function(x) x[,1:ncol(wide$X)])
        } else {
            bal_mat <- residuals[,1:ncol(wide$X)]
        }
    } else {
        # if not balancing residuals, then take out control averages
        # for each time
        ctrl_avg <- matrix(colMeans(wide$X[!is.finite(wide$trt), , drop = F]),
                          nrow = nrow(wide$X), ncol = ncol(wide$X), byrow = T)
        bal_mat <- wide$X - ctrl_avg
        bal_mat <- wide$X
    }
    

    if(scm) {

        # get eligible set of donor units based on covariates
        donors <- get_donors(wide$X, wide$y, wide$trt,
                             wide$Z[, colnames(wide$Z) %in% 
                                      wide$match_covariates, drop = F],
                             time_cohort, n_lags, n_leads, how = how_match,
                             exact_covariates = wide$exact_covariates, ...)
        # run separate synth for scaling
        sep_fit <- multisynth_qp(X=bal_mat,
                                    trt=wide$trt,
                                    mask=wide$mask,
                                    Z = wide$Z[, !colnames(wide$Z) %in%
                                                  wide$match_covariates,
                                                  drop = F],
                                    n_leads=n_leads,
                                    n_lags=n_lags,
                                    relative=relative,
                                    nu=0, lambda=lambda,
                                    V = V,
                                    time_cohort = time_cohort,
                                    donors = donors,
                                    eps_rel = eps_rel,
                                    eps_abs = eps_abs,
                                    verbose = verbose)
        # if no nu value is provided, use default based on
        # global and individual imbalance for separate synth
        if(is.null(nu)) {
            # select nu by triangle inequality ratio
            glbl <- sep_fit$global_l2 * sqrt(nrow(sep_fit$imbalance))
            ind <- sep_fit$avg_l2
            nu <- glbl / ind

        }

        msynth <- multisynth_qp(X=bal_mat,
                                trt=wide$trt,
                                mask=wide$mask,
                                Z = wide$Z[, !colnames(wide$Z) %in%
                                                  wide$match_covariates,
                                                  drop = F],
                                n_leads=n_leads,
                                n_lags=n_lags,
                                relative=relative,
                                nu=nu, lambda=lambda,
                                V = V,
                                time_cohort = time_cohort,
                                donors = donors,
                                norm_pool = sep_fit$global_l2 ^ 2,
                                norm_sep = sep_fit$ind_l2 ^ 2,
                                eps_rel = eps_rel,
                                eps_abs = eps_abs,
                                verbose = verbose)
    } else {
        msynth <- list(weights = matrix(0, nrow = nrow(wide$X), ncol = J),
                       imbalance=NA,
                       global_l2=NA,
                       ind_l2=NA)
    }

    ## put in data and hyperparams
    msynth$data <- wide
    msynth$relative <- relative
    msynth$n_leads <- n_leads
    msynth$n_lags <- n_lags
    msynth$nu <- nu
    msynth$lambda <- lambda
    msynth$scm <- scm
    msynth$time_cohort <- time_cohort


    msynth$grps <- grps
    msynth$y0hat <- y0hat
    msynth$residuals <- residuals

    msynth$n_factors <- n_factors
    msynth$force <- force


    ## outcome model parameters
    msynth$params <- params

    # more arguments
    msynth$scm <- scm
    msynth$time_w <- time_w
    msynth$lambda_t <- lambda_t
    msynth$fit_resids <- fit_resids
    msynth$extra_pars <- c(list(eps_abs = eps_abs, 
                                eps_rel = eps_rel, 
                                verbose = verbose), 
                           list(...))
    msynth$long_df <- long_df

    msynth$how_match <- how_match
    msynth$donors <- donors
    ##format output
    class(msynth) <- "multisynth"
    return(msynth)
}






#' Get prediction of average outcome under control or ATT
#' @param object Fit multisynth object
#' @param att If TRUE, return the ATT, if FALSE, return imputed counterfactual
#' @param att_weight Weights to place on individual units/cohorts when averaging
#' @param bs_weight Weight to perturb units by for weighted bootstrap
#' @param ... Optional arguments
#'
#' @return Matrix of predicted post-treatment control outcomes for each treated unit
#' @export
predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = NULL, ...) {

    multisynth <- object
    relative <- T
    
    time_cohort <- multisynth$time_cohort
    if(is.null(relative)) {
        relative <- multisynth$relative
    }
    n_leads <- multisynth$n_leads
    d <- ncol(multisynth$data$X)
    n <- nrow(multisynth$data$X)
    fulldat <- cbind(multisynth$data$X, multisynth$data$y)
    ttot <- ncol(fulldat)
    grps <- multisynth$grps
    J <- length(grps)

    if(is.null(bs_weight)) {
      # bs_weight <- rep(1 / sqrt(sum(is.finite(multisynth$data$trt))), n)
      bs_weight <- rep(1, n)
    }

    if(time_cohort) {
        which_t <- lapply(grps, 
                          function(tj) (1:n)[multisynth$data$trt == tj])
        mask <- unique(multisynth$data$mask)
    } else {
        which_t <- (1:n)[is.finite(multisynth$data$trt)]
        mask <- multisynth$data$mask
    }
    

    n1 <- sapply(1:J, function(j) length(which_t[[j]]))

    fullmask <- cbind(mask, matrix(0, nrow = J, ncol = (ttot - d)))
    

    ## estimate the post-treatment values to get att estimates
    mu1hat <- vapply(1:J,
                     function(j) colMeans((bs_weight * fulldat)[which_t[[j]],
                                                              , drop=FALSE]),
                     numeric(ttot))



    ## get average outcome model estimates and reweight residuals
    if(typeof(multisynth$y0hat) == "list") {
        mu0hat <- vapply(1:J,
                        function(j) {
                            y0hat <- colMeans(
                              (bs_weight * multisynth$y0hat[[j]])[which_t[[j]],
                                                                , drop=FALSE])
                            weightsj <- multisynth$weights[,j] * bs_weight
                            resj <- multisynth$residuals[[j]][weightsj != 0,, drop = F]
                            y0hat + t(resj) %*% weightsj[weightsj != 0]
                        }
                       , numeric(ttot)
                        )
    } else {
        mu0hat <- vapply(1:J,
                        function(j) {
                            y0hat <- colMeans(
                              (bs_weight * multisynth$y0hat)[which_t[[j]],
                                                              , drop=FALSE])
                            weightsj <- multisynth$weights[, j] * bs_weight
                            resj <- multisynth$residuals[weightsj != 0,, drop = F]
                            y0hat + t(resj) %*% weightsj[weightsj != 0]
                        }
                       , numeric(ttot)
                        )
    }

    tauhat <- mu1hat - mu0hat

    if(is.null(att_weight)) {
      att_weight <- rep(1, J)
    }
    ## re-index time if relative to treatment
    if(relative) {
        total_len <- min(d + n_leads, ttot + d - min(grps)) ## total length of predictions
        mu0hat <- vapply(1:J,
                         function(j) {
                             vec <- c(rep(NA, d-grps[j]),
                                      mu0hat[1:grps[j],j],
                                      mu0hat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])
                             ## last row is post-treatment average
                             c(vec,
                               rep(NA, total_len - length(vec)),
                               mean(mu0hat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))
                               
                         },
                         numeric(total_len +1
                                 ))
        
        tauhat <- vapply(1:J,
                         function(j) {
                             vec <- c(rep(NA, d-grps[j]),
                                      tauhat[1:grps[j],j],
                                      tauhat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])
                             ## last row is post-treatment average
                             c(vec,
                               rep(NA, total_len - length(vec)),
                               mean(tauhat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))
                         },
                         numeric(total_len +1
                                 ))
        # re-index unit weights if they change over time
        if(is.null(dim(att_weight))) {
          if(J == 1) {
            att_weight <- matrix(replicate(total_len + 1, att_weight), ncol = 1)
          } else {
            att_weight <- t(replicate(total_len + 1, att_weight))
          }
        }
        att_weight_new <- vapply(1:J,
                        function(j) {
                            vec <- c(rep(NA, d-grps[j]),
                                    att_weight[1:grps[j],j],
                                    att_weight[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])
                            ## last row is post-treatment average
                            c(vec,
                              rep(NA, total_len - length(vec)),
                              mean(att_weight[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j]))
                              
                        },
                        numeric(total_len +1
                                ))
          
        ## get the overall average estimate
        avg <- apply(mu0hat, 1, function(z) sum(n1 * z, na.rm=T) / sum(n1 * !is.na(z)))
        avg <- sapply(1:nrow(mu0hat),
        function(k)  {
          sum(n1 * mu0hat[k,] * att_weight_new[k,], na.rm=T) /
            sum(n1 * (!is.na(mu0hat[k,])) * att_weight_new[k, ], na.rm = T)
        })
        mu0hat <- cbind(avg, mu0hat)

        avg <- apply(tauhat, 1, function(z) sum(n1 * z, na.rm=T) / sum(n1 * !is.na(z)))
        avg <- sapply(1:nrow(mu0hat),
        function(k)  {
          sum(n1 * tauhat[k,] * att_weight_new[k,], na.rm=T) /
            sum(n1 * (!is.na(tauhat[k,])) * att_weight_new[k, ], na.rm = T)
        })
        tauhat <- cbind(avg, tauhat)
        
    } else {

        ## remove all estimates for t > T_j + n_leads
        vapply(1:J,
               function(j) c(mu0hat[1:min(grps[j]+n_leads, ttot),j],
                             rep(NA, max(0, ttot-(grps[j] + n_leads)))),
               numeric(ttot)) -> mu0hat

        vapply(1:J,
               function(j) c(tauhat[1:min(grps[j]+n_leads, ttot),j],
                             rep(NA, max(0, ttot-(grps[j] + n_leads)))),
               numeric(ttot)) -> tauhat

        
        ## only average currently treated units
        avg1 <- rowSums(t(fullmask) *  mu0hat * n1) /
                rowSums(t(fullmask) *  n1)
        avg2 <- rowSums(t(1-fullmask) *  mu0hat * n1) /
            rowSums(t(1-fullmask) *  n1)
        avg <- replace_na(avg1, 0) * apply(fullmask, 2, min) +
            replace_na(avg2,0) * apply(1-fullmask, 2, max)
        cbind(avg, mu0hat) -> mu0hat

        ## only average currently treated units
        avg1 <- rowSums(t(fullmask) *  tauhat * n1) /
            rowSums(t(fullmask) *  n1)
        avg2 <- rowSums(t(1-fullmask) *  tauhat * n1) /
            rowSums(t(1-fullmask) *  n1)
        avg <- replace_na(avg1, 0) * apply(fullmask, 2, min) +
            replace_na(avg2,0) * apply(1 - fullmask, 2, max)
        cbind(avg, tauhat) -> tauhat
    }
    

    if(att) {
        return(tauhat)
    } else {
        return(mu0hat)
    }
}


#' Print function for multisynth
#' @param x multisynth object
#' @param ... Optional arguments
#' @export
print.multisynth <- function(x, att_weight = NULL, ...) {
    multisynth <- x
    
    ## straight from lm
    cat("\nCall:\n", paste(deparse(multisynth$call), 
        sep="\n", collapse="\n"), "\n\n", sep="")

    # print att estimates
    att_post <- predict(multisynth, att=T, att_weight = att_weight)[,1]
    att_post <- att_post[length(att_post)]

    cat(paste("Average ATT Estimate: ",
              format(round(mean(att_post),3), nsmall = 3), "\n\n", sep=""))
}



#' Plot function for multisynth
#' @importFrom graphics plot
#' @param x Augsynth object to be plotted
#' @param inf_type Type of inference to perform:
#'  \itemize{
#'    \item{bootstrap}{Wild bootstrap, the default option}
#'    \item{jackknife}{Jackknife}
#' }
#' @param inf Whether to compute and plot confidence intervals
#' @param levels Which units/groups to plot, default is every group
#' @param label Whether to label the individual levels
#' @param weights Whether to plot the weights, default = FALSE
#' @param ... Optional arguments
#' @export
plot.multisynth <- function(x, inf_type = "bootstrap", inf = T,
                            levels = NULL, label = T, 
                            weights = FALSE, ...) {

    if(weights) {
      ever_trt <- x$data$units[is.finite(x$data$trt)]
      never_trt <- x$data$units[!is.finite(x$data$trt)]
      weights <- data.frame(x$weights)

      colnames(weights) <- ever_trt
      weights$unit <- factor(rownames(weights),
                             levels = sort(rownames(weights), decreasing = TRUE))

      # plotting the weights
      weights %>%
        tidyr::pivot_longer(-unit, names_to = "trt_unit", values_to = "weight") %>%
        ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) + 
        ggplot2::geom_tile(color = "white", size=.5) +
        ggplot2::scale_fill_gradient(low = "white", high = "black", limits=c(-0.01,1.01)) +
        ggplot2::guides(fill = "none") +
        ggplot2::xlab("Treated Unit") +
        ggplot2::ylab("Donor Unit") +
        ggplot2::theme_bw() + 
        ggplot2::theme(axis.ticks.x = ggplot2::element_blank(),
              axis.ticks.y = ggplot2::element_blank())
    }
    else {
      plot(summary(x, inf_type = inf_type, ...),
         inf = inf, levels = levels, label = label)
    }
}





#' Summary function for multisynth
#' @param object multisynth object
#' @param inf_type Type of inference to perform:
#'  \itemize{
#'    \item{bootstrap}{Wild bootstrap, the default option}
#'    \item{jackknife}{Jackknife}
#' }
#' @param ... Optional arguments
#' 
#' @return summary.multisynth object that contains:
#'         \itemize{
#'          \item{"att"}{Dataframe with ATT estimates, standard errors for each treated unit}
#'          \item{"global_l2"}{L2 imbalance for the pooled synthetic control}
#'          \item{"scaled_global_l2"}{L2 imbalance for the pooled synthetic control, scaled by the imbalance for unitform weights}
#'          \item{"ind_l2"}{Average L2 imbalance for the individual synthetic controls}
#'          \item{"scaled_ind_l2"}{Average L2 imbalance for the individual synthetic controls, scaled by the imbalance for unitform weights}
#'         \item{"n_leads", "n_lags"}{Number of post treatment outcomes (leads) and pre-treatment outcomes (lags) to include in the analysis}
#'         }
#' @export
summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL, ...) {

    multisynth <- object
    
    relative <- T

    n_leads <- multisynth$n_leads
    d <- ncol(multisynth$data$X)
    n <- nrow(multisynth$data$X)
    ttot <- d + ncol(multisynth$data$y)

    trt <- multisynth$data$trt
    time_cohort <- multisynth$time_cohort
    if(time_cohort) {
        grps <- unique(trt[is.finite(trt)])
        which_t <- lapply(grps, function(tj) (1:n)[trt == tj])
    } else {
        grps <- trt[is.finite(trt)]
        which_t <- (1:n)[is.finite(trt)]
    }
    
    # grps <- unique(multisynth$data$trt) %>% sort()
    J <- length(grps)
    
    # which_t <- (1:n)[is.finite(multisynth$data$trt)]
    times <- multisynth$data$time
    
    summ <- list()
    ## post treatment estimate for each group and overall
    # att <- predict(multisynth, relative, att=T)
    
    if(inf_type == "jackknife") {
        attse <- jackknife_se_multi(multisynth, relative, att_weight = att_weight, ...)
    } else if(inf_type == "bootstrap") {
        if(object$force == 2) {
          warning("Wild bootstrap without including a unit fixed effect ",
                  "is likely to be very conservative!")
        }
        attse <- weighted_bootstrap_multi(multisynth, att_weight = att_weight, ...)
    } else {
        att <- predict(multisynth, relative, att=T, att_weight = att_weight)
        attse <- list(att = att,
                      se = matrix(NA, nrow(att), ncol(att)),
                      upper_bound = matrix(NA, nrow(att), ncol(att)),
                      lower_bound = matrix(NA, nrow(att), ncol(att)))
    }
    

    if(relative) {
        att <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),
                                attse$att))
        if(time_cohort) {
            col_names <- c("Time", "Average", 
                            as.character(times[grps + 1]))
        } else {
            col_names <- c("Time", "Average", 
                            as.character(multisynth$data$units[which_t]))
        }
        names(att) <- col_names
        att %>% gather(Level, Estimate, -Time) %>%
            rename("Time"=Time) %>%
            mutate(Time=Time-1) -> att

        se <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),
                               attse$se))                            
        names(se) <- col_names
        se %>% gather(Level, Std.Error, -Time) %>%
            rename("Time"=Time) %>%
            mutate(Time=Time-1) -> se
        lower_bound <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),
                                  attse$lower_bound))
        names(lower_bound) <- col_names
        lower_bound %>% gather(Level, lower_bound, -Time) %>%
          rename("Time"=Time) %>%
          mutate(Time=Time-1) -> lower_bound

        upper_bound <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA),
                                        attse$upper_bound))
        names(upper_bound) <- col_names
        upper_bound %>% gather(Level, upper_bound, -Time) %>%
          rename("Time"=Time) %>%
            mutate(Time=Time-1) -> upper_bound

    } else {
        att <- data.frame(cbind(times, attse$att))
        names(att) <- c("Time", "Average", times[grps[1:J]])        
        att %>% gather(Level, Estimate, -Time) -> att

        se <- data.frame(cbind(times, attse$se))
        names(se) <- c("Time", "Average", times[grps[1:J]])        
        se %>% gather(Level, Std.Error, -Time) -> se

    }

    summ$att <- inner_join(att, se, by = c("Time", "Level")) %>%
      inner_join(lower_bound, by = c("Time", "Level")) %>%
        inner_join(upper_bound, by = c("Time", "Level"))

    summ$relative <- relative
    summ$grps <- grps
    summ$call <- multisynth$call
    summ$global_l2 <- multisynth$global_l2
    summ$scaled_global_l2 <- multisynth$scaled_global_l2

    summ$ind_l2 <- multisynth$ind_l2
    summ$scaled_ind_l2 <- multisynth$scaled_ind_l2

    summ$n_leads <- multisynth$n_leads
    summ$n_lags <- multisynth$n_lags

    class(summ) <- "summary.multisynth"
    return(summ)
}

#' Print function for summary function for multisynth
#' @param x summary object
#' @param level Which unit/group to print results for, default is the overall average
#' @param ... Optional arguments
#' @export
print.summary.multisynth <- function(x, level = "Average", ...) {

    summ <- x
    
    ## straight from lm
    cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="")

    first_lvl <- summ$att %>% filter(Level != "Average") %>% pull(Level) %>% min()
    
    ## get ATT estimates for treatment level, post treatment
    if(summ$relative) {
        summ$att %>%
            filter(Time >= 0, Level==level) %>%
            rename("Time Since Treatment"=Time) -> att_est
    } else if(level == "average") {
        summ$att %>% filter(Time > first_lvl, Level=="Average") -> att_est
    } else {
        summ$att %>% filter(Time > level, Level==level) -> att_est
    }

    cat(paste("Average ATT Estimate (Std. Error): ",
              summ$att %>%
                  filter(Level == level, is.na(Time)) %>%
                  pull(Estimate) %>%
                  round(3) %>% format(nsmall=3),
              "  (",
              summ$att %>%
                  filter(Level == level, is.na(Time)) %>%
                  pull(Std.Error) %>%
                  round(3) %>% format(nsmall=3),
              ")\n\n", sep=""))
    
    cat(paste("Global L2 Imbalance: ",
              format(round(summ$global_l2,3), nsmall=3), "\n",
              "Scaled Global L2 Imbalance: ",
              format(round(summ$scaled_global_l2,3), nsmall=3), "\n",
              "Percent improvement from uniform global weights: ", 
              format(round(1-summ$scaled_global_l2,3)*100), "\n\n",
              "Individual L2 Imbalance: ",
              format(round(summ$ind_l2,3), nsmall=3), "\n",
              "Scaled Individual L2 Imbalance: ", 
              format(round(summ$scaled_ind_l2,3), nsmall=3), "\n",
              "Percent improvement from uniform individual weights: ", 
              format(round(1-summ$scaled_ind_l2,3)*100), "\t",
              "\n\n",
              sep=""))


    print(att_est, row.names=F)

}

#' Plot function for summary function for multisynth
#' @importFrom ggplot2 aes
#' 
#' @param x summary object
#' @param inf Whether to plot confidence intervals
#' @param levels Which units/groups to plot, default is every group
#' @param label Whether to label the individual levels
#' @param weights Whether to plot the weights, default = FALSE
#' @param ... Optional arguments
#' @export
plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T,
                                    weights = FALSE, ...) {


    if(weights) {
        ever_trt <- x$data$units[is.finite(x$data$trt)]
        never_trt <- x$data$units[!is.finite(x$data$trt)]
        weights <- data.frame(x$weights)

        colnames(weights) <- ever_trt
        weights$unit <- factor(rownames(weights),
                              levels = sort(rownames(weights), decreasing = TRUE))

        # plotting the weights
        weights %>%
          tidyr::pivot_longer(-unit, names_to = "trt_unit", values_to = "weight") %>%
          ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) + 
          ggplot2::geom_tile(color = "white", size=.5) +
          ggplot2::scale_fill_gradient(low = "white", high = "black", limits=c(-0.01,1.01)) +
          ggplot2::guides(fill = "none") +
          ggplot2::xlab("Treated Unit") +
          ggplot2::ylab("Donor Unit") +
          ggplot2::theme_bw() + 
          ggplot2::theme(axis.ticks.x = ggplot2::element_blank(),
                axis.ticks.y = ggplot2::element_blank())
    }

    summ <- x
    
    ## get the last time period for each level
    summ$att %>%
        filter(!is.na(Estimate),
               Time >= -summ$n_lags,
               Time <= summ$n_leads) %>%
        group_by(Level) %>%
        summarise(last_time = max(Time)) -> last_times

    if(is.null(levels)) levels <- unique(summ$att$Level)

    summ$att %>% inner_join(last_times) %>%
        filter(Level %in% levels) %>%
        mutate(label = ifelse(Time == last_time, Level, NA),
               is_avg = ifelse(("Average" %in% levels) * (Level == "Average"),
                               "A", "B")) %>%
        ggplot2::ggplot(ggplot2::aes(x = Time, y = Estimate,
                                     group = Level,
                                     color = is_avg,
                                     alpha = is_avg)) +
            ggplot2::geom_line(size = 1) +
            ggplot2::geom_point(size = 1) -> p
            
        if(label) {
          p <- p + ggrepel::geom_label_repel(ggplot2::aes(label = label),
                                      nudge_x = 1, na.rm = T)
        } 
        p <- p + ggplot2::geom_hline(yintercept = 0, lty = 2)

    if(summ$relative) {
        p <- p + ggplot2::geom_vline(xintercept = 0, lty = 2) +
            ggplot2::xlab("Time Relative to Treatment")
    } else {
        p <- p + ggplot2::geom_vline(aes(xintercept = as.numeric(Level)),
                                     lty = 2, alpha = 0.5,
                                     summ$att %>% filter(Level != "Average"))
    }

    ## add ses
    if(inf) {
        max_time <- max(summ$att$Time, na.rm = T)
        if(max_time == 0) {
          error_plt <- ggplot2::geom_errorbar
          clr <- "black"
          alph <- 1
        } else {
          error_plt <- ggplot2::geom_ribbon
          clr <- NA
          alph <- 0.2
        }
        if("Average" %in% levels) {
            p <- p + error_plt(
                ggplot2::aes(ymin=lower_bound,
                             ymax=upper_bound),
                alpha = alph, color=clr,
                data = summ$att %>% 
                        filter(Level == "Average",
                               Time >= 0))
            
        } else {
            p <- p + error_plt(
                ggplot2::aes(ymin=lower_bound,
                             ymax=upper_bound),
                             data = . %>% filter(Time >= 0),
                alpha = alph, color = clr)
        }
    }

    p <- p + ggplot2::scale_alpha_manual(values=c(1, 0.5)) +
        ggplot2::scale_color_manual(values=c("#333333", "#818181")) +
        ggplot2::guides(alpha=F, color=F) + 
        ggplot2::theme_bw()
    return(p)

}


================================================
FILE: R/outcome_models.R
================================================
################################################################################
## Code to fit various outcome models
################################################################################

#' Use a separate regularized regression for each post period
#' to fit E[Y(0)|X]
#' @importFrom stats poly
#' @importFrom stats coef
#'
#' @param X Matrix of covariates/lagged outcomes
#' @param y Matrix of post-period outcomes
#' @param trt Vector of treatment indicator
#' @param alpha Mixing between L1 and L2, default: 1 (LASSO)
#' @param lambda Regularization hyperparameter, if null then CV
#' @param poly_order Order of polynomial to fit, default 1
#' @param type How to fit outcome model(s)
#'             \itemize{
#'              \item{sep }{Separate outcome models}
#'              \item{avg }{Average responses into 1 outcome}
#'              \item{multi }{Use multi response regression in glmnet}}
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#'           \item{y0hat }{Predicted outcome under control}
#'           \item{params }{Regression parameters}}
fit_prog_reg <- function(X, y, trt, alpha=1, lambda=NULL,
                         poly_order=1, type="sep", ...) {
    if(!requireNamespace("glmnet", quietly = TRUE)) {
        stop("In order to fit an elastic net outcome model, you must install the glmnet package.")
    }
    
    extra_params = list(...)
    if (length(extra_params) > 0) {
        warning("Unused parameters when using elastic net: ", paste(names(extra_params), collapse = ", "))
    }
    
    X <- matrix(poly(matrix(X),degree=poly_order), nrow=dim(X)[1])

    ## helper function to fit regression with CV
    outfit <- function(x, y) {
        if(is.null(lambda)) {
            lam <- glmnet::cv.glmnet(x, y, alpha=alpha, grouped=FALSE)$lambda.min
        } else {
            lam <- lambda
        }
        fit <- glmnet::glmnet(x, y, alpha=alpha,
                              lambda=lam)
        
        return(as.matrix(coef(fit)))
    }

    if(type=="avg") {
        ## if fitting the average post period value, stack post periods together
        stacky <- c(y)
        stackx <- do.call(rbind,
                          lapply(1:dim(y)[2],
                                 function(x) X))
        stacktrt <- rep(trt, dim(y)[2])
        regweights <- outfit(stackx[stacktrt==0,],
                             stacky[stacktrt==0])
    } else if(type=="sep"){
        ## fit separate regressions for each post period
        regweights <- apply(as.matrix(y), 2,
                            function(yt) outfit(X[trt==0,],
                                                yt[trt==0]))
    } else {
        ## fit multi response regression
        lam <- glmnet::cv.glmnet(X, y, family="mgaussian",
                                 alpha=alpha, grouped=FALSE)$lambda.min
        fit <- glmnet::glmnet(X, y, family="mgaussian",
                              alpha=alpha,
                              lambda=lam)
        regweights <- as.matrix(do.call(cbind, coef(fit)))
    }


    ## Get predicted values
    y0hat <- cbind(rep(1, dim(X)[1]),
                   X) %*% regweights

    return(list(y0hat = y0hat,
                params  = regweights))
}



#' Use a separate random forest regression for each post period
#' to fit E[Y(0)|X]
#'
#' @param X Matrix of covariates/lagged outcomes
#' @param y Matrix of post-period outcomes
#' @param trt Vector of treatment indicator
#' @param avg Predict the average post-treatment outcome
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#'           \item{y0hat }{Predicted outcome under control}
#'           \item{params }{Regression parameters}}
fit_prog_rf <- function(X, y, trt, avg=FALSE, ...) {

    if(!requireNamespace("randomForest", quietly = TRUE)) {
        stop("In order to fit a random forest outcome model, you must install the randomForest package.")
    }
    
    extra_params = list(...)
    if (length(extra_params) > 0) {
        warning("Unused parameters when using random forest: ", paste(names(extra_params), collapse = ", "))
    }

    
    ## helper function to fit RF
    outfit <- function(x, y) {
            fit <- randomForest::randomForest(x, y)
            return(fit)
    }


    if(avg | dim(y)[2] == 1) {
        ## if fitting the average post period value, stack post periods together
        stacky <- c(y)
        stackx <- do.call(rbind,
                          lapply(1:dim(y)[2],
                                 function(x) X))
        stacktrt <- rep(trt, dim(y)[2])
        fit <- outfit(stackx[stacktrt==0,],
                      stacky[stacktrt==0])

        ## predict outcome
        y0hat <- matrix(predict(fit, X), ncol=1)

        
        ## keep feature importances
        imports <- randomForest::importance(fit)

        
    } else {
        ## fit separate regressions for each post period
        fits <- apply(as.matrix(y), 2,
                      function(yt) outfit(X[trt==0,],
                                          yt[trt==0]))
        
        ## predict outcome
        y0hat <- lapply(fits, function(fit) as.matrix(predict(fit,X))) %>%
            bind_rows() %>%
            as.matrix()

        
        ## keep feature importances
        imports <- lapply(fits, function(fit) randomForest::importance(fit)) %>%
            bind_rows() %>%
            as.matrix()

    }


    return(list(y0hat=y0hat,
                params=imports))
    
}


#' Use gsynth to fit factor model for E[Y(0)|X]
#'
#' @param X Matrix of covariates/lagged outcomes
#' @param y Matrix of post-period outcomes
#' @param trt Vector of treatment indicator
#' @param r Number of factors to use (or start with if CV==1)
#' @param r.end Max number of factors to consider if CV==1
#' @param force Fixed effects (0=none, 1=unit, 2=time, 3=two-way)
#' @param CV Whether to do CV (0=no CV, 1=yes CV)
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#'           \item{y0hat }{Predicted outcome under control}
#'           \item{params }{Regression parameters}}
fit_prog_gsynth <- function(X, y, trt, r=0, r.end=5, force=3, CV=1, ...) {
    if(!requireNamespace("gsynth", quietly = TRUE)) {
        stop("In order to fit generalized synthetic controls, you must install the gsynth package.")
    }
    extra_params = list(...)
    if (length(extra_params) > 0) {
        warning("Unused parameters when using gSynth: ", paste(names(extra_params), collapse = ", "))
    }
    
    df_x = data.frame(X, check.names=FALSE)
    df_x$unit = rownames(df_x)
    df_x$trt = rep(0, nrow(df_x))
    df_x <- df_x %>% select(unit, trt, everything())
    long_df_x = gather(df_x, time, obs, -c(unit,trt))

    df_y = data.frame(y, check.names=FALSE)
    df_y$unit = rownames(df_y)
    df_y$trt = trt
    df_y <- df_y %>% select(unit, trt, everything())
    long_df_y = gather(df_y, time, obs, -c(unit,trt))
    long_df = rbind(long_df_x, long_df_y)

    transform(long_df, time = as.numeric(time))
    transform(long_df, unit = as.numeric(unit))
    gsyn <- gsynth::gsynth(data = long_df, Y = "obs", D = "trt", 
                           index = c("unit", "time"), force = force, CV = CV, r = r)

    t0 <- dim(X)[2]
    t_final <- t0 + dim(y)[2]
    n <- dim(X)[1]
    ## get predicted outcomes
    y0hat <- matrix(0, nrow=n, ncol=(t_final-t0))
    y0hat[trt==0,]  <- t(gsyn$Y.co[(t0+1):t_final,,drop=FALSE] -
                             gsyn$est.co$residuals[(t0+1):t_final,,drop=FALSE])

    y0hat[trt==1,] <- gsyn$Y.ct[(t0+1):t_final,]

    ## add treated prediction for whole pre-period
    gsyn$est.co$Y.ct <- gsyn$Y.ct

    ## control and treated residuals
    gsyn$est.co$ctrl_resids <- gsyn$est.co$residuals
    gsyn$est.co$trt_resids <- colMeans(cbind(X[trt==1,,drop=FALSE],
                                             y[trt==1,,drop=FALSE])) -
        rowMeans(gsyn$est.co$Y.ct)
    return(list(y0hat=y0hat,
                params=gsyn$est.co))
}


#' Use Athey (2017) matrix completion panel data code
#'
#' @param X Matrix of covariates/lagged outcomes
#' @param y Matrix of post-period outcomes
#' @param trt Vector of treatment indicator
#' @param unit_fixed Whether to estimate unit fixed effects
#' @param time_fixed Whether to estimate time fixed effects
#' @param ... optional arguments for outcome model
#' @noRd
#' @return \itemize{
#'           \item{y0hat }{Predicted outcome under control}
#'           \item{params }{Regression parameters}}
fit_prog_mcpanel <- function(X, y, trt, unit_fixed=1, time_fixed=1, ...) {


    if(!requireNamespace("MCPanel", quietly = TRUE)) {
        stop("In order to fit matrix completion, you must install the MCPanel package.")
    }
    
    extra_params = list(...)
    if (length(extra_params) > 0) {
        warning("Unused parameters when using MCPanel: ", paste(names(extra_params), collapse = ", "))
    }
    
    ## create matrix and missingness matrix

    t0 <- dim(X)[2]
    t_final <- t0 + dim(y)[2]
    n <- dim(X)[1]    
    
    fullmat <- cbind(X, y)
    maskmat <- matrix(1, nrow=nrow(fullmat), ncol=ncol(fullmat))
    maskmat[trt==1, (t0+1):t_final] <- 0

    ## estimate matrix
    mcp <- MCPanel::mcnnm_cv(fullmat, maskmat,
                             to_estimate_u=unit_fixed, to_estimate_v=time_fixed)
    
    ## impute matrix
    imp_mat <- mcp$L +
        sweep(matrix(0, nrow=nrow(fullmat), ncol=ncol(fullmat)), 1, mcp$u, "+") + # unit fixed
        sweep(matrix(0, nrow=nrow(fullmat), ncol=ncol(fullmat)), 2, mcp$v, "+") # time fixed
    
    
    trtmat <- matrix(0, ncol=n, nrow=t_final)
    trtmat[t0:t_final, trt == 1] <- 1

    ## get predicted outcomes
    y0hat <- imp_mat[,(t0+1):t_final,drop=FALSE]
    params <- mcp

    params$trt_resids <- colMeans(cbind(X[trt==1,,drop=FALSE],
                                        y[trt==1,,drop=FALSE])) -
        rowMeans(imp_mat[trt==1,,drop=FALSE])

    params$ctrl_resids <- t(cbind(X[trt==0,,drop=FALSE],
                                y[trt==0,,drop=FALSE]) - imp_mat[trt==0,,drop=FALSE])
    params$Y.ct <- t(imp_mat[trt==1,,drop=FALSE])
    return(list(y0hat=y0hat,
                params=params))
    
}


#' Fit a Comparitive interupted time series
#' to fit E[Y(0)|X]
#' @importFrom stats lm
#' @im
Download .txt
gitextract_wwvwxodd/

├── .Rbuildignore
├── .gitignore
├── .travis.yml
├── DESCRIPTION
├── LICENSE
├── NAMESPACE
├── R/
│   ├── augsynth.R
│   ├── augsynth_pre.R
│   ├── cv.R
│   ├── data.R
│   ├── eligible_donors.R
│   ├── fit_synth.R
│   ├── format.R
│   ├── globalVariables.R
│   ├── highdim.R
│   ├── inference.R
│   ├── multi_outcomes.R
│   ├── multi_synth_qp.R
│   ├── multisynth_class.R
│   ├── outcome_models.R
│   ├── outcome_multi.R
│   ├── ridge.R
│   ├── ridge_lambda.R
│   └── time_regression_multi.R
├── README.md
├── data/
│   └── kansas.rda
├── data-raw/
│   ├── clean_kansas.R
│   └── kansas_longer2.dta
├── man/
│   ├── augsynth-package.Rd
│   ├── augsynth.Rd
│   ├── augsynth_multiout.Rd
│   ├── check_data_stag.Rd
│   ├── conformal_inf.Rd
│   ├── conformal_inf_linear.Rd
│   ├── conformal_inf_multiout.Rd
│   ├── get_nona_donors.Rd
│   ├── jackknife_se_single.Rd
│   ├── kansas.Rd
│   ├── make_V_matrix.Rd
│   ├── multisynth.Rd
│   ├── plot.augsynth.Rd
│   ├── plot.augsynth_multiout.Rd
│   ├── plot.multisynth.Rd
│   ├── plot.summary.augsynth.Rd
│   ├── plot.summary.augsynth_multiout.Rd
│   ├── plot.summary.multisynth.Rd
│   ├── predict.augsynth.Rd
│   ├── predict.augsynth_multiout.Rd
│   ├── predict.multisynth.Rd
│   ├── print.augsynth.Rd
│   ├── print.augsynth_multiout.Rd
│   ├── print.multisynth.Rd
│   ├── print.summary.augsynth.Rd
│   ├── print.summary.augsynth_multiout.Rd
│   ├── print.summary.multisynth.Rd
│   ├── rdirichlet_b.Rd
│   ├── rmultinom_b.Rd
│   ├── rwild_b.Rd
│   ├── single_augsynth.Rd
│   ├── summary.augsynth.Rd
│   ├── summary.augsynth_multiout.Rd
│   ├── summary.multisynth.Rd
│   ├── time_jackknife_plus.Rd
│   └── time_jackknife_plus_multiout.Rd
├── pkg.Rproj
├── tests/
│   ├── testthat/
│   │   ├── test_augsynth_pre.R
│   │   ├── test_format.R
│   │   ├── test_general.R
│   │   ├── test_lambda.R
│   │   ├── test_load_data.R
│   │   ├── test_multiple_outcomes.R
│   │   ├── test_multisynth.R
│   │   ├── test_multisynth_covariates.R
│   │   ├── test_outcome_models.R
│   │   ├── test_time_cohort.R
│   │   └── test_unbalanced_multisynth.R
│   └── testthat.R
└── vignettes/
    ├── .gitignore
    ├── multi-outcomes-vignette.Rmd
    ├── multisynth-vignette.Rmd
    ├── multisynth-vignette.md
    ├── singlesynth-vignette.Rmd
    └── singlesynth-vignette.md
Condensed preview — 83 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (401K chars).
[
  {
    "path": ".Rbuildignore",
    "chars": 67,
    "preview": "^data-raw$\n^Meta$\n^doc$\n^\\.travis\\.yml$\n^pkg.Rproj$\nfigure$\ncache$\n"
  },
  {
    "path": ".gitignore",
    "chars": 233,
    "preview": "Meta\ndoc\ninst/doc\n\n## Files\n# Emacs autosave files\n*~\n\\#*#\n# Don't put data in the repo\n*.csv\n*.feather\n\n# R stuff\n*.Rou"
  },
  {
    "path": ".travis.yml",
    "chars": 266,
    "preview": "# R for travis: see documentation at https://docs.travis-ci.com/user/languages/r\n\nlanguage: r\nr:\n    - 3.5.1  \nsudo: fal"
  },
  {
    "path": "DESCRIPTION",
    "chars": 741,
    "preview": "Package: augsynth\nTitle: The Augmented Synthetic Control Method\nVersion: 0.2.0\nAuthors@R: person(\"Eli\", \"Ben-Michael\", e"
  },
  {
    "path": "LICENSE",
    "chars": 1076,
    "preview": "MIT License\n\nCopyright (c) 2018 Elijahu Ben-Michael\n\nPermission is hereby granted, free of charge, to any person obtaini"
  },
  {
    "path": "NAMESPACE",
    "chars": 1232,
    "preview": "# Generated by roxygen2: do not edit by hand\n\nS3method(plot,augsynth)\nS3method(plot,augsynth_multiout)\nS3method(plot,mul"
  },
  {
    "path": "R/augsynth.R",
    "chars": 21716,
    "preview": "################################################################################\n## Main functions for single-period tre"
  },
  {
    "path": "R/augsynth_pre.R",
    "chars": 3864,
    "preview": "################################################################################\n## Main function for the augmented synt"
  },
  {
    "path": "R/cv.R",
    "chars": 3987,
    "preview": "drop_time_t <- function(wide_data, Z, t_drop) {\n  new_wide_data <- list()\n  new_wide_data$trt <- wide_data$trt\n  \n  if ("
  },
  {
    "path": "R/data.R",
    "chars": 1410,
    "preview": "#' Economic indicators for US states from 1990-2016\n#' \n#' \n#' @format A dataframe with 5250 rows and 32 variables:\n#' \\"
  },
  {
    "path": "R/eligible_donors.R",
    "chars": 4979,
    "preview": "##############################################################################\n## Code to get eligible donor units based"
  },
  {
    "path": "R/fit_synth.R",
    "chars": 2659,
    "preview": "#######################################################\n# Helper scripts to fit synthetic controls to simulations\n######"
  },
  {
    "path": "R/format.R",
    "chars": 11471,
    "preview": "################################################################################\n## Scripts to format panel data into ma"
  },
  {
    "path": "R/globalVariables.R",
    "chars": 317,
    "preview": "utils::globalVariables(c(\"time\", \"val\", \"post\", \"weight\", \".\", \"Time\",\n                         \"Estimate\", \"Std.Error\","
  },
  {
    "path": "R/highdim.R",
    "chars": 7802,
    "preview": "################################################################################\n## Methods to use flexible outcome mode"
  },
  {
    "path": "R/inference.R",
    "chars": 47117,
    "preview": "################################################################################\n## Code for inference\n#################"
  },
  {
    "path": "R/multi_outcomes.R",
    "chars": 22446,
    "preview": "#' Fit Augmented SCM with multiple outcomes\n#' @param form outcome ~ treatment | auxillary covariates\n#' @param unit Nam"
  },
  {
    "path": "R/multi_synth_qp.R",
    "chars": 16216,
    "preview": "################################################################################\n## Solve the multisynth problem as a QP"
  },
  {
    "path": "R/multisynth_class.R",
    "chars": 39714,
    "preview": "################################################################################\n## Fitting, plotting, summarizing stagg"
  },
  {
    "path": "R/outcome_models.R",
    "chars": 18480,
    "preview": "################################################################################\n## Code to fit various outcome models\n#"
  },
  {
    "path": "R/outcome_multi.R",
    "chars": 3827,
    "preview": "################################################################################\n## Fitting outcome models for multiple "
  },
  {
    "path": "R/ridge.R",
    "chars": 13898,
    "preview": "################################################################################\n## Ridge-augmented SCM\n################"
  },
  {
    "path": "R/ridge_lambda.R",
    "chars": 2307,
    "preview": "################################################################################\n## Function to calculate error on diffe"
  },
  {
    "path": "R/time_regression_multi.R",
    "chars": 9227,
    "preview": "##############################################################################\n## Outcome regression with multiple treat"
  },
  {
    "path": "README.md",
    "chars": 1371,
    "preview": "# augsynth: Augmented Synthetic Control Method\n[![Build Status](https://travis-ci.org/ebenmichael/augsynth.svg?branch=ma"
  },
  {
    "path": "data-raw/clean_kansas.R",
    "chars": 2115,
    "preview": "library(haven)\nlibrary(tidyverse)\n\nkansas <- read_dta(\"kansas_longer2.dta\")\nstate_abb <- read_csv(\"us-state-ansi-fips.cs"
  },
  {
    "path": "man/augsynth-package.Rd",
    "chars": 252,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\docType{package}\n\\name{augsynt"
  },
  {
    "path": "man/augsynth.Rd",
    "chars": 2013,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth_pre.R\n\\name{augsynth}\n\\alias{augs"
  },
  {
    "path": "man/augsynth_multiout.Rd",
    "chars": 1568,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{augsynth_multiout}\n"
  },
  {
    "path": "man/check_data_stag.Rd",
    "chars": 670,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/format.R\n\\name{check_data_stag}\n\\alias{che"
  },
  {
    "path": "man/conformal_inf.Rd",
    "chars": 1432,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf}\n\\alias{co"
  },
  {
    "path": "man/conformal_inf_linear.Rd",
    "chars": 1461,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf_linear}\n\\a"
  },
  {
    "path": "man/conformal_inf_multiout.Rd",
    "chars": 1504,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{conformal_inf_multiout}\n"
  },
  {
    "path": "man/get_nona_donors.Rd",
    "chars": 390,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/eligible_donors.R\n\\name{get_nona_donors}\n\\"
  },
  {
    "path": "man/jackknife_se_single.Rd",
    "chars": 782,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{jackknife_se_single}\n\\al"
  },
  {
    "path": "man/kansas.Rd",
    "chars": 1563,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/data.R\n\\docType{data}\n\\name{kansas}\n\\alias"
  },
  {
    "path": "man/make_V_matrix.Rd",
    "chars": 268,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/fit_synth.R\n\\name{make_V_matrix}\n\\alias{ma"
  },
  {
    "path": "man/multisynth.Rd",
    "chars": 3654,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{multisynth}\n\\alia"
  },
  {
    "path": "man/plot.augsynth.Rd",
    "chars": 535,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{plot.augsynth}\n\\alias{plo"
  },
  {
    "path": "man/plot.augsynth_multiout.Rd",
    "chars": 618,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{plot.augsynth_multi"
  },
  {
    "path": "man/plot.multisynth.Rd",
    "chars": 841,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{plot.multisynth}\n"
  },
  {
    "path": "man/plot.summary.augsynth.Rd",
    "chars": 457,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{plot.summary.augsynth}\n\\a"
  },
  {
    "path": "man/plot.summary.augsynth_multiout.Rd",
    "chars": 589,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{plot.summary.augsyn"
  },
  {
    "path": "man/plot.summary.multisynth.Rd",
    "chars": 692,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{plot.summary.mult"
  },
  {
    "path": "man/predict.augsynth.Rd",
    "chars": 548,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{predict.augsynth}\n\\alias{"
  },
  {
    "path": "man/predict.augsynth_multiout.Rd",
    "chars": 597,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{predict.augsynth_mu"
  },
  {
    "path": "man/predict.multisynth.Rd",
    "chars": 776,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{predict.multisynt"
  },
  {
    "path": "man/print.augsynth.Rd",
    "chars": 329,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{print.augsynth}\n\\alias{pr"
  },
  {
    "path": "man/print.augsynth_multiout.Rd",
    "chars": 371,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{print.augsynth_mult"
  },
  {
    "path": "man/print.multisynth.Rd",
    "chars": 368,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{print.multisynth}"
  },
  {
    "path": "man/print.summary.augsynth.Rd",
    "chars": 394,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{print.summary.augsynth}\n\\"
  },
  {
    "path": "man/print.summary.augsynth_multiout.Rd",
    "chars": 445,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{print.summary.augsy"
  },
  {
    "path": "man/print.summary.multisynth.Rd",
    "chars": 516,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{print.summary.mul"
  },
  {
    "path": "man/rdirichlet_b.Rd",
    "chars": 259,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rdirichlet_b}\n\\alias{rdi"
  },
  {
    "path": "man/rmultinom_b.Rd",
    "chars": 268,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rmultinom_b}\n\\alias{rmul"
  },
  {
    "path": "man/rwild_b.Rd",
    "chars": 264,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{rwild_b}\n\\alias{rwild_b}"
  },
  {
    "path": "man/single_augsynth.Rd",
    "chars": 1518,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{single_augsynth}\n\\alias{s"
  },
  {
    "path": "man/summary.augsynth.Rd",
    "chars": 1085,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/augsynth.R\n\\name{summary.augsynth}\n\\alias{"
  },
  {
    "path": "man/summary.augsynth_multiout.Rd",
    "chars": 718,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multi_outcomes.R\n\\name{summary.augsynth_mu"
  },
  {
    "path": "man/summary.multisynth.Rd",
    "chars": 1295,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/multisynth_class.R\n\\name{summary.multisynt"
  },
  {
    "path": "man/time_jackknife_plus.Rd",
    "chars": 915,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{time_jackknife_plus}\n\\al"
  },
  {
    "path": "man/time_jackknife_plus_multiout.Rd",
    "chars": 948,
    "preview": "% Generated by roxygen2: do not edit by hand\n% Please edit documentation in R/inference.R\n\\name{time_jackknife_plus_mult"
  },
  {
    "path": "pkg.Rproj",
    "chars": 312,
    "preview": "Version: 1.0\n\nRestoreWorkspace: No\nSaveWorkspace: No\nAlwaysSaveHistory: Default\n\nEnableCodeIndexing: Yes\nEncoding: UTF-8"
  },
  {
    "path": "tests/testthat/test_augsynth_pre.R",
    "chars": 4308,
    "preview": "context(\"Testing that top level API runs the right functions\")\n\nlibrary(Synth)\n\n\ntest_that(\"augsynth runs single_synth w"
  },
  {
    "path": "tests/testthat/test_format.R",
    "chars": 3487,
    "preview": "context(\"Test data formatting\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0"
  },
  {
    "path": "tests/testthat/test_general.R",
    "chars": 4775,
    "preview": "context(\"Generally testing the workflow for augsynth\")\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = c"
  },
  {
    "path": "tests/testthat/test_lambda.R",
    "chars": 1604,
    "preview": "context(\"Testing lambda tuning if ridge is true.\")\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_"
  },
  {
    "path": "tests/testthat/test_load_data.R",
    "chars": 121,
    "preview": "context(\"Testing that we can load packaged data\")\n\ntest_that(\"kansas data loads\", {\n    expect_error(data(kansas), NA)\n}"
  },
  {
    "path": "tests/testthat/test_multiple_outcomes.R",
    "chars": 7158,
    "preview": "context(\"Generally testing the workflow for synth with multiple outcomes\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque"
  },
  {
    "path": "tests/testthat/test_multisynth.R",
    "chars": 7439,
    "preview": "context(\"Generally testing the workflow for multisynth\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = "
  },
  {
    "path": "tests/testthat/test_multisynth_covariates.R",
    "chars": 11918,
    "preview": "context(\"Testing multisynth with covariates\")\nset.seed(1011)\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(tr"
  },
  {
    "path": "tests/testthat/test_outcome_models.R",
    "chars": 3671,
    "preview": "context(\"Testing that augmenting synth with different models loads and runs\")\n\n\n\nlibrary(Synth)\ndata(basque)\nbasque <- b"
  },
  {
    "path": "tests/testthat/test_time_cohort.R",
    "chars": 1447,
    "preview": "context(\"Test time cohort vs unit level analysis\")\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutate(trt = case_w"
  },
  {
    "path": "tests/testthat/test_unbalanced_multisynth.R",
    "chars": 7151,
    "preview": "context(\"Test multisynth for unbalanced panels\")\n\nset.seed(1011)\n\nlibrary(Synth)\ndata(basque)\nbasque <- basque %>% mutat"
  },
  {
    "path": "tests/testthat.R",
    "chars": 60,
    "preview": "library(testthat)\nlibrary(augsynth)\n\ntest_check(\"augsynth\")\n"
  },
  {
    "path": "vignettes/.gitignore",
    "chars": 11,
    "preview": "*.html\n*.R\n"
  },
  {
    "path": "vignettes/multi-outcomes-vignette.Rmd",
    "chars": 4564,
    "preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{Multi Outcomes AugSynth Vignette}\n  %\\VignetteEn"
  },
  {
    "path": "vignettes/multisynth-vignette.Rmd",
    "chars": 7377,
    "preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{MultiSynth Vignette}\n  %\\VignetteEngine{knitr::r"
  },
  {
    "path": "vignettes/multisynth-vignette.md",
    "chars": 20171,
    "preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{MultiSynth Vignette}\n  %\\VignetteEngine{knitr::r"
  },
  {
    "path": "vignettes/singlesynth-vignette.Rmd",
    "chars": 8303,
    "preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{Single Outcome AugSynth Vignette}\n  %\\VignetteEn"
  },
  {
    "path": "vignettes/singlesynth-vignette.md",
    "chars": 19586,
    "preview": "---\noutput: rmarkdown::html_vignette\nvignette: >\n  %\\VignetteIndexEntry{Single Outcome AugSynth Vignette}\n  %\\VignetteEn"
  }
]

// ... and 2 more files (download for full content)

About this extraction

This page contains the full source code of the ebenmichael/augsynth GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 83 files (373.9 KB), approximately 110.9k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!