[
  {
    "path": ".github/workflows/release.yml",
    "content": "name: Package release\n\non:\n  release:\n    types: [created]\n\n\njobs:\n  deploy_osx:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        python-version: [\"3.7\", \"3.8\", \"3.9\", \"3.10\"]\n        os: [macos-latest]\n\n    steps:\n      - uses: actions/checkout@v2\n        with:\n          submodules: true\n      - name: Set up Python\n        uses: actions/setup-python@v1\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Build and publish\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_osx.sh\n\n  deploy_linux:\n    strategy:\n      matrix:\n        python-version:\n          - cp37-cp37m\n          - cp38-cp38\n          - cp39-cp39\n          - cp310-cp310\n\n    runs-on: ubuntu-latest\n    container: quay.io/pypa/manylinux2014_x86_64\n    steps:\n      - uses: actions/checkout@v1\n        with:\n          submodules: true\n      - name: Set target Python version PATH\n        run: |\n            echo \"/opt/python/${{ matrix.python-version }}/bin\" >> $GITHUB_PATH\n      - name: Build and publish\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_linux.sh\n\n  deploy_windows:\n    runs-on: windows-latest\n    strategy:\n      matrix:\n        python-version: [\"3.7\", \"3.8\", \"3.9\", \"3.10\"]\n\n    steps:\n      - uses: actions/checkout@v2\n        with:\n          submodules: true\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v1\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Build and publish\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: |\n          ../../.github/workflows/scripts/release_windows.bat\n"
  },
  {
    "path": ".github/workflows/scripts/release_linux.sh",
    "content": "#!/bin/bash\n\nset -e\n\nyum makecache -y\nyum install centos-release-scl -y\nyum-config-manager --enable rhel-server-rhscl-7-rpms\nyum install llvm-toolset-7.0 python3 python3-devel -y\n\n# Python\npython3 -m pip install --upgrade pip\npython3 -m pip install setuptools wheel twine auditwheel\n\n# Publish\npython3 -m pip wheel . -w dist/ --no-deps\ntwine upload --verbose --skip-existing dist/*\n"
  },
  {
    "path": ".github/workflows/scripts/release_osx.sh",
    "content": "#!/bin/sh\n\nexport MACOSX_DEPLOYMENT_TARGET=10.14\n\npython -m pip install --upgrade pip\npip install setuptools wheel twine auditwheel\n\npython3 setup.py build bdist_wheel --plat-name macosx_10_14_x86_64 --dist-dir wheel\ntwine upload --skip-existing wheel/*\n"
  },
  {
    "path": ".github/workflows/scripts/release_windows.bat",
    "content": "echo on\n\npython -m pip install --upgrade pip\npip install setuptools wheel twine auditwheel\n\npip wheel . -w wheel/ --no-deps\ntwine upload --skip-existing wheel/*\n"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: CATENets Tests\n    \non:\n  push:\n    branches: [main, release]\n  pull_request:\n    types: [opened, synchronize, reopened]\n  schedule:\n    - cron:  '0 0 * * 0'\n  workflow_dispatch:\n\n\njobs:\n  Linter:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        python-version: [3.8]\n        os: [ubuntu-latest]\n    steps:\n      - uses: actions/checkout@v2\n        with:\n          submodules: true\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v1\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Install dependencies\n        run: pip install .[testing]\n      - name: pre-commit validation\n        run: pre-commit run --files catenets/*\n      - name: Security checks\n        run: |\n            bandit -r catenets/*\n\n  Library:\n    needs: [Linter]\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        python-version: ['3.8', '3.9', \"3.10\"]\n        os: [macos-latest, ubuntu-latest, windows-latest]\n    steps:\n      - uses: actions/checkout@v2\n        with:\n          submodules: true\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v1\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Install MacOS dependencies\n        run: |\n            brew install libomp\n        if: ${{ matrix.os == 'macos-latest' }}\n      - name: Install dependencies\n        run: |\n            python -m pip install --upgrade pip\n            pip install .[testing]\n      - name: Test with pytest Unix\n        run: pytest -vvvsx -m \"not slow\"\n        if: ${{ matrix.os != 'windows-latest' }}\n      - name: Test with pytest Windows\n        run: |\n            cd tests\\datasets\n            pytest -vvvsx -m \"not slow\"\n            cd ..\\..\n\n            cd tests\\models\\torch\n            pytest -vvvsx -m \"not slow\"\n        if: ${{ matrix.os == 'windows-latest' }}\n"
  },
  {
    "path": ".gitignore",
    "content": "*.pyc\n*.xml\n*.iml\n*.csv\n*.xlsx\n*.Rhistory\n.idea/\n.coverage\n.ipynb_checkpoints\n.ipynb_checkpoints/\n*/.ipynb_checkpoints/\n*/bin/\n*/include/\n*/lib/\n*/lib64/\n*/share/\n*.cfg\n.pytest_cache\ndata/\nbuild/\ncatenets.egg-info/\ndist/\ngenerated/\n_build\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "exclude: 'setup.py|^docs'\n\nrepos:\n- repo: https://github.com/pre-commit/pre-commit-hooks\n  rev: v3.4.0\n  hooks:\n  - id: trailing-whitespace\n  - id: check-added-large-files\n  - id: check-ast\n  - id: check-json\n  - id: check-merge-conflict\n  - id: check-xml\n  - id: check-yaml\n  - id: debug-statements\n  - id: check-executables-have-shebangs\n  - id: end-of-file-fixer\n  - id: requirements-txt-fixer\n  - id: mixed-line-ending\n    args: ['--fix=auto']  # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows\n\n- repo: https://github.com/pycqa/isort\n  rev: 5.12.0\n  hooks:\n  - id: isort\n\n- repo: https://github.com/psf/black\n  rev: 22.3.0\n  hooks:\n  - id: black\n    language_version: python3\n- repo: https://github.com/pycqa/flake8\n  rev: 3.9.1\n  hooks:\n  - id: flake8\n    args: [\n        \"--max-line-length=140\",\n        \"--extend-ignore=E203,W503\"\n    ]\n- repo: https://github.com/pre-commit/mirrors-mypy\n  rev: v0.812\n  hooks:\n  - id: mypy\n    args: [\n          \"--ignore-missing-imports\",\n          \"--scripts-are-modules\",\n          \"--disallow-incomplete-defs\",\n          \"--no-implicit-optional\",\n          \"--warn-unused-ignores\",\n          \"--warn-redundant-casts\",\n          \"--strict-equality\",\n          \"--warn-unreachable\",\n          \"--disallow-untyped-defs\",\n          \"--disallow-untyped-calls\",\n      ]\n- repo: local\n  hooks:\n  - id: flynt\n    name: flynt\n    entry: flynt\n    args: [--fail-on-change]\n    types: [python]\n    language: python\n    additional_dependencies:\n        - flynt\n"
  },
  {
    "path": "LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2021, Alicia Curth\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived from\n   this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "README.md",
    "content": "# CATENets - Conditional Average Treatment Effect Estimation Using Neural Networks\n\n[![CATENets Tests](https://github.com/AliciaCurth/CATENets/actions/workflows/test.yml/badge.svg)](https://github.com/AliciaCurth/CATENets/actions/workflows/test.yml)\n[![Documentation Status](https://readthedocs.org/projects/catenets/badge/?version=latest)](https://catenets.readthedocs.io/en/latest/?badge=latest)\n[![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://github.com/AliciaCurth/CATENets/blob/main/LICENSE)\n\n\nCode Author: Alicia Curth (amc253@cam.ac.uk)\n\nThis repo contains Jax-based, sklearn-style implementations of Neural Network-based Conditional\nAverage Treatment Effect (CATE) Estimators, which were used in the AISTATS21 paper\n['Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning\nAlgorithms']( https://arxiv.org/abs/2101.10943) (Curth & vd Schaar, 2021a) as well as the follow up\nNeurIPS21 paper [\"On Inductive Biases for Heterogeneous Treatment Effect Estimation\"](https://arxiv.org/abs/2106.03765) (Curth & vd\nSchaar, 2021b) and the NeurIPS21 Datasets & Benchmarks track paper [\"Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation\"](https://openreview.net/forum?id=FQLzQqGEAH) (Curth et al, 2021).\n\nWe implement the SNet-class we introduce in Curth & vd Schaar (2021a), as well as FlexTENet and\nOffsetNet as discussed in Curth & vd Schaar (2021b), and re-implement a number of\nNN-based algorithms from existing literature (Shalit et al (2017), Shi et al (2019), Hassanpour\n& Greiner (2020)). We also provide Neural Network (NN)-based instantiations of a number of so-called\nmeta-learners for CATE estimation, including two-step pseudo-outcome regression estimators (the\nDR-learner (Kennedy, 2020) and single-robust propensity-weighted (PW) and regression-adjusted (RA) learners), Nie & Wager (2017)'s R-learner and Kuenzel et al (2019)'s X-learner. The jax implementations in ``catenets.models.jax`` were used in all papers listed; additionally, pytorch versions of some models (``catenets.models.torch``) were contributed by [Bogdan Cebere](https://github.com/bcebere).\n\n### Interface\nThe repo contains a package ``catenets``, which contains all general code used for modeling and evaluation, and a folder ``experiments``, in which the code for replicating experimental results is contained. All implemented learning algorithms in ``catenets`` (``SNet, FlexTENet, OffsetNet, TNet, SNet1 (TARNet), SNet2\n(DragonNet), SNet3, DRNet, RANet, PWNet, RNet, XNet``) come with a sklearn-style wrapper,  implementing a ``.fit(X, y, w)`` and a ``.predict(X)`` method, where predict returns CATE by default. All hyperparameters are documented in detail in the respective files in ``catenets.models`` folder.\n\nExample usage:\n\n```python\nfrom catenets.models.jax import TNet, SNet\nfrom catenets.experiment_utils.simulation_utils import simulate_treatment_setup\n\n# simulate some data (here: unconfounded, 10 prognostic variables and 5 predictive variables)\nX, y, w, p, cate = simulate_treatment_setup(n=2000, n_o=10, n_t=5, n_c=0)\n\n# estimate CATE using TNet\nt = TNet()\nt.fit(X, y, w)\ncate_pred_t = t.predict(X)  # without potential outcomes\ncate_pred_t, po0_pred_t, po1_pred_t = t.predict(X, return_po=True)  # predict potential outcomes too\n\n# estimate CATE using SNet\ns = SNet(penalty_orthogonal=0.01)\ns.fit(X, y, w)\ncate_pred_s = s.predict(X)\n\n```\n\nAll experiments in Curth & vd Schaar (2021a) can be replicated using this repository; the necessary\ncode is in ``experiments.experiments_AISTATS21``. To do so from shell, clone the repo, create a new\nvirtual environment and run\n```shell\npip install catenets # install the library from PyPI\n# OR\npip install . # install the library from the local repository\n\n# Run the experiments\npython run_experiments_AISTATS.py\n```\n```shell\nOptions:\n--experiment # defaults to 'simulation', 'ihdp' will run ihdp experiments\n--setting # different simulation settings in synthetic experiments (can be 1-5)\n--models # defaults to None which will train all models considered in paper,\n         # can be string of model name (e.g 'TNet'), 'plug' for all plugin models,\n         # 'pseudo' for all pseudo-outcome regression models\n\n--file_name # base file name to write to, defaults to 'results'\n--n_repeats # number of experiments to run for each configuration, defaults to 10 (should be set to 100 for IHDP)\n```\n\nSimilarly, the experiments in Curth & vd Schaar (2021b) can be replicated using the code in\n``experiments.experiments_inductivebias_NeurIPS21`` (or from shell using ```python\nrun_experiments_inductive_bias_NeurIPS.py```) and the experiments in Curth et al (2021) can be replicated using the code in ``experiments.experiments_benchmarks_NeurIPS21`` (the catenets experiments can also be run from shell using ``python run_experiments_benchmarks_NeurIPS``).\n\nThe code can also be installed as a python package (``catenets``). From a local copy of the repo, run ``python setup.py install``.\n\nNote: jax is currently only supported on macOS and linux, but can be run from windows using WSL (the windows subsystem for linux).\n\n\n### Citing\n\nIf you use this software please cite the corresponding paper(s):\n\n```\n@inproceedings{curth2021nonparametric,\n  title={Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning Algorithms},\n  author={Curth, Alicia and van der Schaar, Mihaela},\n    year={2021},\n  booktitle={Proceedings of the 24th International Conference on Artificial\n  Intelligence and Statistics (AISTATS)},\n  organization={PMLR}\n}\n\n@article{curth2021inductive,\n  title={On Inductive Biases for Heterogeneous Treatment Effect Estimation},\n  author={Curth, Alicia and van der Schaar, Mihaela},\n  booktitle={Proceedings of the Thirty-Fifth Conference on Neural Information Processing Systems},\n  year={2021}\n}\n\n\n@article{curth2021really,\n  title={Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation},\n  author={Curth, Alicia and Svensson, David and Weatherall, James and van der Schaar, Mihaela},\n  booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks},\n  year={2021}\n}\n\n```\n"
  },
  {
    "path": "catenets/__init__.py",
    "content": "import sys\n\nfrom . import logger  # noqa: F401\nfrom . import datasets, models  # noqa: F401\n\nlogger.add(sink=sys.stderr, level=\"CRITICAL\")\n"
  },
  {
    "path": "catenets/datasets/__init__.py",
    "content": "# stdlib\nimport os\nfrom pathlib import Path\nfrom typing import Any, Tuple\n\nfrom . import dataset_acic2016, dataset_ihdp, dataset_twins\n\nDATA_PATH = Path(os.path.dirname(__file__)) / Path(\"data\")\n\ntry:\n    os.mkdir(DATA_PATH)\nexcept BaseException:\n    pass\n\n\ndef load(dataset: str, *args: Any, **kwargs: Any) -> Tuple:\n    \"\"\"\n    Input:\n        dataset: the name of the dataset to load\n    Outputs:\n        - Train_X, Test_X: Train and Test features\n        - Train_Y: Observable outcomes\n        - Train_T: Assigned treatment\n        - Test_Y: Potential outcomes.\n    \"\"\"\n    if dataset == \"twins\":\n        return dataset_twins.load(DATA_PATH, *args, **kwargs)\n    if dataset == \"ihdp\":\n        return dataset_ihdp.load(DATA_PATH, *args, **kwargs)\n    if dataset == \"acic2016\":\n        return dataset_acic2016.load(DATA_PATH, *args, **kwargs)\n    else:\n        raise Exception(\"Unsupported dataset\")\n\n\n__all__ = [\"dataset_ihdp\", \"dataset_twins\", \"dataset_acic2016\", \"load\"]\n"
  },
  {
    "path": "catenets/datasets/dataset_acic2016.py",
    "content": "\"\"\"\nACIC2016 dataset\n\"\"\"\nimport glob\n\n# stdlib\nimport random\nfrom pathlib import Path\nfrom typing import Any, Tuple\n\n# third party\nimport numpy as np\nimport pandas as pd\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.preprocessing import OneHotEncoder, StandardScaler\n\nimport catenets.logger as log\n\nfrom .network import download_if_needed\n\nnp.random.seed(0)\nrandom.seed(0)\n\nFILE_ID = \"0B7pG5PPgj6A3N09ibmFwNWE1djA\"\nPREPROCESSED_FILE_ID = \"1iOfEAk402o3jYBs2Prfiz6oaailwWcR5\"\n\nNUMERIC_COLS = [\n    0,\n    3,\n    4,\n    16,\n    17,\n    18,\n    20,\n    21,\n    22,\n    24,\n    24,\n    25,\n    30,\n    31,\n    32,\n    33,\n    39,\n    40,\n    41,\n    53,\n    54,\n]\nN_NUM_COLS = len(NUMERIC_COLS)\n\n\ndef get_acic_covariates(\n    fn_csv: Path, keep_categorical: bool = False, preprocessed: bool = True\n) -> np.ndarray:\n    X = pd.read_csv(fn_csv)\n    if not keep_categorical:\n        X = X.drop(columns=[\"x_2\", \"x_21\", \"x_24\"])\n    else:\n        # encode categorical features\n        feature_list = []\n        for cols_ in X.columns:\n            if type(X.loc[X.index[0], cols_]) not in [np.int64, np.float64]:\n\n                enc = OneHotEncoder(drop=\"first\")\n\n                enc.fit(np.array(X[[cols_]]).reshape((-1, 1)))\n\n                for k in range(len(list(enc.get_feature_names()))):\n                    X[cols_ + list(enc.get_feature_names())[k]] = enc.transform(\n                        np.array(X[[cols_]]).reshape((-1, 1))\n                    ).toarray()[:, k]\n\n                feature_list.append(cols_)\n\n        X.drop(feature_list, axis=1, inplace=True)\n\n    if preprocessed:\n        X_t = X.values\n    else:\n        scaler = StandardScaler()\n        X_t = scaler.fit_transform(X)\n    return X_t\n\n\ndef preprocess_simu(\n    fn_csv: Path,\n    n_0: int = 2000,\n    n_1: int = 200,\n    n_test: int = 500,\n    error_sd: float = 1,\n    sp_lin: float = 0.6,\n    sp_nonlin: float = 0.3,\n    prop_gamma: float = 0,\n    prop_omega: float = 0,\n    ate_goal: float = 0,\n    inter: bool = True,\n    i_exp: int = 0,\n    keep_categorical: bool = False,\n    preprocessed: bool = True,\n) -> Tuple:\n    X = get_acic_covariates(\n        fn_csv, keep_categorical=keep_categorical, preprocessed=preprocessed\n    )\n    np.random.seed(i_exp)\n\n    # shuffle indices\n    n_total, n_cov = X.shape\n    ind = np.arange(n_total)\n    np.random.shuffle(ind)\n    ind_test = ind[-n_test:]\n    ind_1 = ind[n_0 : (n_0 + n_1)]\n\n    # create treatment indicator (treatment assignment does not matter in test set)\n    w = np.zeros(n_total).reshape((-1, 1))\n    w[ind_1] = 1\n\n    # create dgp\n    coeffs_ = [0, 1]\n    # sample baseline coefficients\n    beta_0 = np.random.choice(coeffs_, size=n_cov, replace=True, p=[1 - sp_lin, sp_lin])\n    intercept = np.random.choice([x for x in np.arange(-1, 1.25, 0.25)])\n\n    # sample treatment effect coefficients\n    gamma = np.random.choice(\n        coeffs_, size=n_cov, replace=True, p=[1 - prop_gamma, prop_gamma]\n    )\n    omega = np.random.choice(\n        [0, 1], replace=True, size=n_cov, p=[prop_omega, 1 - prop_omega]\n    )\n\n    # simulate mu_0 and mu_1\n    mu_0 = (intercept + np.dot(X, beta_0)).reshape((-1, 1))\n    mu_1 = (intercept + np.dot(X, gamma + beta_0 * omega)).reshape((-1, 1))\n    if sp_nonlin > 0:\n        coefs_sq = [0, 0.1]\n        beta_sq = np.random.choice(\n            coefs_sq, size=N_NUM_COLS, replace=True, p=[1 - sp_nonlin, sp_nonlin]\n        )\n        omega = np.random.choice(\n            [0, 1], replace=True, size=N_NUM_COLS, p=[prop_omega, 1 - prop_omega]\n        )\n        X_sq = X[:, NUMERIC_COLS] ** 2\n        mu_0 = mu_0 + np.dot(X_sq, beta_sq).reshape((-1, 1))\n        mu_1 = mu_1 + np.dot(X_sq, beta_sq * omega).reshape((-1, 1))\n\n        if inter:\n            # randomly add some interactions\n            ind_c = np.arange(n_cov)\n            np.random.shuffle(ind_c)\n            inter_list = list()\n            for i in range(0, n_cov - 2, 2):\n                inter_list.append(X[:, ind_c[i]] * X[:, ind_c[i + 1]])\n\n            X_inter = np.array(inter_list).T\n            n_inter = X_inter.shape[1]\n            beta_inter = np.random.choice(\n                coefs_sq, size=n_inter, replace=True, p=[1 - sp_nonlin, sp_nonlin]\n            )\n            omega = np.random.choice(\n                [0, 1], replace=True, size=n_inter, p=[prop_omega, 1 - prop_omega]\n            )\n            mu_0 = mu_0 + np.dot(X_inter, beta_inter).reshape((-1, 1))\n            mu_1 = mu_1 + np.dot(X_inter, beta_inter * omega).reshape((-1, 1))\n\n    ate = np.mean(mu_1 - mu_0)\n    mu_1 = mu_1 - ate + ate_goal\n\n    y = (\n        w * mu_1\n        + (1 - w) * mu_0\n        + np.random.normal(0, error_sd, n_total).reshape((-1, 1))\n    )\n\n    X_train, y_train, w_train, mu_0_train, mu_1_train = (\n        X[ind[: (n_0 + n_1)], :],\n        y[ind[: (n_0 + n_1)]],\n        w[ind[: (n_0 + n_1)]],\n        mu_0[ind[: (n_0 + n_1)]],\n        mu_1[ind[: (n_0 + n_1)]],\n    )\n    X_test, y_test, w_test, mu_0_t, mu_1_t = (\n        X[ind_test, :],\n        y[ind_test],\n        w[ind_test],\n        mu_0[ind_test],\n        mu_1[ind_test],\n    )\n\n    return (\n        X_train,\n        w_train,\n        y_train,\n        np.asarray([mu_0_train, mu_1_train]).squeeze().T,\n        X_test,\n        w_test,\n        y_test,\n        np.asarray([mu_0_t, mu_1_t]).squeeze().T,\n    )\n\n\ndef get_acic_orig_filenames(data_path: Path, simu_num: int) -> list:\n    return sorted(\n        glob.glob(\n            (data_path / (\"data_cf_all/\" + str(simu_num) + \"/zymu_*.csv\")).__str__()\n        )\n    )\n\n\ndef get_acic_orig_outcomes(data_path: Path, simu_num: int, i_exp: int) -> Tuple:\n    file_list = get_acic_orig_filenames(data_path=data_path, simu_num=simu_num)\n\n    out = pd.read_csv(file_list[i_exp])\n    w = out[\"z\"]\n    y = w * out[\"y1\"] + (1 - w) * out[\"y0\"]\n    mu_0, mu_1 = out[\"mu0\"], out[\"mu1\"]\n    return y.values, w.values, mu_0.values, mu_1.values\n\n\ndef preprocess_acic_orig(\n    fn_csv: Path,\n    data_path: Path,\n    preprocessed: bool = False,\n    keep_categorical: bool = True,\n    simu_num: int = 1,\n    i_exp: int = 0,\n    train_size: int = 4000,\n    random_split: bool = False,\n) -> Tuple:\n    X = get_acic_covariates(\n        fn_csv, keep_categorical=keep_categorical, preprocessed=preprocessed\n    )\n\n    y, w, mu_0, mu_1 = get_acic_orig_outcomes(\n        data_path=data_path, simu_num=simu_num, i_exp=i_exp\n    )\n\n    if not random_split:\n        X_train, y_train, w_train, mu_0_train, mu_1_train = (\n            X[:train_size, :],\n            y[:train_size],\n            w[:train_size],\n            mu_0[:train_size],\n            mu_1[:train_size],\n        )\n        X_test, y_test, w_test, mu_0_test, mu_1_test = (\n            X[train_size:, :],\n            y[train_size:],\n            w[train_size:],\n            mu_0[train_size:],\n            mu_1[train_size:],\n        )\n    else:\n        (\n            X_train,\n            X_test,\n            y_train,\n            y_test,\n            w_train,\n            w_test,\n            mu_0_train,\n            mu_0_test,\n            mu_1_train,\n            mu_1_test,\n        ) = train_test_split(\n            X, y, w, mu_0, mu_1, test_size=1 - train_size, random_state=i_exp\n        )\n\n    return (\n        X_train,\n        w_train,\n        y_train,\n        np.asarray([mu_0_train, mu_1_train]).squeeze().T,\n        X_test,\n        w_test,\n        y_test,\n        np.asarray([mu_0_test, mu_1_test]).squeeze().T,\n    )\n\n\ndef preprocess(\n    fn_csv: Path,\n    data_path: Path,\n    preprocessed: bool = True,\n    original_acic_outcomes: bool = False,\n    **kwargs: Any,\n) -> Tuple:\n    if not original_acic_outcomes:\n        return preprocess_simu(fn_csv=fn_csv, preprocessed=preprocessed, **kwargs)\n    else:\n        return preprocess_acic_orig(\n            fn_csv=fn_csv, preprocessed=preprocessed, data_path=data_path, **kwargs\n        )\n\n\ndef load(\n    data_path: Path,\n    preprocessed: bool = True,\n    original_acic_outcomes: bool = False,\n    **kwargs: Any,\n) -> Tuple:\n    \"\"\"\n    ACIC2016 dataset dataloader.\n        - Download the dataset if needed.\n        - Load the dataset.\n        - Preprocess the data.\n        - Return train/test split.\n\n    Parameters\n    ----------\n    data_path: Path\n        Path to the CSV. If it is missing, it will be downloaded.\n    preprocessed: bool\n        Switch between the raw and preprocessed versions of the dataset.\n    original_acic_outcomes: bool\n        Switch between new simulations (Inductive bias paper) and original acic outcomes\n\n    Returns\n    -------\n    train_x: array or pd.DataFrame\n        Features in training data.\n    train_t: array or pd.DataFrame\n        Treatments in training data.\n    train_y: array or pd.DataFrame\n        Observed outcomes in training data.\n    train_potential_y: array or pd.DataFrame\n        Potential outcomes in training data.\n    test_x: array or pd.DataFrame\n        Features in testing data.\n    test_potential_y: array or pd.DataFrame\n        Potential outcomes in testing data.\n    \"\"\"\n    if preprocessed:\n        csv = data_path / \"x_trans.csv\"\n\n        download_if_needed(csv, file_id=PREPROCESSED_FILE_ID)\n    else:\n        arch = data_path / \"data_cf_all.tar.gz\"\n\n        download_if_needed(\n            arch, file_id=FILE_ID, unarchive=True, unarchive_folder=data_path\n        )\n\n        csv = data_path / \"data_cf_all/x.csv\"\n    log.debug(f\"load dataset {csv}\")\n\n    return preprocess(\n        csv,\n        data_path=data_path,\n        preprocessed=preprocessed,\n        original_acic_outcomes=original_acic_outcomes,\n        **kwargs,\n    )\n"
  },
  {
    "path": "catenets/datasets/dataset_ihdp.py",
    "content": "\"\"\"\nIHDP (Infant Health and Development Program) dataset\n\"\"\"\n# stdlib\nimport os\nimport random\nfrom pathlib import Path\nfrom typing import Any, Tuple\n\n# third party\nimport numpy as np\n\nimport catenets.logger as log\n\nfrom .network import download_if_needed\n\nnp.random.seed(0)\nrandom.seed(0)\n\nTRAIN_DATASET = \"ihdp_npci_1-100.train.npz\"\nTEST_DATASET = \"ihdp_npci_1-100.test.npz\"\nTRAIN_URL = \"https://www.fredjo.com/files/ihdp_npci_1-100.train.npz\"\nTEST_URL = \"https://www.fredjo.com/files/ihdp_npci_1-100.test.npz\"\n\n\n# helper functions\ndef load_data_npz(fname: Path, get_po: bool = True) -> dict:\n    \"\"\"\n    Helper function for loading the IHDP data set (adapted from https://github.com/clinicalml/cfrnet)\n\n    Parameters\n    ----------\n    fname: Path\n        Dataset path\n\n    Returns\n    -------\n    data: dict\n        Raw IHDP dict, with X, w, y and yf keys.\n    \"\"\"\n    data_in = np.load(fname)\n    data = {\"X\": data_in[\"x\"], \"w\": data_in[\"t\"], \"y\": data_in[\"yf\"]}\n    try:\n        data[\"ycf\"] = data_in[\"ycf\"]\n    except BaseException:\n        data[\"ycf\"] = None\n\n    if get_po:\n        data[\"mu0\"] = data_in[\"mu0\"]\n        data[\"mu1\"] = data_in[\"mu1\"]\n\n    data[\"HAVE_TRUTH\"] = not data[\"ycf\"] is None\n    data[\"dim\"] = data[\"X\"].shape[1]\n    data[\"n\"] = data[\"X\"].shape[0]\n\n    return data\n\n\ndef prepare_ihdp_data(\n    data_train: dict,\n    data_test: dict,\n    rescale: bool = False,\n    setting: str = \"C\",\n    return_pos: bool = False,\n) -> Tuple:\n    \"\"\"\n    Helper for preprocessing the IHDP dataset.\n\n    Parameters\n    ----------\n    data_train: pd.DataFrame or dict\n        Train dataset\n    data_test: pd.DataFrame or dict\n        Test dataset\n    rescale: bool, default False\n        Rescale the outcomes to have similar scale\n    setting: str, default C\n        Experiment setting\n    return_pos: bool\n        Return potential outcomes\n\n    Returns\n    -------\n    X: dict or pd.DataFrame\n        Training Feature set\n    y: pd.DataFrame or list\n        Outcome list\n    t: pd.DataFrame or list\n        Treatment list\n    cate_true_in: pd.DataFrame or list\n        Average treatment effects for the training set\n    X_t: pd.Dataframe or list\n        Test feature set\n    cate_true_out: pd.DataFrame of list\n        Average treatment effects for the testing set\n    \"\"\"\n\n    X, y, w, mu0, mu1 = (\n        data_train[\"X\"],\n        data_train[\"y\"],\n        data_train[\"w\"],\n        data_train[\"mu0\"],\n        data_train[\"mu1\"],\n    )\n\n    X_t, _, _, mu0_t, mu1_t = (\n        data_test[\"X\"],\n        data_test[\"y\"],\n        data_test[\"w\"],\n        data_test[\"mu0\"],\n        data_test[\"mu1\"],\n    )\n    if setting == \"D\":\n        y[w == 1] = y[w == 1] + mu0[w == 1]\n        mu1 = mu0 + mu1\n        mu1_t = mu0_t + mu1_t\n\n    if rescale:\n        # rescale all outcomes to have similar scale of CATEs if sd_cate > 1\n        cate_in = mu0 - mu1\n        sd_cate = np.sqrt(cate_in.var())\n\n        if sd_cate > 1:\n            # training data\n            error = y - w * mu1 - (1 - w) * mu0\n            mu0 = mu0 / sd_cate\n            mu1 = mu1 / sd_cate\n            y = w * mu1 + (1 - w) * mu0 + error\n\n            # test data\n            mu0_t = mu0_t / sd_cate\n            mu1_t = mu1_t / sd_cate\n\n    cate_true_in = mu1 - mu0\n    cate_true_out = mu1_t - mu0_t\n\n    if return_pos:\n        return X, y, w, cate_true_in, X_t, cate_true_out, mu0, mu1, mu0_t, mu1_t\n\n    return X, y, w, cate_true_in, X_t, cate_true_out\n\n\ndef get_one_data_set(D: dict, i_exp: int, get_po: bool = True) -> dict:\n    \"\"\"\n    Helper for getting the IHDP data for one experiment. Adapted from https://github.com/clinicalml/cfrnet\n\n    Parameters\n    ----------\n    D: dict or pd.DataFrame\n        All the experiment\n    i_exp: int\n        Experiment number\n\n    Returns\n    -------\n    data: dict or pd.Dataframe\n        dict with the experiment\n    \"\"\"\n    D_exp = {}\n    D_exp[\"X\"] = D[\"X\"][:, :, i_exp - 1]\n    D_exp[\"w\"] = D[\"w\"][:, i_exp - 1 : i_exp]\n    D_exp[\"y\"] = D[\"y\"][:, i_exp - 1 : i_exp]\n    if D[\"HAVE_TRUTH\"]:\n        D_exp[\"ycf\"] = D[\"ycf\"][:, i_exp - 1 : i_exp]\n    else:\n        D_exp[\"ycf\"] = None\n\n    if get_po:\n        D_exp[\"mu0\"] = D[\"mu0\"][:, i_exp - 1 : i_exp]\n        D_exp[\"mu1\"] = D[\"mu1\"][:, i_exp - 1 : i_exp]\n\n    return D_exp\n\n\ndef load(data_path: Path, exp: int = 1, rescale: bool = False, **kwargs: Any) -> Tuple:\n    \"\"\"\n    Get IHDP train/test datasets with treatments and labels.\n\n    Parameters\n    ----------\n    data_path: Path\n        Path to the dataset csv. If the data is missing, it will be downloaded.\n\n\n    Returns\n    -------\n    X: pd.Dataframe or array\n        The training feature set\n    w: pd.DataFrame or array\n        Training treatment assignments.\n    y: pd.Dataframe or array\n        The training labels\n    training potential outcomes: pd.DataFrame or array.\n        Potential outcomes for the training set.\n    X_t: pd.DataFrame or array\n        The testing feature set\n    testing potential outcomes: pd.DataFrame of array\n        Potential outcomes for the testing set.\n    \"\"\"\n    data_train, data_test = load_raw(data_path)\n\n    data_exp = get_one_data_set(data_train, i_exp=exp, get_po=True)\n    data_exp_test = get_one_data_set(data_test, i_exp=exp, get_po=True)\n\n    (\n        X,\n        y,\n        w,\n        cate_true_in,\n        X_t,\n        cate_true_out,\n        mu0,\n        mu1,\n        mu0_t,\n        mu1_t,\n    ) = prepare_ihdp_data(\n        data_exp,\n        data_exp_test,\n        rescale=rescale,\n        return_pos=True,\n    )\n\n    return (\n        X,\n        w,\n        y,\n        np.asarray([mu0, mu1]).squeeze().T,\n        X_t,\n        np.asarray([mu0_t, mu1_t]).squeeze().T,\n    )\n\n\ndef load_raw(data_path: Path) -> Tuple:\n    \"\"\"\n    Get IHDP raw train/test sets.\n\n    Parameters\n    ----------\n    data_path: Path\n        Path to the dataset csv. If the data is missing, it will be downloaded.\n\n    Returns\n    -------\n\n    data_train: dict or pd.DataFrame\n        Training data\n    data_test: dict or pd.DataFrame\n        Testing data\n    \"\"\"\n\n    try:\n        os.mkdir(data_path)\n    except BaseException:\n        pass\n\n    train_csv = data_path / TRAIN_DATASET\n    test_csv = data_path / TEST_DATASET\n\n    log.debug(f\"load raw dataset {train_csv}\")\n\n    download_if_needed(train_csv, http_url=TRAIN_URL)\n    download_if_needed(test_csv, http_url=TEST_URL)\n\n    data_train = load_data_npz(train_csv, get_po=True)\n    data_test = load_data_npz(test_csv, get_po=True)\n\n    return data_train, data_test\n"
  },
  {
    "path": "catenets/datasets/dataset_twins.py",
    "content": "\"\"\"\nTwins dataset\nLoad real-world individualized treatment effects estimation datasets\n\n- Reference: http://data.nber.org/data/linked-birth-infant-death-data-vital-statistics-data.html\n\"\"\"\n# stdlib\nimport random\nfrom pathlib import Path\nfrom typing import Tuple\n\n# third party\nimport numpy as np\nimport pandas as pd\nfrom sklearn.preprocessing import MinMaxScaler\n\nimport catenets.logger as log\n\nfrom .network import download_if_needed\n\nDATASET = \"Twin_Data.csv.gz\"\nURL = \"https://bitbucket.org/mvdschaar/mlforhealthlabpub/raw/0b0190bcd38a76c405c805f1ca774971fcd85233/data/twins/Twin_Data.csv.gz\"  # noqa: E501\n\n\ndef preprocess(\n    fn_csv: Path,\n    train_ratio: float = 0.8,\n    treatment_type: str = \"rand\",\n    seed: int = 42,\n    treat_prop: float = 0.5,\n) -> Tuple:\n    \"\"\"Helper for preprocessing the Twins dataset.\n\n    Parameters\n    ----------\n    fn_csv: Path\n        Dataset CSV file path.\n    train_ratio: float\n        The ratio of training data.\n    treatment_type: string\n        The treatment selection strategy.\n    seed: float\n        Random seed.\n\n    Returns\n    -------\n    train_x: array or pd.DataFrame\n        Features in training data.\n    train_t: array or pd.DataFrame\n        Treatments in training data.\n    train_y: array or pd.DataFrame\n        Observed outcomes in training data.\n    train_potential_y: array or pd.DataFrame\n        Potential outcomes in training data.\n    test_x: array or pd.DataFrame\n        Features in testing data.\n    test_potential_y: array or pd.DataFrame\n        Potential outcomes in testing data.\n    \"\"\"\n    np.random.seed(seed)\n    random.seed(seed)\n\n    # Load original data (11400 patients, 30 features, 2 dimensional potential outcomes)\n    df = pd.read_csv(fn_csv)\n\n    cleaned_columns = []\n    for col in df.columns:\n        cleaned_columns.append(col.replace(\"'\", \"\").replace(\"’\", \"\"))\n    df.columns = cleaned_columns\n\n    feat_list = list(df)\n\n    # 8: factor not on certificate, 9: factor not classifiable --> np.nan --> mode imputation\n    medrisk_list = [\n        \"anemia\",\n        \"cardiac\",\n        \"lung\",\n        \"diabetes\",\n        \"herpes\",\n        \"hydra\",\n        \"hemo\",\n        \"chyper\",\n        \"phyper\",\n        \"eclamp\",\n        \"incervix\",\n        \"pre4000\",\n        \"dtotord\",\n        \"preterm\",\n        \"renal\",\n        \"rh\",\n        \"uterine\",\n        \"othermr\",\n    ]\n    # 99: missing\n    other_list = [\"cigar\", \"drink\", \"wtgain\", \"gestat\", \"dmeduc\", \"nprevist\"]\n\n    other_list2 = [\"pldel\", \"resstatb\"]  # but no samples are missing..\n\n    bin_list = [\"dmar\"] + medrisk_list\n    con_list = [\"dmage\", \"mpcb\"] + other_list\n    cat_list = [\"adequacy\"] + other_list2\n\n    for feat in medrisk_list:\n        df[feat] = df[feat].apply(lambda x: df[feat].mode()[0] if x in [8, 9] else x)\n\n    for feat in other_list:\n        df.loc[df[feat] == 99, feat] = df.loc[df[feat] != 99, feat].mean()\n\n    df_features = df[con_list + bin_list]\n\n    for feat in cat_list:\n        df_features = pd.concat(\n            [df_features, pd.get_dummies(df[feat], prefix=feat)], axis=1\n        )\n\n    # Define features\n    feat_list = [\n        \"dmage\",\n        \"mpcb\",\n        \"cigar\",\n        \"drink\",\n        \"wtgain\",\n        \"gestat\",\n        \"dmeduc\",\n        \"nprevist\",\n        \"dmar\",\n        \"anemia\",\n        \"cardiac\",\n        \"lung\",\n        \"diabetes\",\n        \"herpes\",\n        \"hydra\",\n        \"hemo\",\n        \"chyper\",\n        \"phyper\",\n        \"eclamp\",\n        \"incervix\",\n        \"pre4000\",\n        \"dtotord\",\n        \"preterm\",\n        \"renal\",\n        \"rh\",\n        \"uterine\",\n        \"othermr\",\n        \"adequacy_1\",\n        \"adequacy_2\",\n        \"adequacy_3\",\n        \"pldel_1\",\n        \"pldel_2\",\n        \"pldel_3\",\n        \"pldel_4\",\n        \"pldel_5\",\n        \"resstatb_1\",\n        \"resstatb_2\",\n        \"resstatb_3\",\n        \"resstatb_4\",\n    ]\n\n    x = np.asarray(df_features[feat_list])\n    y0 = np.asarray(df[[\"outcome(t=0)\"]]).reshape((-1,))\n    y0 = np.array(y0 < 9999, dtype=int)\n\n    y1 = np.asarray(df[[\"outcome(t=1)\"]]).reshape((-1,))\n    y1 = np.array(y1 < 9999, dtype=int)\n\n    # Preprocessing\n    scaler = MinMaxScaler()\n    scaler.fit(x)\n    x = scaler.transform(x)\n\n    no, dim = x.shape\n\n    if treatment_type == \"rand\":\n        # assign with p=0.5\n        prob = np.ones(x.shape[0]) * treat_prop\n    elif treatment_type == \"logistic\":\n        # assign with logistic prob\n        coef = np.random.uniform(-0.1, 0.1, size=[np.shape(x)[1], 1])\n        prob = 1 / (1 + np.exp(-np.matmul(x, coef)))\n\n    w = np.random.binomial(1, prob)\n    y = y1 * w + y0 * (1 - w)\n\n    potential_y = np.vstack((y0, y1)).T\n\n    # Train/test division\n    if train_ratio < 1:\n        idx = np.random.permutation(no)\n        train_idx = idx[: int(train_ratio * no)]\n        test_idx = idx[int(train_ratio * no) :]\n\n        train_x = x[train_idx, :]\n        train_w = w[train_idx]\n        train_y = y[train_idx]\n        train_potential_y = potential_y[train_idx, :]\n\n        test_x = x[test_idx, :]\n        test_potential_y = potential_y[test_idx, :]\n    else:\n        train_x = x\n        train_w = w\n        train_y = y\n        train_potential_y = potential_y\n        test_x = None\n        test_potential_y = None\n\n    return train_x, train_w, train_y, train_potential_y, test_x, test_potential_y\n\n\ndef load(\n    data_path: Path,\n    train_ratio: float = 0.8,\n    treatment_type: str = \"rand\",\n    seed: int = 42,\n    treat_prop: float = 0.5,\n) -> Tuple:\n    \"\"\"\n    Twins dataset dataloader.\n        - Download the dataset if needed.\n        - Load the dataset.\n        - Preprocess the data.\n        - Return train/test split.\n\n    Parameters\n    ----------\n    data_path: Path\n        Path to the CSV. If it is missing, it will be downloaded.\n    train_ratio: float\n        Train/test ratio\n    treatment_type: str\n        Treatment generation strategy\n    seed: float\n        Random seed\n    treat_prop: float\n        Treatment proportion\n\n    Returns\n    -------\n    train_x: array or pd.DataFrame\n        Features in training data.\n    train_t: array or pd.DataFrame\n        Treatments in training data.\n    train_y: array or pd.DataFrame\n        Observed outcomes in training data.\n    train_potential_y: array or pd.DataFrame\n        Potential outcomes in training data.\n    test_x: array or pd.DataFrame\n        Features in testing data.\n    test_potential_y: array or pd.DataFrame\n        Potential outcomes in testing data.\n    \"\"\"\n    csv = data_path / DATASET\n\n    download_if_needed(csv, http_url=URL)\n\n    log.debug(f\"load dataset {csv}\")\n\n    return preprocess(\n        csv,\n        train_ratio=train_ratio,\n        treatment_type=treatment_type,\n        seed=seed,\n        treat_prop=treat_prop,\n    )\n"
  },
  {
    "path": "catenets/datasets/network.py",
    "content": "\"\"\"\nUtilities and helpers for retrieving the datasets\n\"\"\"\n# stdlib\nimport tarfile\nimport urllib.request\nfrom pathlib import Path\nfrom typing import Optional\n\nimport gdown\n\n\ndef download_gdrive_if_needed(path: Path, file_id: str) -> None:\n    \"\"\"\n    Helper for downloading a file from Google Drive, if it is now already on the disk.\n\n    Parameters\n    ----------\n    path: Path\n        Where to download the file\n    file_id: str\n        Google Drive File ID. Details: https://developers.google.com/drive/api/v3/about-files\n    \"\"\"\n    path = Path(path)\n\n    if path.exists():\n        return\n\n    gdown.download(id=file_id, output=str(path), quiet=False)\n\n\ndef download_http_if_needed(path: Path, url: str) -> None:\n    \"\"\"\n    Helper for downloading a file, if it is now already on the disk.\n\n    Parameters\n    ----------\n    path: Path\n        Where to download the file.\n    url: URL string\n        HTTP URL for the dataset.\n    \"\"\"\n    path = Path(path)\n\n    if path.exists():\n        return\n\n    if url.lower().startswith(\"http\"):\n        urllib.request.urlretrieve(url, path)  # nosec\n        return\n\n    raise ValueError(f\"Invalid url provided {url}\")\n\n\ndef unarchive_if_needed(path: Path, output_folder: Path) -> None:\n    \"\"\"\n    Helper for uncompressing archives. Supports .tar.gz and .tar.\n\n    Parameters\n    ----------\n    path: Path\n        Source archive.\n    output_folder: Path\n        Where to unarchive.\n    \"\"\"\n    if str(path).endswith(\".tar.gz\"):\n        tar = tarfile.open(path, \"r:gz\")\n        tar.extractall(path=output_folder) # nosec\n        tar.close()\n    elif str(path).endswith(\".tar\"):\n        tar = tarfile.open(path, \"r:\")\n        tar.extractall(path=output_folder) # nosec\n        tar.close()\n    else:\n        raise NotImplementedError(f\"archive not supported {path}\")\n\n\ndef download_if_needed(\n    download_path: Path,\n    file_id: Optional[str] = None,  # used for downloading from Google Drive\n    http_url: Optional[str] = None,  # used for downloading from a HTTP URL\n    unarchive: bool = False,  # unzip a downloaded archive\n    unarchive_folder: Optional[Path] = None,  # unzip folder\n) -> None:\n    \"\"\"\n    Helper for retrieving online datasets.\n\n    Parameters\n    ----------\n    download_path: str\n        Where to download the archive\n    file_id: str, optional\n        Set this if you want to download from a public Google drive share\n    http_url: str, optional\n        Set this if you want to download from a HTTP URL\n    unarchive: bool\n        Set this if you want to try to unarchive the downloaded file\n    unarchive_folder: str\n        Mandatory if you set unarchive to True.\n    \"\"\"\n    download_path = Path(download_path)\n    if file_id is not None:\n        download_gdrive_if_needed(download_path, file_id)\n    elif http_url is not None:\n        download_http_if_needed(download_path, http_url)\n    else:\n        raise ValueError(\"Please provide a download URL\")\n\n    if unarchive and unarchive_folder is None:\n        raise ValueError(\"Please provide a folder for the archive\")\n    if unarchive and unarchive_folder is not None:\n        try:\n            unarchive_if_needed(download_path, unarchive_folder)\n        except BaseException as e:\n            print(f\"Failed to unpack {download_path}. Error {e}\")\n            download_path.unlink()\n"
  },
  {
    "path": "catenets/experiment_utils/__init__.py",
    "content": ""
  },
  {
    "path": "catenets/experiment_utils/base.py",
    "content": "\"\"\"\nSome utils for experiments\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Callable, Dict, Optional, Union\n\nimport jax.numpy as jnp\n\nfrom catenets.models.jax import (\n    DRNET_NAME,\n    PSEUDOOUT_NAME,\n    RANET_NAME,\n    RNET_NAME,\n    SNET1_NAME,\n    SNET2_NAME,\n    SNET3_NAME,\n    SNET_NAME,\n    T_NAME,\n    XNET_NAME,\n    PseudoOutcomeNet,\n    get_catenet,\n)\nfrom catenets.models.jax.base import check_shape_1d_data\nfrom catenets.models.jax.transformation_utils import (\n    DR_TRANSFORMATION,\n    PW_TRANSFORMATION,\n    RA_TRANSFORMATION,\n)\n\nSEP = \"_\"\n\n\ndef eval_mse_model(\n    inputs: jnp.ndarray,\n    targets: jnp.ndarray,\n    predict_fun: Callable,\n    params: jnp.ndarray,\n) -> jnp.ndarray:\n    # evaluate the mse of a model given its function and params\n    preds = predict_fun(params, inputs)\n    return jnp.mean((preds - targets) ** 2)\n\n\ndef eval_mse(preds: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:\n    preds = check_shape_1d_data(preds)\n    targets = check_shape_1d_data(targets)\n    return jnp.mean((preds - targets) ** 2)\n\n\ndef eval_root_mse(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray:\n    cate_true = check_shape_1d_data(cate_true)\n    cate_pred = check_shape_1d_data(cate_pred)\n    return jnp.sqrt(eval_mse(cate_pred, cate_true))\n\n\ndef eval_abs_error_ate(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray:\n    cate_true = check_shape_1d_data(cate_true)\n    cate_pred = check_shape_1d_data(cate_pred)\n    return jnp.abs(jnp.mean(cate_pred) - jnp.mean(cate_true))\n\n\ndef get_model_set(\n    model_selection: Union[str, list] = \"all\", model_params: Optional[dict] = None\n) -> Dict:\n    \"\"\"Helper function to retrieve a set of models\"\"\"\n    # get model selection\n    if type(model_selection) is str:\n        if model_selection == \"snet\":\n            models = get_all_snets()\n        elif model_selection == \"pseudo\":\n            models = get_all_pseudoout_models()\n        elif model_selection == \"twostep\":\n            models = get_all_twostep_models()\n        elif model_selection == \"all\":\n            models = dict(**get_all_snets(), **get_all_pseudoout_models())\n        else:\n            models = {model_selection: get_catenet(model_selection)()}  # type: ignore\n    elif type(model_selection) is list:\n        models = {}\n        for model in model_selection:\n            models.update({model: get_catenet(model)()})\n    else:\n        raise ValueError(\"model_selection should be string or list.\")\n\n    # set hyperparameters\n    if model_params is not None:\n        for model in models.values():\n            existing_params = model.get_params()\n            new_params = {\n                key: val\n                for key, val in model_params.items()\n                if key in existing_params.keys()\n            }\n            model.set_params(**new_params)\n\n    return models\n\n\nALL_SNETS = [T_NAME, SNET1_NAME, SNET2_NAME, SNET3_NAME, SNET_NAME]\nALL_PSEUDOOUT_MODELS = [DR_TRANSFORMATION, PW_TRANSFORMATION, RA_TRANSFORMATION]\nALL_TWOSTEP_MODELS = [DRNET_NAME, RANET_NAME, XNET_NAME, RNET_NAME]\n\n\ndef get_all_snets() -> Dict:\n    model_dict = {}\n    for name in ALL_SNETS:\n        model_dict.update({name: get_catenet(name)()})\n    return model_dict\n\n\ndef get_all_pseudoout_models() -> Dict:  # DR, RA, PW learner\n    model_dict = {}\n    for trans in ALL_PSEUDOOUT_MODELS:\n        model_dict.update(\n            {PSEUDOOUT_NAME + SEP + trans: PseudoOutcomeNet(transformation=trans)}\n        )\n    return model_dict\n\n\ndef get_all_twostep_models() -> Dict:  # DR, RA, R, X learner\n    model_dict = {}\n    for name in ALL_TWOSTEP_MODELS:\n        model_dict.update({name: get_catenet(name)()})\n    return model_dict\n"
  },
  {
    "path": "catenets/experiment_utils/simulation_utils.py",
    "content": "\"\"\"\r\nSimulation utils, allowing to flexibly consider different DGPs\r\n\"\"\"\r\n# Author: Alicia Curth\r\nfrom typing import Any, Optional, Tuple\r\n\r\nimport numpy as np\r\nfrom scipy.special import expit\r\n\r\n\r\ndef simulate_treatment_setup(\r\n    n: int,\r\n    d: int = 25,\r\n    n_w: int = 0,\r\n    n_c: int = 0,\r\n    n_o: int = 0,\r\n    n_t: int = 0,\r\n    covariate_model: Any = None,\r\n    covariate_model_params: Optional[dict] = None,\r\n    propensity_model: Any = None,\r\n    propensity_model_params: Optional[dict] = None,\r\n    mu_0_model: Any = None,\r\n    mu_0_model_params: Optional[dict] = None,\r\n    mu_1_model: Any = None,\r\n    mu_1_model_params: Optional[dict] = None,\r\n    error_sd: float = 1,\r\n    seed: int = 42,\r\n) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\r\n    \"\"\"\r\n    Generic function to flexibly simulate a treatment setup.\r\n\r\n    Parameters\r\n    ----------\r\n    n: int\r\n        Number of observations to generate\r\n    d: int\r\n        dimension of X to generate\r\n    n_o: int\r\n        Dimension of outcome-factor\r\n    n_c: int\r\n        Dimension of confounding factor\r\n    n_t: int\r\n        Dimension of purely predictive variables (support of tau(x)\r\n    n_w: int\r\n        Dimension of treatment assignment factor\r\n    covariate_model:\r\n        Model to generate covariates. Default: multivariate normal\r\n    covariate_model_params: dict\r\n        Additional parameters to pass to covariate model\r\n    propensity_model:\r\n        Model to generate propensity scores\r\n    propensity_model_params:\r\n        Additional parameters to pass to propensity model\r\n    mu_0_model:\r\n        Model to generate untreated outcomes\r\n    mu_0_model_params:\r\n        Additional parameters to pass to untreated outcome model\r\n    mu_1_model:\r\n        Model to generate treated outcomes.\r\n    mu_1_model_params:\r\n        Additional parameters to pass to treated outcome model\r\n    error_sd: float, default 1\r\n        Standard deviation of normal errors\r\n    seed: int\r\n        Seed\r\n\r\n    Returns\r\n    -------\r\n        X, y, w, p, t - Covariates, observed outcomes, treatment indicators, propensities, CATE\r\n    \"\"\"\r\n    # input checks\r\n    n_nuisance = d - (n_c + n_o + n_w + n_t)\r\n    if n_nuisance < 0:\r\n        raise ValueError(\"Dimensions should add up to maximally d.\")\r\n\r\n    # set defaults\r\n    if covariate_model is None:\r\n        covariate_model = normal_covariate_model\r\n\r\n    if covariate_model_params is None:\r\n        covariate_model_params = {}\r\n\r\n    if propensity_model is None:\r\n        propensity_model = propensity_AISTATS\r\n\r\n    if propensity_model_params is None:\r\n        propensity_model_params = {}\r\n\r\n    if mu_0_model is None:\r\n        mu_0_model = mu0_AISTATS\r\n\r\n    if mu_0_model_params is None:\r\n        mu_0_model_params = {}\r\n\r\n    if mu_1_model is None:\r\n        mu_1_model = mu1_AISTATS\r\n\r\n    if mu_1_model_params is None:\r\n        mu_1_model_params = {}\r\n\r\n    np.random.seed(seed)\r\n\r\n    # generate data and outcomes\r\n    X = covariate_model(\r\n        n=n,\r\n        n_nuisance=n_nuisance,\r\n        n_c=n_c,\r\n        n_o=n_o,\r\n        n_w=n_w,\r\n        n_t=n_t,\r\n        **covariate_model_params\r\n    )\r\n    mu_0 = mu_0_model(X, n_c=n_c, n_o=n_o, n_w=n_w, **mu_0_model_params)\r\n    mu_1 = mu_1_model(\r\n        X, n_c=n_c, n_o=n_o, n_w=n_w, n_t=n_t, mu_0=mu_0, **mu_1_model_params\r\n    )\r\n    t = mu_1 - mu_0\r\n\r\n    # generate treatments\r\n    p = propensity_model(X, n_c=n_c, n_w=n_w, **propensity_model_params)\r\n    w = np.random.binomial(1, p=p)\r\n\r\n    # generate observables\r\n    y = w * mu_1 + (1 - w) * mu_0 + np.random.normal(0, error_sd, n)\r\n\r\n    return X, y, w, p, t\r\n\r\n\r\n# normal covariate model (Adapted from Hassanpour & Greiner, 2020) -------------\r\ndef get_multivariate_normal_params(\r\n    m: int, correlated: bool = False\r\n) -> Tuple[np.ndarray, np.ndarray]:\r\n    # Adapted from Hassanpour & Greiner (2020)\r\n    if correlated:\r\n        mu = np.zeros(m)  # np.random.normal(size=m)/10\r\n        temp = np.random.uniform(size=(m, m))\r\n        temp = 0.5 * (np.transpose(temp) + temp)\r\n        sig = (np.ones((m, m)) - np.eye(m)) * temp / 10 + 0.5 * np.eye(\r\n            m\r\n        )  # (temp + m * np.eye(m)) / 10\r\n\r\n    else:\r\n        mu = np.zeros(m)\r\n        sig = np.eye(m)\r\n\r\n    return mu, sig\r\n\r\n\r\ndef get_set_normal_covariates(m: int, n: int, correlated: bool = False) -> np.ndarray:\r\n    if m == 0:\r\n        return\r\n    mu, sig = get_multivariate_normal_params(m, correlated=correlated)\r\n    return np.random.multivariate_normal(mean=mu, cov=sig, size=n)\r\n\r\n\r\ndef normal_covariate_model(\r\n    n: int,\r\n    n_nuisance: int = 25,\r\n    n_c: int = 0,\r\n    n_o: int = 0,\r\n    n_w: int = 0,\r\n    n_t: int = 0,\r\n    correlated: bool = False,\r\n) -> np.ndarray:\r\n    X_stack: Tuple = ()\r\n    for n_x in [n_w, n_c, n_o, n_t, n_nuisance]:\r\n        if n_x > 0:\r\n            X_stack = (*X_stack, get_set_normal_covariates(n_x, n, correlated))\r\n\r\n    return np.hstack(X_stack)\r\n\r\n\r\ndef propensity_AISTATS(\r\n    X: np.ndarray,\r\n    n_c: int = 0,\r\n    n_w: int = 0,\r\n    xi: float = 0.5,\r\n    nonlinear: bool = True,\r\n    offset: Any = 0,\r\n    target_prop: Optional[np.ndarray] = None,\r\n) -> np.ndarray:\r\n    if n_c + n_w == 0:\r\n        # constant propensity\r\n        return xi * np.ones(X.shape[0])\r\n    else:\r\n        coefs = np.ones(n_c + n_w)\r\n\r\n        if nonlinear:\r\n            z = np.dot(X[:, : (n_c + n_w)] ** 2, coefs) / (n_c + n_w)\r\n        else:\r\n            z = np.dot(X[:, : (n_c + n_w)], coefs) / (n_c + n_w)\r\n\r\n        if type(offset) is float or type(offset) is int:\r\n            prop = expit(xi * z + offset)\r\n            if target_prop is not None:\r\n                avg_prop = np.average(prop)\r\n                prop = target_prop / avg_prop * prop\r\n            return prop\r\n        elif offset == \"center\":\r\n            # center the propensity scores to median 0.5\r\n            prop = expit(xi * (z - np.median(z)))\r\n            if target_prop is not None:\r\n                avg_prop = np.average(prop)\r\n                prop = target_prop / avg_prop * prop\r\n            return prop\r\n        else:\r\n            raise ValueError(\"Not a valid value for offset\")\r\n\r\n\r\ndef propensity_constant(\r\n    X: np.ndarray, n_c: int = 0, n_w: int = 0, xi: float = 0.5\r\n) -> np.ndarray:\r\n    return xi * np.ones(X.shape[0])\r\n\r\n\r\ndef mu0_AISTATS(\r\n    X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0, scale: bool = False\r\n) -> np.ndarray:\r\n    if n_c + n_o == 0:\r\n        return np.zeros((X.shape[0]))\r\n    else:\r\n        if not scale:\r\n            coefs = np.ones(n_c + n_o)\r\n        else:\r\n            coefs = 10 * np.ones(n_c + n_o) / (n_c + n_o)\r\n        return np.dot(X[:, n_w : (n_w + n_c + n_o)] ** 2, coefs)\r\n\r\n\r\ndef mu1_AISTATS(\r\n    X: np.ndarray,\r\n    n_w: int = 0,\r\n    n_c: int = 0,\r\n    n_o: int = 0,\r\n    n_t: int = 0,\r\n    mu_0: Optional[np.ndarray] = None,\r\n    nonlinear: int = 2,\r\n    withbase: bool = True,\r\n    scale: bool = False,\r\n) -> np.ndarray:\r\n    if n_t == 0:\r\n        return mu_0\r\n    # use additive effect\r\n    else:\r\n        if scale:\r\n            coefs = 10 * np.ones(n_t) / n_t\r\n        else:\r\n            coefs = np.ones(n_t)\r\n        X_sel = X[:, (n_w + n_c + n_o) : (n_w + n_c + n_o + n_t)]\r\n    if withbase:\r\n        return mu_0 + np.dot(X_sel**nonlinear, coefs)\r\n    else:\r\n        return np.dot(X_sel**nonlinear, coefs)\r\n\r\n\r\n# Other simulation settings not used in AISTATS paper\r\n# uniform covariate model\r\ndef uniform_covariate_model(\r\n    n: int,\r\n    n_nuisance: int = 0,\r\n    n_c: int = 0,\r\n    n_o: int = 0,\r\n    n_w: int = 0,\r\n    n_t: int = 0,\r\n    low: int = -1,\r\n    high: int = 1,\r\n) -> np.ndarray:\r\n    d = n_nuisance + n_c + n_o + n_w + n_t\r\n    return np.random.uniform(low=low, high=high, size=(n, d))\r\n\r\n\r\ndef mu1_additive(\r\n    X: np.ndarray,\r\n    n_w: int = 0,\r\n    n_c: int = 0,\r\n    n_o: int = 0,\r\n    n_t: int = 0,\r\n    mu_0: Optional[np.ndarray] = None,\r\n) -> np.ndarray:\r\n    if n_t == 0:\r\n        return mu_0\r\n    else:\r\n        coefs = np.random.normal(size=n_t)\r\n        return np.dot(X[:, (n_w + n_c + n_o) : (n_w + n_c + n_o + n_t)], coefs) / n_t\r\n\r\n\r\n# regression surfaces from Hassanpour & Greiner\r\ndef mu0_hg(X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0) -> np.ndarray:\r\n    if n_c + n_o == 0:\r\n        return np.zeros((X.shape[0]))\r\n    else:\r\n        coefs = np.random.normal(size=n_c + n_o)\r\n        return np.dot(X[:, n_w : (n_w + n_c + n_o)], coefs) / (n_c + n_o)\r\n\r\n\r\ndef mu1_hg(\r\n    X: np.ndarray,\r\n    n_w: int = 0,\r\n    n_c: int = 0,\r\n    n_o: int = 0,\r\n    n_t: int = 0,\r\n    mu_0: Optional[np.ndarray] = None,\r\n) -> np.ndarray:\r\n    if n_c + n_o == 0:\r\n        return np.zeros((X.shape[0]))\r\n    else:\r\n        coefs = np.random.normal(size=n_c + n_o)\r\n        return np.dot(X[:, n_w : (n_w + n_c + n_o)] ** 2, coefs) / (n_c + n_o)\r\n\r\n\r\ndef propensity_hg(\r\n    X: np.ndarray, n_c: int = 0, n_w: int = 0, xi: Optional[float] = None\r\n) -> np.ndarray:\r\n    # propensity set-up used in Hassanpour & Greiner (2020)\r\n    if n_c + n_w == 0:\r\n        return 0.5 * np.ones(X.shape[0])\r\n    else:\r\n        if xi is None:\r\n            xi = 1\r\n\r\n        coefs = np.random.normal(size=n_c + n_w)\r\n        z = np.dot(X[:, : (n_c + n_w)], coefs)\r\n        return expit(xi * z)\r\n"
  },
  {
    "path": "catenets/experiment_utils/tester.py",
    "content": "# stdlib\nimport copy\nfrom typing import Any, Tuple\n\n# third party\nimport numpy as np\nimport torch\nfrom sklearn.model_selection import KFold, StratifiedKFold\n\nfrom catenets.experiment_utils.torch_metrics import abs_error_ATE, sqrt_PEHE\n\n\ndef generate_score(metric: np.ndarray) -> Tuple[float, float]:\n    percentile_val = 1.96\n    return (np.mean(metric), percentile_val * np.std(metric) / np.sqrt(len(metric)))\n\n\ndef print_score(score: Tuple[float, float]) -> str:\n    return str(round(score[0], 4)) + \" +/- \" + str(round(score[1], 4))\n\n\ndef evaluate_treatments_model(\n    estimator: Any,\n    X: torch.Tensor,\n    Y: torch.Tensor,\n    Y_full: torch.Tensor,\n    W: torch.Tensor,\n    n_folds: int = 3,\n    seed: int = 0,\n) -> dict:\n    metric_pehe = np.zeros(n_folds)\n    metric_ate = np.zeros(n_folds)\n\n    indx = 0\n    if len(np.unique(Y)) == 2:\n        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)\n    else:\n        skf = KFold(n_splits=n_folds, shuffle=True, random_state=seed)\n\n    for train_index, test_index in skf.split(X, Y):\n\n        X_train = X[train_index]\n        Y_train = Y[train_index]\n        W_train = W[train_index]\n\n        X_test = X[test_index]\n        Y_full_test = Y_full[test_index]\n\n        model = copy.deepcopy(estimator)\n        model.fit(X_train, Y_train, W_train)\n\n        try:\n            te_pred = model.predict(X_test).detach().cpu().numpy()\n        except BaseException:\n            te_pred = np.asarray(model.predict(X_test))\n\n        metric_ate[indx] = abs_error_ATE(Y_full_test, te_pred)\n        metric_pehe[indx] = sqrt_PEHE(Y_full_test, te_pred)\n        indx += 1\n\n    output_pehe = generate_score(metric_pehe)\n    output_ate = generate_score(metric_ate)\n\n    return {\n        \"raw\": {\n            \"pehe\": output_pehe,\n            \"ate\": output_ate,\n        },\n        \"str\": {\n            \"pehe\": print_score(output_pehe),\n            \"ate\": print_score(output_ate),\n        },\n    }\n"
  },
  {
    "path": "catenets/experiment_utils/torch_metrics.py",
    "content": "# third party\nimport torch\n\n\ndef sqrt_PEHE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Precision in Estimation of Heterogeneous Effect(PyTorch version).\n    PEHE reflects the ability to capture individual variation in treatment effects.\n    Args:\n        po: expected outcome.\n        hat_te: estimated outcome.\n    \"\"\"\n    po = torch.Tensor(po)\n    hat_te = torch.Tensor(hat_te)\n    return torch.sqrt(torch.mean(((po[:, 1] - po[:, 0]) - hat_te) ** 2))\n\n\ndef abs_error_ATE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Average Treatment Effect.\n    ATE measures what is the expected causal effect of the treatment across all individuals in the population.\n    Args:\n        po: expected outcome.\n        hat_te: estimated outcome.\n    \"\"\"\n    po = torch.Tensor(po)\n    hat_te = torch.Tensor(hat_te)\n    return torch.abs(torch.mean(po[:, 1] - po[:, 0]) - torch.mean(hat_te))\n"
  },
  {
    "path": "catenets/logger.py",
    "content": "# stdlib\nimport logging\nimport os\nfrom typing import Any, Callable, NoReturn, TextIO, Union\n\n# third party\nfrom loguru import logger\n\nLOG_FORMAT = \"[{time}][{process.id}][{level}] {message}\"\n\nlogger.remove()\nDEFAULT_SINK = \"catenets_{time}.log\"\n\n\ndef remove() -> None:\n    logger.remove()\n\n\ndef add(\n    sink: Union[None, str, os.PathLike, TextIO, logging.Handler] = None,\n    level: str = \"ERROR\",\n) -> None:\n    sink = DEFAULT_SINK if sink is None else sink\n    try:\n        logger.add(\n            sink=sink,\n            format=LOG_FORMAT,\n            enqueue=True,\n            colorize=False,\n            diagnose=True,\n            backtrace=True,\n            rotation=\"10 MB\",\n            retention=\"1 day\",\n            level=level,\n        )\n    except BaseException:\n        logger.add(\n            sink=sink,\n            format=LOG_FORMAT,\n            colorize=False,\n            diagnose=True,\n            backtrace=True,\n            level=level,\n        )\n\n\ndef traceback_and_raise(e: Any, verbose: bool = False) -> NoReturn:\n    try:\n        if verbose:\n            logger.opt(lazy=True).exception(e)\n        else:\n            logger.opt(lazy=True).critical(e)\n    except BaseException as ex:\n        logger.debug(\"failed to print exception\", ex)\n    if not issubclass(type(e), Exception):\n        e = Exception(e)\n    raise e\n\n\ndef create_log_and_print_function(level: str) -> Callable:\n    def log_and_print(*args: Any, **kwargs: Any) -> None:\n        try:\n            method = getattr(logger.opt(lazy=True), level, None)\n            if method is not None:\n                method(*args, **kwargs)\n            else:\n                logger.debug(*args, **kwargs)\n        except BaseException as e:\n            msg = f\"failed to log exception. {e}\"\n            try:\n                logger.debug(msg)\n            except Exception as e:\n                print(f\"{msg}. {e}\")\n\n    return log_and_print\n\n\ndef traceback(*args: Any, **kwargs: Any) -> None:\n    return create_log_and_print_function(level=\"exception\")(*args, **kwargs)\n\n\ndef critical(*args: Any, **kwargs: Any) -> None:\n    return create_log_and_print_function(level=\"critical\")(*args, **kwargs)\n\n\ndef error(*args: Any, **kwargs: Any) -> None:\n    return create_log_and_print_function(level=\"error\")(*args, **kwargs)\n\n\ndef warning(*args: Any, **kwargs: Any) -> None:\n    return create_log_and_print_function(level=\"warning\")(*args, **kwargs)\n\n\ndef info(*args: Any, **kwargs: Any) -> None:\n    return create_log_and_print_function(level=\"info\")(*args, **kwargs)\n\n\ndef debug(*args: Any, **kwargs: Any) -> None:\n    return create_log_and_print_function(level=\"debug\")(*args, **kwargs)\n\n\ndef trace(*args: Any, **kwargs: Any) -> None:\n    return create_log_and_print_function(level=\"trace\")(*args, **kwargs)\n"
  },
  {
    "path": "catenets/models/__init__.py",
    "content": "import catenets.logger as log\n\ntry:\n    from . import jax\nexcept ImportError:\n    log.error(\"JAX models disabled\")\n\ntry:\n    from . import torch\nexcept ImportError:\n    log.error(\"PyTorch models disabled\")\n\n__all__ = [\"jax\", \"torch\"]\n"
  },
  {
    "path": "catenets/models/constants.py",
    "content": "\"\"\"\r\nDefine some constants for initialisation of hyperparamters etc\r\n\"\"\"\r\nimport numpy as np\r\n\r\n# default model architectures\r\nDEFAULT_LAYERS_OUT = 2\r\nDEFAULT_LAYERS_OUT_T = 2\r\nDEFAULT_LAYERS_R = 3\r\nDEFAULT_LAYERS_R_T = 3\r\n\r\nDEFAULT_UNITS_OUT = 100\r\nDEFAULT_UNITS_R = 200\r\nDEFAULT_UNITS_OUT_T = 100\r\nDEFAULT_UNITS_R_T = 200\r\n\r\nDEFAULT_NONLIN = \"elu\"\r\n\r\n# other default hyperparameters\r\nDEFAULT_STEP_SIZE = 0.0001\r\nDEFAULT_STEP_SIZE_T = 0.0001\r\nDEFAULT_N_ITER = 10000\r\nDEFAULT_BATCH_SIZE = 100\r\nDEFAULT_PENALTY_L2 = 1e-4\r\nDEFAULT_PENALTY_DISC = 0\r\nDEFAULT_PENALTY_ORTHOGONAL = 1 / 100\r\nDEFAULT_AVG_OBJECTIVE = True\r\n\r\n# defaults for early stopping\r\nDEFAULT_VAL_SPLIT = 0.3\r\nDEFAULT_N_ITER_MIN = 200\r\nDEFAULT_PATIENCE = 10\r\n\r\n# Defaults for crossfitting\r\nDEFAULT_CF_FOLDS = 2\r\n\r\n# other defaults\r\nDEFAULT_SEED = 42\r\nDEFAULT_N_ITER_PRINT = 50\r\nLARGE_VAL = np.iinfo(np.int32).max\r\n\r\nDEFAULT_UNITS_R_BIG_S = 100\r\nDEFAULT_UNITS_R_SMALL_S = 50\r\n\r\nDEFAULT_UNITS_R_BIG_S3 = 150\r\nDEFAULT_UNITS_R_SMALL_S3 = 50\r\n\r\nN_SUBSPACES = 3\r\nDEFAULT_DIM_S_OUT = 50\r\nDEFAULT_DIM_S_R = 100\r\nDEFAULT_DIM_P_OUT = 50\r\nDEFAULT_DIM_P_R = 100\r\n"
  },
  {
    "path": "catenets/models/jax/__init__.py",
    "content": "\"\"\"\nJAX-based implementations for the CATE estimators.\n\"\"\"\nfrom typing import Any\n\nfrom catenets.models.jax.disentangled_nets import SNet3\nfrom catenets.models.jax.flextenet import FlexTENet\nfrom catenets.models.jax.offsetnet import OffsetNet\nfrom catenets.models.jax.pseudo_outcome_nets import (\n    DRNet,\n    PseudoOutcomeNet,\n    PWNet,\n    RANet,\n)\nfrom catenets.models.jax.representation_nets import DragonNet, SNet1, SNet2, TARNet\nfrom catenets.models.jax.rnet import RNet\nfrom catenets.models.jax.snet import SNet\nfrom catenets.models.jax.tnet import TNet\nfrom catenets.models.jax.xnet import XNet\n\nSNET1_NAME = \"SNet1\"\nT_NAME = \"TNet\"\nSNET2_NAME = \"SNet2\"\nPSEUDOOUT_NAME = \"PseudoOutcomeNet\"\nSNET3_NAME = \"SNet3\"\nSNET_NAME = \"SNet\"\nXNET_NAME = \"XNet\"\nRNET_NAME = \"RNet\"\nDRNET_NAME = \"DRNet\"\nPWNET_NAME = \"PWNet\"\nRANET_NAME = \"RANet\"\nTARNET_NAME = \"TARNet\"\nFLEXTE_NAME = \"FlexTENet\"\nOFFSET_NAME = \"OffsetNet\"\nDRAGON_NAME = \"DragonNet\"\n\nALL_MODELS = [\n    T_NAME,\n    SNET1_NAME,\n    SNET2_NAME,\n    SNET3_NAME,\n    SNET_NAME,\n    PSEUDOOUT_NAME,\n    RNET_NAME,\n    XNET_NAME,\n    DRNET_NAME,\n    PWNET_NAME,\n    RANET_NAME,\n    TARNET_NAME,\n    FLEXTE_NAME,\n    OFFSET_NAME,\n]\nMODEL_DICT = {\n    T_NAME: TNet,\n    SNET1_NAME: SNet1,\n    SNET2_NAME: SNet2,\n    SNET3_NAME: SNet3,\n    SNET_NAME: SNet,\n    PSEUDOOUT_NAME: PseudoOutcomeNet,\n    RNET_NAME: RNet,\n    XNET_NAME: XNet,\n    DRNET_NAME: DRNet,\n    PWNET_NAME: PWNet,\n    RANET_NAME: RANet,\n    TARNET_NAME: TARNet,\n    DRAGON_NAME: DragonNet,\n    OFFSET_NAME: OffsetNet,\n    FLEXTE_NAME: FlexTENet,\n}\n\n__all__ = [\n    T_NAME,\n    SNET1_NAME,\n    SNET2_NAME,\n    SNET3_NAME,\n    SNET_NAME,\n    PSEUDOOUT_NAME,\n    RNET_NAME,\n    XNET_NAME,\n    DRNET_NAME,\n    PWNET_NAME,\n    RANET_NAME,\n    TARNET_NAME,\n    DRAGON_NAME,\n    FLEXTE_NAME,\n    OFFSET_NAME,\n]\n\n\ndef get_catenet(name: str) -> Any:\n    if name not in ALL_MODELS:\n        raise ValueError(\n            f\"Model name should be in catenets.models.jax.ALL_MODELS You passed {name}\"\n        )\n    return MODEL_DICT[name]\n"
  },
  {
    "path": "catenets/models/jax/base.py",
    "content": "\"\"\"\r\nBase modules shared across different nets\r\n\"\"\"\r\n# Author: Alicia Curth\r\nimport abc\r\nfrom typing import Any, Callable, List, Optional, Tuple\r\n\r\nimport jax.numpy as jnp\r\nimport numpy as onp\r\nfrom jax import grad, jit, random\r\nfrom jax.example_libraries import optimizers, stax\r\nfrom jax.example_libraries.stax import Dense, Elu, Relu, Sigmoid\r\nfrom sklearn.base import BaseEstimator, RegressorMixin\r\nfrom sklearn.model_selection import ParameterGrid\r\n\r\nimport catenets.logger as log\r\nfrom catenets.models.constants import (\r\n    DEFAULT_BATCH_SIZE,\r\n    DEFAULT_LAYERS_OUT,\r\n    DEFAULT_N_ITER,\r\n    DEFAULT_N_ITER_MIN,\r\n    DEFAULT_N_ITER_PRINT,\r\n    DEFAULT_NONLIN,\r\n    DEFAULT_PATIENCE,\r\n    DEFAULT_PENALTY_L2,\r\n    DEFAULT_SEED,\r\n    DEFAULT_STEP_SIZE,\r\n    DEFAULT_UNITS_OUT,\r\n    DEFAULT_UNITS_R,\r\n    DEFAULT_VAL_SPLIT,\r\n    LARGE_VAL,\r\n)\r\nfrom catenets.models.jax.model_utils import (\r\n    check_shape_1d_data,\r\n    check_X_is_np,\r\n    make_val_split,\r\n)\r\n\r\n\r\ndef ReprBlock(\r\n    n_layers: int = 3, n_units: int = 100, nonlin: str = DEFAULT_NONLIN\r\n) -> Any:\r\n    # Creates a representation block using jax.stax\r\n    # create first layer\r\n    if nonlin == \"elu\":\r\n        NL = Elu\r\n    elif nonlin == \"relu\":\r\n        NL = Relu\r\n    elif nonlin == \"sigmoid\":\r\n        NL = Sigmoid\r\n    else:\r\n        raise ValueError(\"Unknown nonlinearity\")\r\n\r\n    layers: Tuple\r\n    layers = (Dense(n_units), NL)\r\n\r\n    # add required number of layers\r\n    for i in range(n_layers - 1):\r\n        layers = (*layers, Dense(n_units), NL)\r\n\r\n    return stax.serial(*layers)\r\n\r\n\r\ndef OutputHead(\r\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\r\n    n_units_out: int = DEFAULT_UNITS_OUT,\r\n    binary_y: bool = False,\r\n    n_layers_r: int = 0,\r\n    n_units_r: int = DEFAULT_UNITS_R,\r\n    nonlin: str = DEFAULT_NONLIN,\r\n) -> Any:\r\n    # Creates an output head using jax.stax\r\n    if nonlin == \"elu\":\r\n        NL = Elu\r\n    elif nonlin == \"relu\":\r\n        NL = Relu\r\n    elif nonlin == \"sigmoid\":\r\n        NL = Sigmoid\r\n    else:\r\n        raise ValueError(\"Unknown nonlinearity\")\r\n\r\n    layers: Tuple = ()\r\n\r\n    # add required number of layers\r\n    for i in range(n_layers_r):\r\n        layers = (*layers, Dense(n_units_r), NL)\r\n\r\n    # add required number of layers\r\n    for i in range(n_layers_out):\r\n        layers = (*layers, Dense(n_units_out), NL)\r\n\r\n    # return final architecture\r\n    if not binary_y:\r\n        return stax.serial(*layers, Dense(1))\r\n    else:\r\n        return stax.serial(*layers, Dense(1), Sigmoid)\r\n\r\n\r\nclass BaseCATENet(BaseEstimator, RegressorMixin, abc.ABC):\r\n    \"\"\"\r\n    Base CATENet class to serve as template for all other nets\r\n    \"\"\"\r\n\r\n    def score(\r\n        self,\r\n        X: jnp.ndarray,\r\n        y: jnp.ndarray,\r\n        sample_weight: Optional[jnp.ndarray] = None,\r\n    ) -> float:\r\n        \"\"\"\r\n        Return the sqrt PEHE error (Oracle metric).\r\n\r\n        Parameters\r\n        ----------\r\n        X: pd.DataFrame or np.array\r\n            Covariate matrix\r\n        y: np.array\r\n            Expected potential outcome vector\r\n        \"\"\"\r\n        X = check_X_is_np(X)\r\n        y = check_X_is_np(y)\r\n        if len(X) != len(y):\r\n            raise ValueError(\"X/y length mismatch for score\")\r\n        if y.shape[-1] != 2:\r\n            raise ValueError(f\"y has invalid shape {y.shape}\")\r\n\r\n        hat_te = self.predict(X)\r\n\r\n        return jnp.sqrt(jnp.mean(((y[:, 1] - y[:, 0]) - hat_te) ** 2))\r\n\r\n    @abc.abstractmethod\r\n    def _get_train_function(self) -> Callable:\r\n        ...\r\n\r\n    def fit(\r\n        self,\r\n        X: jnp.ndarray,\r\n        y: jnp.ndarray,\r\n        w: jnp.ndarray,\r\n        p: Optional[jnp.ndarray] = None,\r\n    ) -> \"BaseCATENet\":\r\n        \"\"\"\r\n        Fit method for a CATENet. Takes covariates, outcome variable and treatment indicator as\r\n        input\r\n\r\n        Parameters\r\n        ----------\r\n        X: pd.DataFrame or np.array\r\n            Covariate matrix\r\n        y: np.array\r\n            Outcome vector\r\n        w: np.array\r\n            Treatment indicator\r\n        p: np.array\r\n            Vector of (known) treatment propensities. Currently only supported for TwoStepNets.\r\n        \"\"\"\r\n        # some quick input checks\r\n        if p is not None:\r\n            raise NotImplementedError(\"Only two-step-nets take p as input. \")\r\n        X = check_X_is_np(X)\r\n        self._check_inputs(w, p)\r\n\r\n        train_func = self._get_train_function()\r\n        train_params = self.get_params()\r\n\r\n        self._params, self._predict_funs = train_func(X, y, w, **train_params)\r\n\r\n        return self\r\n\r\n    @abc.abstractmethod\r\n    def _get_predict_function(self) -> Callable:\r\n        ...\r\n\r\n    def predict(\r\n        self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False\r\n    ) -> jnp.ndarray:\r\n        \"\"\"\r\n        Predict treatment effect estimates using a CATENet. Depending on method, can also return\r\n        potential outcome estimate and propensity score estimate.\r\n\r\n        Parameters\r\n        ----------\r\n        X: pd.DataFrame or np.array\r\n            Covariate matrix\r\n        return_po: bool, default False\r\n            Whether to return potential outcome estimate\r\n        return_prop: bool, default False\r\n            Whether to return propensity estimate\r\n\r\n        Returns\r\n        -------\r\n        array of CATE estimates, optionally also potential outcomes and propensity\r\n        \"\"\"\r\n        X = check_X_is_np(X)\r\n        predict_func = self._get_predict_function()\r\n        return predict_func(\r\n            X,\r\n            trained_params=self._params,\r\n            predict_funs=self._predict_funs,\r\n            return_po=return_po,\r\n            return_prop=return_prop,\r\n        )\r\n\r\n    @staticmethod\r\n    def _check_inputs(w: jnp.ndarray, p: jnp.ndarray) -> None:\r\n        if p is not None:\r\n            if onp.sum(p > 1) > 0 or onp.sum(p < 0) > 0:\r\n                raise ValueError(\"p should be in [0,1]\")\r\n\r\n        if not ((w == 0) | (w == 1)).all():\r\n            raise ValueError(\"W should be binary\")\r\n\r\n    def fit_and_select_params(\r\n        self,\r\n        X: jnp.ndarray,\r\n        y: jnp.ndarray,\r\n        w: jnp.ndarray,\r\n        p: Optional[jnp.ndarray] = None,\r\n        param_grid: dict = {},\r\n    ) -> \"BaseCATENet\":\r\n        # some quick input checks\r\n        if param_grid is None:\r\n            raise ValueError(\"No param_grid to evaluate. \")\r\n        X = check_X_is_np(X)\r\n        self._check_inputs(w, p)\r\n\r\n        param_grid = ParameterGrid(param_grid)\r\n        self_param_dict = self.get_params()\r\n        train_function = self._get_train_function()\r\n\r\n        models = []\r\n        losses = []\r\n        param_settings: list = []\r\n\r\n        for param_setting in param_grid:\r\n            log.debug(\r\n                \"Testing parameter setting: \"\r\n                + \" \".join(\r\n                    [key + \": \" + str(value) for key, value in param_setting.items()]\r\n                )\r\n            )\r\n            # replace params\r\n            train_param_dict = {\r\n                key: (val if key not in param_setting.keys() else param_setting[key])\r\n                for key, val in self_param_dict.items()\r\n            }\r\n            if p is not None:\r\n                params, funs, val_loss = train_function(\r\n                    X, y, w, p=p, return_val_loss=True, **train_param_dict\r\n                )\r\n            else:\r\n                params, funs, val_loss = train_function(\r\n                    X, y, w, return_val_loss=True, **train_param_dict\r\n                )\r\n\r\n            models.append((params, funs))\r\n            losses.append(val_loss)\r\n\r\n        # save results\r\n        param_settings.extend(param_grid)\r\n        self._selection_results = {\r\n            \"param_settings\": param_settings,\r\n            \"val_losses\": losses,\r\n        }\r\n\r\n        # find lowest loss and set params\r\n        best_idx = jnp.array(losses).argmin()\r\n        self._params, self._predict_funs = models[best_idx]\r\n        self.set_params(**param_settings[best_idx])\r\n\r\n        return self\r\n\r\n\r\ndef train_output_net_only(\r\n    X: jnp.ndarray,\r\n    y: jnp.ndarray,\r\n    binary_y: bool = False,\r\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\r\n    n_units_out: int = DEFAULT_UNITS_OUT,\r\n    n_layers_r: int = 0,\r\n    n_units_r: int = DEFAULT_UNITS_R,\r\n    penalty_l2: float = DEFAULT_PENALTY_L2,\r\n    step_size: float = DEFAULT_STEP_SIZE,\r\n    n_iter: int = DEFAULT_N_ITER,\r\n    batch_size: int = DEFAULT_BATCH_SIZE,\r\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\r\n    early_stopping: bool = True,\r\n    patience: int = DEFAULT_PATIENCE,\r\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\r\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\r\n    seed: int = DEFAULT_SEED,\r\n    return_val_loss: bool = False,\r\n    nonlin: str = DEFAULT_NONLIN,\r\n    avg_objective: bool = False,\r\n) -> Any:\r\n    # function to train a single output head\r\n    # input check\r\n    y = check_shape_1d_data(y)\r\n    d = X.shape[1]\r\n    input_shape = (-1, d)\r\n    rng_key = random.PRNGKey(seed)\r\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\r\n\r\n    # get validation split (can be none)\r\n    X, y, X_val, y_val, val_string = make_val_split(\r\n        X, y, val_split_prop=val_split_prop, seed=seed\r\n    )\r\n    n = X.shape[0]  # could be different from before due to split\r\n\r\n    # get output head\r\n    init_fun, predict_fun = OutputHead(\r\n        n_layers_out=n_layers_out,\r\n        n_units_out=n_units_out,\r\n        binary_y=binary_y,\r\n        n_layers_r=n_layers_r,\r\n        n_units_r=n_units_r,\r\n        nonlin=nonlin,\r\n    )\r\n\r\n    # get functions\r\n    if not binary_y:\r\n        # define loss and grad\r\n        @jit\r\n        def loss(\r\n            params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float\r\n        ) -> jnp.ndarray:\r\n            # mse loss function\r\n            inputs, targets = batch\r\n            preds = predict_fun(params, inputs)\r\n            weightsq = sum(\r\n                [\r\n                    jnp.sum(params[i][0] ** 2)\r\n                    for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2)\r\n                ]\r\n            )\r\n            if not avg_objective:\r\n                return jnp.sum((preds - targets) ** 2) + 0.5 * penalty * weightsq\r\n            else:\r\n                return jnp.average((preds - targets) ** 2) + 0.5 * penalty * weightsq\r\n\r\n    else:\r\n        # get loss and grad\r\n        @jit\r\n        def loss(\r\n            params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float\r\n        ) -> jnp.ndarray:\r\n            # mse loss function\r\n            inputs, targets = batch\r\n            preds = predict_fun(params, inputs)\r\n            weightsq = sum(\r\n                [\r\n                    jnp.sum(params[i][0] ** 2)\r\n                    for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2)\r\n                ]\r\n            )\r\n            if not avg_objective:\r\n                return (\r\n                    -jnp.sum(\r\n                        targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)\r\n                    )\r\n                    + 0.5 * penalty * weightsq\r\n                )\r\n            else:\r\n                return (\r\n                    -jnp.average(\r\n                        targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)\r\n                    )\r\n                    + 0.5 * penalty * weightsq\r\n                )\r\n\r\n    # set optimization routine\r\n    # set optimizer\r\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\r\n\r\n    # set update function\r\n    @jit\r\n    def update(i: int, state: dict, batch: jnp.ndarray, penalty: float) -> jnp.ndarray:\r\n        params = get_params(state)\r\n        g_params = grad(loss)(params, batch, penalty)\r\n        return opt_update(i, g_params, state)\r\n\r\n    # initialise states\r\n    _, init_params = init_fun(rng_key, input_shape)\r\n    opt_state = opt_init(init_params)\r\n\r\n    # calculate number of batches per epoch\r\n    batch_size = batch_size if batch_size < n else n\r\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\r\n    train_indices = onp.arange(n)\r\n\r\n    l_best = LARGE_VAL\r\n    p_curr = 0\r\n\r\n    # do training\r\n    for i in range(n_iter):\r\n        # shuffle data for minibatches\r\n        onp.random.shuffle(train_indices)\r\n        for b in range(n_batches):\r\n            idx_next = train_indices[\r\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\r\n            ]\r\n            next_batch = X[idx_next, :], y[idx_next, :]\r\n            opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2)\r\n\r\n        if (i % n_iter_print == 0) or early_stopping:\r\n            params_curr = get_params(opt_state)\r\n            l_curr = loss(params_curr, (X_val, y_val), penalty_l2)\r\n\r\n        if i % n_iter_print == 0:\r\n            log.info(f\"Epoch: {i}, current {val_string} loss: {l_curr}\")\r\n\r\n        if early_stopping and ((i + 1) * n_batches > n_iter_min):\r\n            # check if loss updated\r\n            if l_curr < l_best:\r\n                l_best = l_curr\r\n                p_curr = 0\r\n            else:\r\n                p_curr = p_curr + 1\r\n\r\n            if p_curr > patience:\r\n                trained_params = get_params(opt_state)\r\n\r\n                if return_val_loss:\r\n                    # return loss without penalty\r\n                    l_final = loss(trained_params, (X_val, y_val), 0)\r\n                    return trained_params, predict_fun, l_final\r\n\r\n                return trained_params, predict_fun\r\n\r\n    # get final parameters\r\n    trained_params = get_params(opt_state)\r\n\r\n    if return_val_loss:\r\n        # return loss without penalty\r\n        l_final = loss(trained_params, (X_val, y_val), 0)\r\n        return trained_params, predict_fun, l_final\r\n\r\n    return trained_params, predict_fun\r\n"
  },
  {
    "path": "catenets/models/jax/disentangled_nets.py",
    "content": "\"\"\"\nClass implements SNet-3, a variation on DR-CFR discussed in\nHassanpour and Greiner (2020) and Wu et al (2020).\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Callable, List, Tuple\n\nimport jax.numpy as jnp\nimport numpy as onp\nfrom jax import grad, jit, random\nfrom jax.example_libraries import optimizers\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_AVG_OBJECTIVE,\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_DISC,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_PENALTY_ORTHOGONAL,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_R_BIG_S3,\n    DEFAULT_UNITS_R_SMALL_S3,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.jax.base import BaseCATENet, OutputHead, ReprBlock\nfrom catenets.models.jax.model_utils import (\n    check_shape_1d_data,\n    heads_l2_penalty,\n    make_val_split,\n)\nfrom catenets.models.jax.representation_nets import mmd2_lin\n\n\n# helper functions to avoid abstract tracer values in jit\ndef _get_absolute_rowsums(mat: jnp.ndarray) -> jnp.ndarray:\n    return jnp.sum(jnp.abs(mat), axis=1)\n\n\ndef _concatenate_representations(reps: jnp.ndarray) -> jnp.ndarray:\n    return jnp.concatenate(reps, axis=1)\n\n\nclass SNet3(BaseCATENet):\n    \"\"\"\n    Class implements SNet-3, which is based on Hassanpour & Greiner (2020)'s DR-CFR (Without\n    propensity weighting), using an orthogonal regularizer to enforce decomposition similar to\n    Wu et al (2020).\n\n    Parameters\n    ----------\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_layers_out_prop: int\n        Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense\n        layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_units_out_prop: int\n        Number of hidden units in each propensity score hypothesis layer\n    n_layers_r: int\n        Number of shared & private representation layers before hypothesis layers\n    n_units_r: int\n        Number of hidden units in representation layer shared by propensity score and outcome\n        function (the 'confounding factor')\n    n_units_r_small: int\n        Number of hidden units in representation layer NOT shared by propensity score and outcome\n        functions (the 'outcome factor' and the 'instrumental factor')\n    penalty_l2: float\n        l2 (ridge) penalty\n    step_size: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    reg_diff: bool, default False\n        Whether to regularize the difference between the two potential outcome heads\n    penalty_diff: float\n        l2-penalty for regularizing the difference between output heads. used only if\n        train_separate=False\n    same_init: bool, False\n        Whether to initialise the two output heads with same values\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    penalty_disc: float, default zero\n        Discrepancy penalty. Defaults to zero as this feature is not tested.\n    \"\"\"\n\n    def __init__(\n        self,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R_BIG_S3,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n        penalty_disc: float = DEFAULT_PENALTY_DISC,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n        reg_diff: bool = False,\n        penalty_diff: float = DEFAULT_PENALTY_L2,\n        same_init: bool = False,\n    ) -> None:\n        self.binary_y = binary_y\n\n        self.n_layers_r = n_layers_r\n        self.n_layers_out = n_layers_out\n        self.n_layers_out_prop = n_layers_out_prop\n        self.n_units_r = n_units_r\n        self.n_units_r_small = n_units_r_small\n        self.n_units_out = n_units_out\n        self.n_units_out_prop = n_units_out_prop\n        self.nonlin = nonlin\n\n        self.penalty_l2 = penalty_l2\n        self.penalty_orthogonal = penalty_orthogonal\n        self.penalty_disc = penalty_disc\n        self.reg_diff = reg_diff\n        self.penalty_diff = penalty_diff\n        self.same_init = same_init\n\n        self.step_size = step_size\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n\n        self.seed = seed\n        self.n_iter_print = n_iter_print\n\n    def _get_predict_function(self) -> Callable:\n        return predict_snet3\n\n    def _get_train_function(self) -> Callable:\n        return train_snet3\n\n\n# SNET-3 -------------------------------------------------------------\ndef train_snet3(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R_BIG_S3,\n    n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_units_out_prop: int = DEFAULT_UNITS_OUT,\n    n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_disc: float = DEFAULT_PENALTY_DISC,\n    penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    reg_diff: bool = False,\n    penalty_diff: float = DEFAULT_PENALTY_L2,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n    same_init: bool = False,\n) -> Any:\n    \"\"\"\n    SNet-3, based on the decompostion used in Hassanpour and Greiner (2020)\n    \"\"\"\n    # function to train a net with 3 representations\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    if not reg_diff:\n        penalty_diff = penalty_l2\n\n    # get validation split (can be none)\n    X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n        X, y, w, val_split_prop=val_split_prop, seed=seed\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get representation layers\n    init_fun_repr, predict_fun_repr = ReprBlock(\n        n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin\n    )\n    init_fun_repr_small, predict_fun_repr_small = ReprBlock(\n        n_layers=n_layers_r, n_units=n_units_r_small, nonlin=nonlin\n    )\n\n    # get output head functions (output heads share same structure)\n    init_fun_head_po, predict_fun_head_po = OutputHead(\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        binary_y=binary_y,\n        nonlin=nonlin,\n    )\n    # add propensity head\n    init_fun_head_prop, predict_fun_head_prop = OutputHead(\n        n_layers_out=n_layers_out_prop,\n        n_units_out=n_units_out_prop,\n        binary_y=True,\n        nonlin=nonlin,\n    )\n\n    def init_fun_snet3(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]:\n        # chain together the layers\n        # param should look like [repr_c, repr_o, repr_t, po_0, po_1, prop]\n        # initialise representation layers\n        rng, layer_rng = random.split(rng)\n        input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape)\n        rng, layer_rng = random.split(rng)\n        input_shape_repr_small, param_repr_o = init_fun_repr_small(\n            layer_rng, input_shape\n        )\n        rng, layer_rng = random.split(rng)\n        _, param_repr_w = init_fun_repr_small(layer_rng, input_shape)\n\n        # each head gets two representations\n        input_shape_repr = input_shape_repr[:-1] + (\n            input_shape_repr[-1] + input_shape_repr_small[-1],\n        )\n\n        # initialise output heads\n        rng, layer_rng = random.split(rng)\n        if same_init:\n            # initialise both on same values\n            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)\n            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)\n        else:\n            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)\n            rng, layer_rng = random.split(rng)\n            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)\n        rng, layer_rng = random.split(rng)\n        input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr)\n        return input_shape, [\n            param_repr_c,\n            param_repr_o,\n            param_repr_w,\n            param_0,\n            param_1,\n            param_prop,\n        ]\n\n    # Define loss functions\n    # loss functions for the head\n    if not binary_y:\n\n        def loss_head(\n            params: List,\n            batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n            penalty: float,\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head_po(params, inputs)\n            return jnp.sum(weights * ((preds - targets) ** 2))\n\n    else:\n\n        def loss_head(\n            params: List,\n            batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n            penalty: float,\n        ) -> jnp.ndarray:\n            # log loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head_po(params, inputs)\n            return -jnp.sum(\n                weights\n                * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n            )\n\n    def loss_head_prop(\n        params: List,\n        batch: Tuple[jnp.ndarray, jnp.ndarray],\n        penalty: float,\n    ) -> jnp.ndarray:\n        # log loss function for propensities\n        inputs, targets = batch\n        preds = predict_fun_head_prop(params, inputs)\n        return -jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n\n    # complete loss function for all parts\n    @jit\n    def loss_snet3(\n        params: List,\n        batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n        penalty_l2: float,\n        penalty_orthogonal: float,\n        penalty_disc: float,\n    ) -> jnp.ndarray:\n        # params: list[repr_c, repr_o, repr_t, po_0, po_1, prop]\n        # batch: (X, y, w)\n        X, y, w = batch\n\n        # get representation\n        reps_c = predict_fun_repr(params[0], X)\n        reps_o = predict_fun_repr_small(params[1], X)\n        reps_w = predict_fun_repr_small(params[2], X)\n\n        # concatenate\n        reps_po = _concatenate_representations((reps_c, reps_o))\n        reps_prop = _concatenate_representations((reps_c, reps_w))\n\n        # pass down to heads\n        loss_0 = loss_head(params[3], (reps_po, y, 1 - w), penalty_l2)\n        loss_1 = loss_head(params[4], (reps_po, y, w), penalty_l2)\n\n        # pass down to propensity head\n        loss_prop = loss_head_prop(params[5], (reps_prop, w), penalty_l2)\n        weightsq_prop = sum(\n            [\n                jnp.sum(params[5][i][0] ** 2)\n                for i in range(0, 2 * n_layers_out_prop + 1, 2)\n            ]\n        )\n\n        # which variable has impact on which representation\n        col_c = _get_absolute_rowsums(params[0][0][0])\n        col_o = _get_absolute_rowsums(params[1][0][0])\n        col_w = _get_absolute_rowsums(params[2][0][0])\n        loss_o = penalty_orthogonal * (\n            jnp.sum(col_c * col_o + col_c * col_w + col_w * col_o)\n        )\n\n        # is rep_o balanced between groups?\n        loss_disc = penalty_disc * mmd2_lin(reps_o, w)\n\n        # weight decay on representations\n        weightsq_body = sum(\n            [\n                sum(\n                    [jnp.sum(params[j][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]\n                )\n                for j in range(3)\n            ]\n        )\n        weightsq_head = heads_l2_penalty(\n            params[3], params[4], n_layers_out, reg_diff, penalty_l2, penalty_diff\n        )\n\n        if not avg_objective:\n            return (\n                loss_0\n                + loss_1\n                + loss_prop\n                + loss_o\n                + loss_disc\n                + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)\n            )\n        else:\n            n_batch = y.shape[0]\n            return (\n                (loss_0 + loss_1) / n_batch\n                + loss_prop / n_batch\n                + loss_o\n                + loss_disc\n                + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)\n            )\n\n    # Define optimisation routine\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n\n    @jit\n    def update(\n        i: int,\n        state: dict,\n        batch: jnp.ndarray,\n        penalty_l2: float,\n        penalty_orthogonal: float,\n        penalty_disc: float,\n    ) -> jnp.ndarray:\n        # updating function\n        params = get_params(state)\n        return opt_update(\n            i,\n            grad(loss_snet3)(\n                params, batch, penalty_l2, penalty_orthogonal, penalty_disc\n            ),\n            state,\n        )\n\n    # initialise states\n    _, init_params = init_fun_snet3(rng_key, input_shape)\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    # do training\n    for i in range(n_iter):\n        # shuffle data for minibatches\n        onp.random.shuffle(train_indices)\n        for b in range(n_batches):\n            idx_next = train_indices[\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\n            ]\n            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]\n            opt_state = update(\n                i * n_batches + b,\n                opt_state,\n                next_batch,\n                penalty_l2,\n                penalty_orthogonal,\n                penalty_disc,\n            )\n\n        if (i % n_iter_print == 0) or early_stopping:\n            params_curr = get_params(opt_state)\n            l_curr = loss_snet3(\n                params_curr,\n                (X_val, y_val, w_val),\n                penalty_l2,\n                penalty_orthogonal,\n                penalty_disc,\n            )\n\n        if i % n_iter_print == 0:\n            log.info(f\"Epoch: {i}, current {val_string} loss {l_curr}\")\n\n        if early_stopping and ((i + 1) * n_batches > n_iter_min):\n            # check if loss updated\n            if l_curr < l_best:\n                l_best = l_curr\n                p_curr = 0\n                params_best = params_curr\n            else:\n                if onp.isnan(l_curr):\n                    # if diverged, return best\n                    return params_best, (\n                        predict_fun_repr,\n                        predict_fun_head_po,\n                        predict_fun_head_prop,\n                    )\n                p_curr = p_curr + 1\n\n            if p_curr > patience:\n                if return_val_loss:\n                    # return loss without penalty\n                    l_final = loss_snet3(params_curr, (X_val, y_val, w_val), 0, 0, 0)\n                    return (\n                        params_curr,\n                        (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop),\n                        l_final,\n                    )\n\n                return params_curr, (\n                    predict_fun_repr,\n                    predict_fun_head_po,\n                    predict_fun_head_prop,\n                )\n\n    # return the parameters\n    trained_params = get_params(opt_state)\n\n    if return_val_loss:\n        # return loss without penalty\n        l_final = loss_snet3(get_params(opt_state), (X_val, y_val, w_val), 0, 0)\n        return (\n            trained_params,\n            (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop),\n            l_final,\n        )\n\n    return trained_params, (\n        predict_fun_repr,\n        predict_fun_head_po,\n        predict_fun_head_prop,\n    )\n\n\ndef predict_snet3(\n    X: jnp.ndarray,\n    trained_params: dict,\n    predict_funs: list,\n    return_po: bool = False,\n    return_prop: bool = False,\n) -> jnp.ndarray:\n    # unpack inputs\n    predict_fun_repr, predict_fun_head, predict_fun_prop = predict_funs\n    param_repr_c, param_repr_o, param_repr_t = (\n        trained_params[0],\n        trained_params[1],\n        trained_params[2],\n    )\n    param_0, param_1, param_prop = (\n        trained_params[3],\n        trained_params[4],\n        trained_params[5],\n    )\n\n    # get representations\n    rep_c = predict_fun_repr(param_repr_c, X)\n    rep_o = predict_fun_repr(param_repr_o, X)\n    rep_w = predict_fun_repr(param_repr_t, X)\n\n    # concatenate\n    reps_po = jnp.concatenate((rep_c, rep_o), axis=1)\n    reps_prop = jnp.concatenate((rep_c, rep_w), axis=1)\n\n    # get potential outcomes\n    mu_0 = predict_fun_head(param_0, reps_po)\n    mu_1 = predict_fun_head(param_1, reps_po)\n\n    te = mu_1 - mu_0\n    if return_prop:\n        # get propensity\n        prop = predict_fun_prop(param_prop, reps_prop)\n\n    # stack other outputs\n    if return_po:\n        if return_prop:\n            return te, mu_0, mu_1, prop\n        else:\n            return te, mu_0, mu_1\n    else:\n        if return_prop:\n            return te, prop\n        else:\n            return te\n"
  },
  {
    "path": "catenets/models/jax/flextenet.py",
    "content": "\"\"\"\nModule implements FlexTENet, also referred to as the 'flexible approach' in \"On inductive biases\nfor heterogeneous treatment effect estimation\", Curth & vd Schaar (2021).\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Callable, Optional, Tuple\n\nimport jax.numpy as jnp\nimport numpy as onp\nfrom jax import grad, jit, random\nfrom jax.example_libraries import optimizers\nfrom jax.example_libraries.stax import (\n    Dense,\n    Sigmoid,\n    elu,\n    glorot_normal,\n    normal,\n    serial,\n)\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_DIM_P_OUT,\n    DEFAULT_DIM_P_R,\n    DEFAULT_DIM_S_OUT,\n    DEFAULT_DIM_S_R,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_PENALTY_ORTHOGONAL,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n    N_SUBSPACES,\n)\nfrom catenets.models.jax.base import BaseCATENet\nfrom catenets.models.jax.model_utils import check_shape_1d_data, make_val_split\n\n\nclass FlexTENet(BaseCATENet):\n    \"\"\"\n    Module implements FlexTENet, an architecture for treatment effect estimation that allows for\n    both shared and private information in each layer of the network.\n\n    Parameters\n    ----------\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_s_out: int\n        Number of hidden units in each shared hypothesis layer\n    n_units_p_out: int\n        Number of hidden units in each private hypothesis layer\n    n_layers_r: int\n        Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_s_r: int\n        Number of hidden units in each shared representation layer\n    n_units_s_r: int\n        Number of hidden units in each private representation layer\n    private_out: bool, False\n        Whether the final prediction layer should be fully private, or retain a shared component.\n    penalty_l2: float\n        l2 (ridge) penalty\n    penalty_l2_p: float\n        l2 (ridge) penalty for private layers\n    penalty_orthogonal: float\n        orthogonalisation penalty\n    step_size: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    opt: str, default 'adam'\n        Optimizer to use, accepts 'adam' and 'sgd'\n    shared_repr: bool, False\n        Whether to use a shared representation block as TARNet\n    pretrain_shared: bool, False\n        Whether to pretrain the shared component of the network while freezing the private\n        parameters\n    same_init: bool, True\n        Whether to use the same initialisation for all private spaces\n    lr_scale: float\n        Whether to scale down the learning rate after unfreezing the private components of the\n        network (only used if pretrain_shared=True)\n    normalize_ortho: bool, False\n        Whether to normalize the orthogonality penalty (by depth of network)\n    \"\"\"\n\n    def __init__(\n        self,\n        binary_y: bool = False,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_s_out: int = DEFAULT_DIM_S_OUT,\n        n_units_p_out: int = DEFAULT_DIM_P_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_s_r: int = DEFAULT_DIM_S_R,\n        n_units_p_r: int = DEFAULT_DIM_P_R,\n        private_out: bool = False,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_l2_p: float = DEFAULT_PENALTY_L2,\n        penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        return_val_loss: bool = False,\n        opt: str = \"adam\",\n        shared_repr: bool = False,\n        pretrain_shared: bool = False,\n        same_init: bool = True,\n        lr_scale: float = 10,\n        normalize_ortho: bool = False,\n    ) -> None:\n        self.binary_y = binary_y\n\n        self.n_layers_r = n_layers_r\n        self.n_layers_out = n_layers_out\n        self.n_units_s_out = n_units_s_out\n        self.n_units_p_out = n_units_p_out\n        self.n_units_s_r = n_units_s_r\n        self.n_units_p_r = n_units_p_r\n        self.private_out = private_out\n\n        self.penalty_orthogonal = penalty_orthogonal\n        self.penalty_l2 = penalty_l2\n        self.penalty_l2_p = penalty_l2_p\n        self.step_size = step_size\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n        self.opt = opt\n        self.same_init = same_init\n        self.shared_repr = shared_repr\n        self.normalize_ortho = normalize_ortho\n        self.pretrain_shared = pretrain_shared\n        self.lr_scale = lr_scale\n\n        self.seed = seed\n        self.n_iter_print = n_iter_print\n        self.return_val_loss = return_val_loss\n\n    def _get_train_function(self) -> Callable:\n        return train_flextenet\n\n    def _get_predict_function(self) -> Callable:\n        return predict_flextenet\n\n\ndef train_flextenet(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_s_out: int = DEFAULT_DIM_S_OUT,\n    n_units_p_out: int = DEFAULT_DIM_P_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_s_r: int = DEFAULT_DIM_S_R,\n    n_units_p_r: int = DEFAULT_DIM_P_R,\n    private_out: bool = False,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_l2_p: float = DEFAULT_PENALTY_L2,\n    penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    avg_objective: bool = True,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    opt: str = \"adam\",\n    shared_repr: bool = False,\n    pretrain_shared: bool = False,\n    same_init: bool = True,\n    lr_scale: float = 10,\n    normalize_ortho: bool = False,\n    nonlin: str = DEFAULT_NONLIN,\n    n_units_r: Optional[int] = None,\n    n_units_out: Optional[int] = None,\n) -> Tuple:  # TODO incorporate different nonlins here\n    # function to train a single output head\n    # input check\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    # get validation split (can be none)\n    X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n        X, y, w, val_split_prop=val_split_prop, seed=seed\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get output head\n    init_fun, predict_fun = FlexTENetArchitecture(\n        n_layers_out=n_layers_out,\n        n_layers_r=n_layers_r,\n        n_units_p_r=n_units_p_r,\n        n_units_p_out=n_units_p_out,\n        n_units_s_r=n_units_s_r,\n        n_units_s_out=n_units_s_out,\n        private_out=private_out,\n        shared_repr=shared_repr,\n        same_init=same_init,\n        binary_y=binary_y,\n    )\n\n    # get functions\n    if not binary_y:\n        # define loss and grad\n        @jit\n        def loss(\n            params: jnp.ndarray,\n            batch: jnp.ndarray,\n            penalty_l2: float,\n            penalty_l2_p: float,\n            penalty_orthogonal: float,\n            mode: int,\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets = batch\n            preds = predict_fun(params, inputs, mode=mode)\n            penalty = _compute_penalty(\n                params,\n                n_layers_out,\n                n_layers_r,\n                private_out,\n                penalty_l2,\n                penalty_l2_p,\n                penalty_orthogonal,\n                shared_repr,\n                normalize_ortho,\n                mode,\n            )\n            if not avg_objective:\n                return jnp.sum((preds - targets) ** 2) + penalty\n            else:\n                return jnp.average((preds - targets) ** 2) + penalty\n\n    else:\n        # get loss and grad\n        @jit\n        def loss(\n            params: jnp.ndarray,\n            batch: jnp.ndarray,\n            penalty_l2: float,\n            penalty_l2_p: float,\n            penalty_orthogonal: float,\n            mode: int,\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets = batch\n            preds = predict_fun(params, inputs, mode=mode)\n            penalty = _compute_penalty(\n                params,\n                n_layers_out,\n                n_layers_r,\n                private_out,\n                penalty_l2,\n                penalty_l2_p,\n                penalty_orthogonal,\n                shared_repr,\n                normalize_ortho,\n                mode,\n            )\n            if not avg_objective:\n                return (\n                    -jnp.sum(\n                        targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)\n                    )\n                    + penalty\n                )\n            else:\n                return (\n                    -jnp.average(\n                        targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)\n                    )\n                    + penalty\n                )\n\n    # set optimization routine\n    # set optimizer\n    if opt == \"adam\":\n        opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n    elif opt == \"sgd\":\n        opt_init, opt_update, get_params = optimizers.sgd(step_size=step_size)\n    else:\n        raise ValueError(\"opt should be adam or sgd\")\n\n    # set update function\n    @jit\n    def update(\n        i: int,\n        state: dict,\n        batch: jnp.ndarray,\n        penalty_l2: float,\n        penalty_l2_p: float,\n        penalty_orthogonal: float,\n        mode: int,\n    ) -> jnp.ndarray:\n        params = get_params(state)\n        g_params = grad(loss)(\n            params, batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode\n        )\n        return opt_update(i, g_params, state)\n\n    # initialise states\n    _, init_params = init_fun(rng_key, input_shape)\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    # do training\n    if not pretrain_shared:  # train entire model together\n        for i in range(n_iter):\n            # shuffle data for minibatches\n            onp.random.shuffle(train_indices)\n            for b in range(n_batches):\n                idx_next = train_indices[\n                    (b * batch_size) : min((b + 1) * batch_size, n - 1)\n                ]\n                next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :]\n                opt_state = update(\n                    i * n_batches + b,\n                    opt_state,\n                    next_batch,\n                    penalty_l2,\n                    penalty_l2_p,\n                    penalty_orthogonal,\n                    mode=1,\n                )\n\n            if (i % n_iter_print == 0) or early_stopping:\n                params_curr = get_params(opt_state)\n                l_curr = loss(\n                    params_curr,\n                    ((X_val, w_val), y_val),\n                    penalty_l2,\n                    penalty_l2_p,\n                    penalty_orthogonal,\n                    mode=1,\n                )\n\n            if i % n_iter_print == 0:\n                log.debug(f\"Epoch: {i}, current {val_string} loss: {l_curr}\")\n\n            if early_stopping and ((i + 1) * n_batches > n_iter_min):\n                # check if loss updated\n                if l_curr < l_best:\n                    l_best = l_curr\n                    p_curr = 0\n                else:\n                    p_curr = p_curr + 1\n\n                if p_curr > patience:\n                    trained_params = get_params(opt_state)\n\n                    if return_val_loss:\n                        # return loss without penalty\n                        l_final = loss(\n                            trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1\n                        )\n                        return trained_params, predict_fun, l_final\n\n                    return trained_params, predict_fun\n\n        # get final parameters\n        trained_params = get_params(opt_state)\n\n        if return_val_loss:\n            # return loss without penalty\n            l_final = loss(trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1)\n            return trained_params, predict_fun, l_final\n\n        return trained_params, predict_fun\n    else:\n        # Step 1: pretrain only shared bit of network (mode=0)\n        for i in range(n_iter):\n            # shuffle data for minibatches\n            onp.random.shuffle(train_indices)\n            for b in range(n_batches):\n                idx_next = train_indices[\n                    (b * batch_size) : min((b + 1) * batch_size, n - 1)\n                ]\n                next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :]\n                opt_state = update(\n                    i * n_batches + b,\n                    opt_state,\n                    next_batch,\n                    penalty_l2,\n                    penalty_l2_p,\n                    penalty_orthogonal,\n                    mode=0,\n                )\n\n            if (i % n_iter_print == 0) or early_stopping:\n                params_curr = get_params(opt_state)\n                l_curr = loss(\n                    params_curr,\n                    ((X_val, w_val), y_val),\n                    penalty_l2,\n                    penalty_l2_p,\n                    penalty_orthogonal,\n                    mode=0,\n                )\n\n            if i % n_iter_print == 0:\n                log.debug(\n                    f\"Pre-training epoch: {i}, current {val_string} loss: {l_curr}\"\n                )\n\n            if early_stopping and ((i + 1) * n_batches > n_iter_min):\n                # check if loss updated\n                if l_curr < l_best:\n                    l_best = l_curr\n                    p_curr = 0\n                else:\n                    p_curr = p_curr + 1\n\n                if p_curr > patience:\n                    break\n\n        # get final parameters\n        pre_trained_params = get_params(opt_state)\n\n        # Step 2: train also private parts of network (mode=1)\n        # set new optimizer\n        if opt == \"adam\":\n            opt_init2, opt_update2, get_params2 = optimizers.adam(\n                step_size=step_size / lr_scale\n            )\n        elif opt == \"sgd\":\n            opt_init2, opt_update2, get_params2 = optimizers.sgd(\n                step_size=step_size / lr_scale\n            )\n        else:\n            raise ValueError(\"opt should be adam or sgd\")\n\n        # set update function\n        @jit\n        def update2(\n            i: int,\n            state: dict,\n            batch: jnp.ndarray,\n            penalty_l2: float,\n            penalty_l2_p: float,\n            penalty_orthogonal: float,\n            mode: int,\n        ) -> Any:\n            params = get_params(state)\n            g_params = grad(loss)(\n                params, batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode\n            )\n            return opt_update2(i, g_params, state)\n\n        opt_state = opt_init2(pre_trained_params)\n        l_best = LARGE_VAL\n        p_curr = 0\n\n        # train full\n        for i in range(n_iter):\n            # shuffle data for minibatches\n            onp.random.shuffle(train_indices)\n            for b in range(n_batches):\n                idx_next = train_indices[\n                    (b * batch_size) : min((b + 1) * batch_size, n - 1)\n                ]\n                next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :]\n                opt_state = update2(\n                    i * n_batches + b,\n                    opt_state,\n                    next_batch,\n                    penalty_l2,\n                    penalty_l2_p,\n                    penalty_orthogonal,\n                    mode=1,\n                )\n\n            if (i % n_iter_print == 0) or early_stopping:\n                params_curr = get_params2(opt_state)\n                l_curr = loss(\n                    params_curr,\n                    ((X_val, w_val), y_val),\n                    penalty_l2,\n                    penalty_l2_p,\n                    penalty_orthogonal,\n                    mode=1,\n                )\n\n            if i % n_iter_print == 0:\n                log.debug(f\"Epoch: {i}, current {val_string} loss: {l_curr}\")\n\n            if early_stopping and ((i + 1) * n_batches > n_iter_min):\n                # check if loss updated\n                if l_curr < l_best:\n                    l_best = l_curr\n                    p_curr = 0\n                else:\n                    p_curr = p_curr + 1\n\n                if p_curr > patience:\n                    trained_params = get_params2(opt_state)\n\n                    if return_val_loss:\n                        # return loss without penalty\n                        l_final = loss(\n                            trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1\n                        )\n                        return trained_params, predict_fun, l_final\n\n                    return trained_params, predict_fun\n\n        # get final parameters\n        trained_params = get_params2(opt_state)\n\n        if return_val_loss:\n            # return loss without penalty\n            l_final = loss(trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1)\n            return trained_params, predict_fun, l_final\n\n        return trained_params, predict_fun\n\n\ndef predict_flextenet(\n    X: jnp.ndarray,\n    trained_params: jnp.ndarray,\n    predict_funs: Callable,\n    return_po: bool = False,\n    return_prop: bool = False,\n) -> Any:\n    # unpack inputs\n    n, _ = X.shape\n\n    W1 = check_shape_1d_data(jnp.ones(n))\n    W0 = check_shape_1d_data(jnp.zeros(n))\n\n    # get potential outcomes\n    mu_0 = predict_funs(trained_params, (X, W0))\n    mu_1 = predict_funs(trained_params, (X, W1))\n\n    te = mu_1 - mu_0\n    if return_prop:\n        raise ValueError(\"does not have propensity score estimator\")\n\n    # stack other outputs\n    if return_po:\n        return te, mu_0, mu_1\n    else:\n        return te\n\n\n# helper functions for training\ndef _get_cos_reg(\n    params_0: jnp.ndarray, params_1: jnp.ndarray, normalize: bool\n) -> jnp.ndarray:\n    if normalize:\n        params_0 = params_0 / jnp.linalg.norm(params_0, axis=0)\n        params_1 = params_1 / jnp.linalg.norm(params_1, axis=0)\n\n    return jnp.linalg.norm(jnp.dot(jnp.transpose(params_0), params_1), \"fro\") ** 2\n\n\ndef _compute_ortho_penalty_asymmetric(\n    params: jnp.ndarray,\n    n_layers_out: int,\n    n_layers_r: int,\n    private_out: int,\n    penalty_orthogonal: float,\n    shared_repr: bool,\n    normalize_ortho: bool,\n    mode: int = 1,\n) -> float:\n    # where to start counting: is there a fully shared representation?\n    if shared_repr:\n        lb = 2 * n_layers_r\n    else:\n        lb = 0\n\n    n_in = [\n        params[i][0][0].shape[0] for i in range(lb, 2 * (n_layers_out + n_layers_r), 2)\n    ]\n\n    ortho_body = _get_cos_reg(params[lb][1][0], params[lb][2][0], normalize_ortho)\n    ortho_body = ortho_body + sum(\n        [\n            _get_cos_reg(\n                params[i][0][0],\n                params[i][1][0][: n_in[int(i / 2 - lb / 2)], :],\n                normalize_ortho,\n            )\n            + _get_cos_reg(\n                params[i][0][0],\n                params[i][2][0][: n_in[int(i / 2 - lb / 2)], :],\n                normalize_ortho,\n            )\n            for i in range(lb, 2 * (n_layers_out + n_layers_r), 2)\n        ]\n    )\n\n    if not private_out:\n        # add also orthogonal regularization on final layer\n        idx_out = 2 * (n_layers_r + n_layers_out)\n        n_idx = params[idx_out][0][0].shape[0]\n\n        ortho_body = (\n            ortho_body\n            + _get_cos_reg(\n                params[idx_out][0][0],\n                params[idx_out][1][0][:n_idx, :],\n                normalize_ortho,\n            )\n            + _get_cos_reg(\n                params[idx_out][0][0], params[idx_out][2][0][:n_idx, :], normalize_ortho\n            )\n        )\n\n    return mode * penalty_orthogonal * ortho_body\n\n\ndef _compute_penalty_l2(\n    params: jnp.ndarray,\n    n_layers_out: int,\n    n_layers_r: int,\n    private_out: int,\n    penalty_l2: float,\n    penalty_l2_p: float,\n    shared_repr: bool,\n    mode: int = 1,\n) -> jnp.ndarray:\n    n_bodys = N_SUBSPACES\n\n    # compute l2 penalty\n    if shared_repr:\n        # get representation and then heads\n        weightsq_body = penalty_l2 * sum(\n            [jnp.sum(params[i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]\n        )\n        weightsq_body = weightsq_body + penalty_l2 * sum(\n            [\n                jnp.sum(params[i][0][0] ** 2)\n                for i in range(2 * n_layers_r, 2 * (n_layers_out + n_layers_r), 2)\n            ]\n        )\n        weightsq_body = weightsq_body + penalty_l2_p * mode * sum(\n            [\n                sum(\n                    [\n                        jnp.sum(params[i][j][0] ** 2)\n                        for i in range(\n                            2 * n_layers_r, 2 * (n_layers_out + n_layers_r), 2\n                        )\n                    ]\n                )\n                for j in range(1, n_bodys)\n            ]\n        )\n    else:\n        weightsq_body = penalty_l2 * sum(\n            [\n                jnp.sum(params[i][0][0] ** 2)\n                for i in range(0, 2 * (n_layers_out + n_layers_r), 2)\n            ]\n        )\n        weightsq_body = weightsq_body + penalty_l2_p * mode * sum(\n            [\n                sum(\n                    [\n                        jnp.sum(params[i][j][0] ** 2)\n                        for i in range(0, 2 * (n_layers_out + n_layers_r), 2)\n                    ]\n                )\n                for j in range(1, n_bodys)\n            ]\n        )\n\n    idx_out = 2 * (n_layers_r + n_layers_out)\n    if private_out:\n        weightsq = (\n            weightsq_body\n            + penalty_l2 * jnp.sum(params[idx_out][0][0] ** 2)\n            + jnp.sum(params[idx_out][1][0] ** 2)\n        )\n    else:\n        weightsq = (\n            weightsq_body\n            + penalty_l2 * jnp.sum(params[idx_out][0][0] ** 2)\n            + penalty_l2_p * mode * jnp.sum(params[idx_out][1][0] ** 2)\n            + penalty_l2_p * mode * jnp.sum(params[idx_out][2][0] ** 2)\n        )\n\n    return 0.5 * weightsq\n\n\ndef _compute_penalty(\n    params: jnp.ndarray,\n    n_layers_out: int,\n    n_layers_r: int,\n    private_out: int,\n    penalty_l2: float,\n    penalty_l2_p: float,\n    penalty_orthogonal: float,\n    shared_repr: bool,\n    normalize_ortho: bool,\n    mode: int = 1,\n) -> jnp.ndarray:\n    l2_penalty = _compute_penalty_l2(\n        params,\n        n_layers_out,\n        n_layers_r,\n        private_out,\n        penalty_l2,\n        penalty_l2_p,\n        shared_repr,\n        mode,\n    )\n\n    ortho_penalty = _compute_ortho_penalty_asymmetric(\n        params,\n        n_layers_out,\n        n_layers_r,\n        private_out,\n        penalty_orthogonal,\n        shared_repr,\n        normalize_ortho,\n        mode,\n    )\n\n    return l2_penalty + ortho_penalty\n\n\n# ------------------------------------------------------------\n# construction of FlexTENetlayers/architecture\ndef SplitLayerAsymmetric(\n    n_units_s: int, n_units_p: int, first_layer: bool = False, same_init: bool = True\n) -> Tuple:\n    # create multitask layer has shape [shared, private_0, private_1]\n    init_s, apply_s = Dense(n_units_s)\n    init_p, apply_p = Dense(n_units_p)\n\n    def init_fun(rng: float, input_shape: Tuple) -> Tuple:\n        if first_layer:  # put input shape in expected format\n            input_shape = (input_shape, input_shape, input_shape)\n        out_shape = (\n            input_shape[0][:-1] + (n_units_s,),\n            input_shape[1][:-1] + (n_units_p + n_units_s,),\n            input_shape[2][:-1] + (n_units_p + n_units_s,),\n        )\n\n        rng_1, rng_2, rng_3 = random.split(rng, N_SUBSPACES)\n        if same_init:  # use same init for the two private layers\n            return out_shape, (\n                init_s(rng_1, input_shape[0])[1],\n                init_p(rng_2, input_shape[1])[1],\n                init_p(rng_2, input_shape[2])[1],\n            )\n        else:  # initialise all separately\n            return out_shape, (\n                init_s(rng_1, input_shape[0])[1],\n                init_p(rng_2, input_shape[1])[1],\n                init_p(rng_3, input_shape[2])[1],\n            )\n\n    def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple:\n        mode = kwargs[\"mode\"] if \"mode\" in kwargs.keys() else 1\n        if first_layer:\n            # X is the only input\n            X, W = inputs\n            rep_s = apply_s(params[0], X)\n            rep_p0 = mode * apply_p(params[1], X)\n            rep_p1 = mode * apply_p(params[2], X)\n        else:\n            X_s, X_p0, X_p1, W = inputs\n            rep_s = apply_s(params[0], X_s)\n            rep_p0 = mode * apply_p(params[1], jnp.concatenate([X_s, X_p0], axis=1))\n            rep_p1 = mode * apply_p(params[2], jnp.concatenate([X_s, X_p1], axis=1))\n        return (rep_s, rep_p0, rep_p1, W)\n\n    return init_fun, apply_fun\n\n\ndef TEOutputLayerAsymmetric(private: bool = True, same_init: bool = True) -> Tuple:\n    init_f, apply_f = Dense(1)\n    if private:\n        # the two output layers are private\n        def init_fun(rng: float, input_shape: Tuple) -> Tuple:\n            out_shape = input_shape[1][:-1] + (1,)\n            rng_1, rng_2 = random.split(rng, N_SUBSPACES - 1)\n            return out_shape, (\n                init_f(rng_1, input_shape[1])[1],\n                init_f(rng_2, input_shape[2])[1],\n            )\n\n        def apply_fun(params: jnp.ndarray, inputs: Tuple, **kwargs: Any) -> jnp.ndarray:\n            X_s, X_p0, X_p1, W = inputs\n            rep_p0 = apply_f(params[0], jnp.concatenate([X_s, X_p0], axis=1))\n            rep_p1 = apply_f(params[1], jnp.concatenate([X_s, X_p1], axis=1))\n            return (1 - W) * rep_p0 + W * rep_p1\n\n    else:\n        # also have a shared piece of output layer\n        def init_fun(rng: float, input_shape: Tuple) -> Tuple:\n            out_shape = input_shape[1][:-1] + (1,)\n            rng_1, rng_2, rng_3 = random.split(rng, N_SUBSPACES)\n            if same_init:\n                return out_shape, (\n                    init_f(rng_1, input_shape[0])[1],\n                    init_f(rng_2, input_shape[1])[1],\n                    init_f(rng_2, input_shape[2])[1],\n                )\n            else:\n                return out_shape, (\n                    init_f(rng_1, input_shape[0])[1],\n                    init_f(rng_2, input_shape[1])[1],\n                    init_f(rng_3, input_shape[2])[1],\n                )\n\n        def apply_fun(params: jnp.ndarray, inputs: Tuple, **kwargs: Any) -> jnp.ndarray:\n            mode = kwargs[\"mode\"] if \"mode\" in kwargs.keys() else 1\n            X_s, X_p0, X_p1, W = inputs\n            rep_s = apply_f(params[0], X_s)\n            rep_p0 = mode * apply_f(params[1], jnp.concatenate([X_s, X_p0], axis=1))\n            rep_p1 = mode * apply_f(params[2], jnp.concatenate([X_s, X_p1], axis=1))\n            return (1 - W) * rep_p0 + W * rep_p1 + rep_s\n\n    return init_fun, apply_fun\n\n\ndef FlexTENetArchitecture(\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_s_out: int = DEFAULT_DIM_S_OUT,\n    n_units_p_out: int = DEFAULT_DIM_P_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_s_r: int = DEFAULT_DIM_S_R,\n    n_units_p_r: int = DEFAULT_DIM_P_R,\n    private_out: bool = False,\n    binary_y: bool = False,\n    shared_repr: bool = False,\n    same_init: bool = True,\n) -> Any:\n    if n_layers_out < 1:\n        raise ValueError(\n            \"FlexTENet needs at least one hidden output layer (else there are no \"\n            \"parameters to be shared)\"\n        )\n\n    Nonlin_Elu = Elu_parallel\n    Layer = SplitLayerAsymmetric\n    Head = TEOutputLayerAsymmetric\n\n    # give broader body (as in e.g. TARNet)\n    has_body = n_layers_r > 0\n\n    layers: Tuple = ()\n    if has_body:\n        # representation block first\n        if shared_repr:  # fully shared representation as in TARNet\n            layers = (DenseW(n_units_s_r), Elu_split)\n\n            # add required number of layers\n            for i in range(n_layers_r - 1):\n                layers = (*layers, DenseW(n_units_s_r), Elu_split)\n\n        else:  # shared AND private representations\n            layers = (\n                Layer(n_units_s_r, n_units_p_r, first_layer=True, same_init=same_init),\n                Nonlin_Elu,\n            )\n\n            # add required number of layers\n            for i in range(n_layers_r - 1):\n                layers = (\n                    *layers,\n                    Layer(n_units_s_r, n_units_p_r, same_init=same_init),\n                    Nonlin_Elu,\n                )\n    else:\n        layers = ()\n\n    # add output layers\n    first_layer = (has_body is False) | (shared_repr is True)\n    layers = (\n        *layers,\n        Layer(\n            n_units_s_out, n_units_p_out, first_layer=first_layer, same_init=same_init\n        ),\n        Nonlin_Elu,\n    )\n\n    if n_layers_out > 1:\n        # add required number of layers\n        for i in range(n_layers_out - 1):\n            layers = (\n                *layers,\n                Layer(n_units_s_out, n_units_p_out, same_init=same_init),\n                Nonlin_Elu,\n            )\n\n    # return final architecture\n    if not binary_y:\n        return serial(*layers, Head(private=private_out, same_init=same_init))\n    else:\n        return serial(*layers, Head(private=private_out, same_init=same_init), Sigmoid)\n\n\n# ------------------------------------------------\n# rewrite some jax.stax code to allow different input types to be passed\ndef elementwise_split(fun: Callable, **fun_kwargs: Any) -> Tuple:\n    \"\"\"Layer that applies a scalar function elementwise on its inputs. Adapted from original\n    jax.stax to skip treatment indicator.\n\n    Input looks like: X, t = inputs\"\"\"\n\n    def init_fun(rng: float, input_shape: Tuple) -> Tuple:\n        return (input_shape, ())\n\n    def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple:\n        return fun(inputs[0], **fun_kwargs), inputs[1]\n\n    return init_fun, apply_fun\n\n\nElu_split = elementwise_split(elu)\n\n\ndef elementwise_parallel(fun: Callable, **fun_kwargs: Any) -> Tuple:\n    \"\"\"Layer that applies a scalar function elementwise on its inputs. Adapted from original\n    jax.stax to allow three inputs and to skip treatment indicator.\n\n    Input looks like: X_s, X_p0, X_p1, t = inputs\n    \"\"\"\n\n    def init_fun(rng: float, input_shape: Tuple) -> Tuple:\n        return input_shape, ()\n\n    def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple:\n        return (\n            fun(inputs[0], **fun_kwargs),\n            fun(inputs[1], **fun_kwargs),\n            fun(inputs[2], **fun_kwargs),\n            inputs[3],\n        )\n\n    return init_fun, apply_fun\n\n\nElu_parallel = elementwise_parallel(elu)\n\n\ndef DenseW(\n    out_dim: int, W_init: Callable = glorot_normal(), b_init: Callable = normal()\n) -> Tuple:\n    \"\"\"Layer constructor function for a dense (fully-connected) layer. Adapted to allow passing\n    treatment indicator through layer without using it\"\"\"\n\n    def init_fun(rng: float, input_shape: Tuple) -> Tuple:\n        output_shape = input_shape[:-1] + (out_dim,)\n        k1, k2 = random.split(rng)\n        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))\n        return output_shape, (W, b)\n\n    def apply_fun(\n        params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any\n    ) -> jnp.ndarray:\n        W, b = params\n        x, t = inputs\n        return (jnp.dot(x, W) + b, t)\n\n    return init_fun, apply_fun\n"
  },
  {
    "path": "catenets/models/jax/model_utils.py",
    "content": "\"\"\"\r\nModel utils shared across different nets\r\n\"\"\"\r\n# Author: Alicia Curth\r\nfrom typing import Any, Optional\r\n\r\nimport jax.numpy as jnp\r\nimport pandas as pd\r\nfrom sklearn.model_selection import train_test_split\r\n\r\nfrom catenets.models.constants import DEFAULT_SEED, DEFAULT_VAL_SPLIT\r\n\r\nTRAIN_STRING = \"training\"\r\nVALIDATION_STRING = \"validation\"\r\n\r\n\r\ndef check_shape_1d_data(y: jnp.ndarray) -> jnp.ndarray:\r\n    # helper func to ensure that output shape won't clash\r\n    # with jax internally\r\n    shape_y = y.shape\r\n    if len(shape_y) == 1:\r\n        # should be shape (n_obs, 1), not (n_obs,)\r\n        return y.reshape((shape_y[0], 1))\r\n    return y\r\n\r\n\r\ndef check_X_is_np(X: pd.DataFrame) -> jnp.ndarray:\r\n    # function to make sure we are using arrays only\r\n    return jnp.asarray(X)\r\n\r\n\r\ndef make_val_split(\r\n    X: jnp.ndarray,\r\n    y: jnp.ndarray,\r\n    w: Optional[jnp.ndarray] = None,\r\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\r\n    seed: int = DEFAULT_SEED,\r\n    stratify_w: bool = True,\r\n) -> Any:\r\n    if val_split_prop == 0:\r\n        # return original data\r\n        if w is None:\r\n            return X, y, X, y, TRAIN_STRING\r\n\r\n        return X, y, w, X, y, w, TRAIN_STRING\r\n\r\n    # make actual split\r\n    if w is None:\r\n        X_t, X_val, y_t, y_val = train_test_split(\r\n            X, y, test_size=val_split_prop, random_state=seed, shuffle=True\r\n        )\r\n        return X_t, y_t, X_val, y_val, VALIDATION_STRING\r\n\r\n    if stratify_w:\r\n        # split to stratify by group\r\n        X_t, X_val, y_t, y_val, w_t, w_val = train_test_split(\r\n            X,\r\n            y,\r\n            w,\r\n            test_size=val_split_prop,\r\n            random_state=seed,\r\n            stratify=w,\r\n            shuffle=True,\r\n        )\r\n    else:\r\n        X_t, X_val, y_t, y_val, w_t, w_val = train_test_split(\r\n            X, y, w, test_size=val_split_prop, random_state=seed, shuffle=True\r\n        )\r\n\r\n    return X_t, y_t, w_t, X_val, y_val, w_val, VALIDATION_STRING\r\n\r\n\r\ndef heads_l2_penalty(\r\n    params_0: jnp.ndarray,\r\n    params_1: jnp.ndarray,\r\n    n_layers_out: jnp.ndarray,\r\n    reg_diff: jnp.ndarray,\r\n    penalty_0: jnp.ndarray,\r\n    penalty_1: jnp.ndarray,\r\n) -> jnp.ndarray:\r\n    # Compute l2 penalty for output heads. Either seperately, or regularizing their difference\r\n\r\n    # get l2-penalty for first head\r\n    weightsq_0 = penalty_0 * sum(\r\n        [jnp.sum(params_0[i][0] ** 2) for i in range(0, 2 * n_layers_out + 1, 2)]\r\n    )\r\n\r\n    # get l2-penalty for second head\r\n    if reg_diff:\r\n        weightsq_1 = penalty_1 * sum(\r\n            [\r\n                jnp.sum((params_1[i][0] - params_0[i][0]) ** 2)\r\n                for i in range(0, 2 * n_layers_out + 1, 2)\r\n            ]\r\n        )\r\n    else:\r\n        weightsq_1 = penalty_1 * sum(\r\n            [jnp.sum(params_1[i][0] ** 2) for i in range(0, 2 * n_layers_out + 1, 2)]\r\n        )\r\n    return weightsq_1 + weightsq_0\r\n"
  },
  {
    "path": "catenets/models/jax/offsetnet.py",
    "content": "\"\"\"\nModule implements OffsetNet, also referred to as the 'reparametrization approach' and 'hard\napproach' in \"On inductive biases for heterogeneous treatment effect estimation\", Curth & vd\nSchaar (2021); modeling the POs using a shared prognostic function and\nan offset (treatment effect)\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Callable, List, Tuple\n\nimport jax.numpy as jnp\nimport numpy as onp\nfrom jax import grad, jit, random\nfrom jax.example_libraries import optimizers\nfrom jax.example_libraries.stax import sigmoid\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_R,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.jax.base import BaseCATENet, OutputHead\nfrom catenets.models.jax.model_utils import (\n    check_shape_1d_data,\n    heads_l2_penalty,\n    make_val_split,\n)\n\n\nclass OffsetNet(BaseCATENet):\n    \"\"\"\n    Module implements OffsetNet, also referred to as the 'reparametrization approach' and 'hard\n    approach' in Curth & vd Schaar (2021); modeling the POs using a shared prognostic function and\n    an offset (treatment effect).\n\n    Parameters\n    ----------\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_layers_r: int\n        Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_r: int\n        Number of hidden units in each representation layer\n    penalty_l2: float\n        l2 (ridge) penalty\n    step_size: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    penalty_l2_p: float\n        l2-penalty for regularizing the offset\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    \"\"\"\n\n    def __init__(\n        self,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_l2_p: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n    ):\n        # structure of net\n        self.binary_y = binary_y\n        self.n_layers_r = n_layers_r\n        self.n_layers_out = n_layers_out\n        self.n_units_r = n_units_r\n        self.n_units_out = n_units_out\n        self.nonlin = nonlin\n\n        # penalties\n        self.penalty_l2 = penalty_l2\n        self.penalty_l2_p = penalty_l2_p\n\n        # training params\n        self.step_size = step_size\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.n_iter_print = n_iter_print\n        self.seed = seed\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n\n    def _get_train_function(self) -> Callable:\n        return train_offsetnet\n\n    def _get_predict_function(self) -> Callable:\n        return predict_offsetnet\n\n\ndef predict_offsetnet(\n    X: jnp.ndarray,\n    trained_params: jnp.ndarray,\n    predict_funs: List[Any],\n    return_po: bool = False,\n    return_prop: bool = False,\n) -> jnp.ndarray:\n    if return_prop:\n        raise NotImplementedError(\"OffsetNet does not implement a propensity model.\")\n\n    # unpack inputs\n    predict_fun_head = predict_funs[0]\n    binary_y = predict_funs[1]\n    param_0, param_1 = trained_params[0], trained_params[1]\n\n    # get potential outcomes\n    mu_0 = predict_fun_head(param_0, X)\n    offset = predict_fun_head(param_1, X)\n\n    if not binary_y:\n        if return_po:\n            return offset, mu_0, mu_0 + offset\n        else:\n            return offset\n    else:\n        # still need to sigmoid\n        po_0 = sigmoid(mu_0)\n        po_1 = sigmoid(mu_0 + offset)\n        if return_po:\n            return po_1 - po_0, po_0, po_1\n        else:\n            return po_1 - po_0\n\n\ndef train_offsetnet(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_l2_p: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = True,\n) -> Tuple:\n    # input check\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    # get validation split (can be none)\n    X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n        X, y, w, val_split_prop=val_split_prop, seed=seed\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get output head functions (both heads share same structure)\n    init_fun_head, predict_fun_head = OutputHead(\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        binary_y=False,\n        n_layers_r=n_layers_r,\n        n_units_r=n_units_r,\n        nonlin=nonlin,\n    )\n\n    def init_fun_offset(rng: float, input_shape: Tuple) -> Tuple:\n        # chain together the layers\n        # param should look like [param_base, param_offset]\n        rng, layer_rng = random.split(rng)\n        _, param_base = init_fun_head(layer_rng, input_shape)\n        rng, layer_rng = random.split(rng)\n        input_shape, param_offset = init_fun_head(layer_rng, input_shape)\n        return input_shape, [param_base, param_offset]\n\n    # Define loss functions\n    if not binary_y:\n\n        @jit\n        def loss_offsetnet(\n            params: jnp.ndarray, batch: jnp.ndarray, penalty: float, penalty_l2_p: float\n        ) -> jnp.ndarray:\n            # params: list[representation, head_0, head_1]\n            # batch: (X, y, w)\n            inputs, targets, w = batch\n            preds_0 = predict_fun_head(params[0], inputs)\n            offset = predict_fun_head(params[1], inputs)\n            preds = preds_0 + w * offset\n            weightsq_head = heads_l2_penalty(\n                params[0],\n                params[1],\n                n_layers_out + n_layers_r,\n                False,\n                penalty,\n                penalty_l2_p,\n            )\n            if not avg_objective:\n                return jnp.sum((preds - targets) ** 2) + 0.5 * weightsq_head\n            else:\n                return jnp.average((preds - targets) ** 2) + 0.5 * weightsq_head\n\n    else:\n\n        def loss_offsetnet(\n            params: jnp.ndarray,\n            batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n            penalty: float,\n            penalty_l2_p: float,\n        ) -> jnp.ndarray:\n            # params: list[representation, head_0, head_1]\n            # batch: (X, y, w)\n            inputs, targets, w = batch\n            preds_0 = predict_fun_head(params[0], inputs)\n            offset = predict_fun_head(params[1], inputs)\n            preds = sigmoid(preds_0 + w * offset)\n            weightsq_head = heads_l2_penalty(\n                params[0],\n                params[1],\n                n_layers_out + n_layers_r,\n                False,\n                penalty,\n                penalty_l2_p,\n            )\n            if not avg_objective:\n                return (\n                    -jnp.sum(\n                        (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n                    )\n                    + 0.5 * weightsq_head\n                )\n            else:\n                n_batch = y.shape[0]\n                return (\n                    -jnp.sum(\n                        (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n                    )\n                    / n_batch\n                    + 0.5 * weightsq_head\n                )\n\n    # Define optimisation routine\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n\n    @jit\n    def update(\n        i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_l2_p: float\n    ) -> jnp.ndarray:\n        # updating function\n        params = get_params(state)\n        return opt_update(\n            i, grad(loss_offsetnet)(params, batch, penalty_l2, penalty_l2_p), state\n        )\n\n    # initialise states\n    _, init_params = init_fun_offset(rng_key, input_shape)\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    pred_funs = predict_fun_head, binary_y\n\n    # do training\n    for i in range(n_iter):\n        # shuffle data for minibatches\n        onp.random.shuffle(train_indices)\n        for b in range(n_batches):\n            idx_next = train_indices[\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\n            ]\n            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]\n            opt_state = update(\n                i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_l2_p\n            )\n\n        if (i % n_iter_print == 0) or early_stopping:\n            params_curr = get_params(opt_state)\n            l_curr = loss_offsetnet(\n                params_curr, (X_val, y_val, w_val), penalty_l2, penalty_l2_p\n            )\n\n        if i % n_iter_print == 0:\n            log.info(f\"Epoch: {i}, current {val_string} loss {l_curr}\")\n\n        if early_stopping and ((i + 1) * n_batches > n_iter_min):\n            if l_curr < l_best:\n                l_best = l_curr\n                p_curr = 0\n            else:\n                p_curr = p_curr + 1\n\n            if p_curr > patience:\n                if return_val_loss:\n                    # return loss without penalty\n                    l_final = loss_offsetnet(params_curr, (X_val, y_val, w_val), 0, 0)\n                    return params_curr, pred_funs, l_final\n\n                return params_curr, pred_funs\n\n    # return the parameters\n    trained_params = get_params(opt_state)\n\n    if return_val_loss:\n        # return loss without penalty\n        l_final = loss_offsetnet(get_params(opt_state), (X_val, y_val, w_val), 0, 0)\n        return trained_params, pred_funs, l_final\n\n    return trained_params, pred_funs\n"
  },
  {
    "path": "catenets/models/jax/pseudo_outcome_nets.py",
    "content": "\"\"\"\nImplements Pseudo-outcome based Two-step Nets, namely the DR-learner, the PW-learner and the\nRA-learner.\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Callable, Optional, Tuple\n\nimport jax.numpy as jnp\nimport numpy as onp\nimport pandas as pd\nfrom sklearn.model_selection import StratifiedKFold\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_AVG_OBJECTIVE,\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_CF_FOLDS,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_OUT_T,\n    DEFAULT_LAYERS_R,\n    DEFAULT_LAYERS_R_T,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_STEP_SIZE_T,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_OUT_T,\n    DEFAULT_UNITS_R,\n    DEFAULT_UNITS_R_T,\n    DEFAULT_VAL_SPLIT,\n)\nfrom catenets.models.jax.base import BaseCATENet, train_output_net_only\nfrom catenets.models.jax.disentangled_nets import predict_snet3, train_snet3\nfrom catenets.models.jax.flextenet import predict_flextenet, train_flextenet\nfrom catenets.models.jax.model_utils import check_shape_1d_data, check_X_is_np\nfrom catenets.models.jax.offsetnet import predict_offsetnet, train_offsetnet\nfrom catenets.models.jax.representation_nets import (\n    predict_snet1,\n    predict_snet2,\n    train_snet1,\n    train_snet2,\n)\nfrom catenets.models.jax.snet import predict_snet, train_snet\nfrom catenets.models.jax.tnet import predict_t_net, train_tnet\nfrom catenets.models.jax.transformation_utils import (\n    DR_TRANSFORMATION,\n    PW_TRANSFORMATION,\n    RA_TRANSFORMATION,\n    _get_transformation_function,\n)\n\nT_STRATEGY = \"T\"\nS1_STRATEGY = \"Tar\"\nS2_STRATEGY = \"S2\"\nS3_STRATEGY = \"S3\"\nS_STRATEGY = \"S\"\nOFFSET_STRATEGY = \"Offset\"\nFLEX_STRATEGY = \"Flex\"\n\nALL_STRATEGIES = [\n    T_STRATEGY,\n    S1_STRATEGY,\n    S2_STRATEGY,\n    S3_STRATEGY,\n    S_STRATEGY,\n    FLEX_STRATEGY,\n    OFFSET_STRATEGY,\n]\n\n\nclass PseudoOutcomeNet(BaseCATENet):\n    \"\"\"\n    Class implements TwoStepLearners based on pseudo-outcome regression as discussed in\n    Curth &vd Schaar (2021): RA-learner, PW-learner and DR-learner\n\n    Parameters\n    ----------\n    first_stage_strategy: str, default 't'\n        which nuisance estimator to use in first stage\n    first_stage_args: dict\n        Any additional arguments to pass to first stage training function\n    data_split: bool, default False\n        Whether to split the data in two folds for estimation\n    cross_fit: bool, default False\n        Whether to perform cross fitting\n    n_cf_folds: int\n        Number of crossfitting folds to use\n    transformation: str, default 'AIPW'\n        pseudo-outcome to use ('AIPW' for DR-learner, 'HT' for PW learner, 'RA' for RA-learner)\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out: int\n        First stage Number of hidden units in each hypothesis layer\n    n_layers_r: int\n        First stage Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_r: int\n        First stage Number of hidden units in each representation layer\n    n_layers_out_t: int\n        Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out_t: int\n        Second stage Number of hidden units in each hypothesis layer\n    n_layers_r_t: int\n        Second stage Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_r_t: int\n        Second stage Number of hidden units in each representation layer\n    penalty_l2: float\n        First stage l2 (ridge) penalty\n    penalty_l2_t: float\n        Second stage l2 (ridge) penalty\n    step_size: float\n        First stage learning rate for optimizer\n    step_size_t: float\n        Second stage learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    \"\"\"\n\n    def __init__(\n        self,\n        first_stage_strategy: str = T_STRATEGY,\n        first_stage_args: Optional[dict] = None,\n        data_split: bool = False,\n        cross_fit: bool = False,\n        n_cf_folds: int = DEFAULT_CF_FOLDS,\n        transformation: str = DR_TRANSFORMATION,\n        binary_y: bool = False,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n        n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n        n_units_r_t: int = DEFAULT_UNITS_R_T,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_l2_t: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        step_size_t: float = DEFAULT_STEP_SIZE_T,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        rescale_transformation: bool = False,\n        nonlin: str = DEFAULT_NONLIN,\n    ) -> None:\n        # settings\n        self.first_stage_strategy = first_stage_strategy\n        self.first_stage_args = first_stage_args\n        self.binary_y = binary_y\n        self.transformation = transformation\n        self.data_split = data_split\n        self.cross_fit = cross_fit\n        self.n_cf_folds = n_cf_folds\n\n        # model architecture hyperparams\n        self.n_layers_out = n_layers_out\n        self.n_layers_out_t = n_layers_out_t\n        self.n_layers_r = n_layers_r\n        self.n_layers_r_t = n_layers_r_t\n        self.n_units_out = n_units_out\n        self.n_units_out_t = n_units_out_t\n        self.n_units_r = n_units_r\n        self.n_units_r_t = n_units_r_t\n        self.nonlin = nonlin\n\n        # other hyperparameters\n        self.penalty_l2 = penalty_l2\n        self.penalty_l2_t = penalty_l2_t\n        self.step_size = step_size\n        self.step_size_t = step_size_t\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.n_iter_print = n_iter_print\n        self.seed = seed\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n        self.rescale_transformation = rescale_transformation\n\n    def _get_train_function(self) -> Callable:\n        return train_pseudooutcome_net\n\n    def fit(\n        self,\n        X: jnp.ndarray,\n        y: jnp.ndarray,\n        w: jnp.ndarray,\n        p: Optional[jnp.ndarray] = None,\n    ) -> \"PseudoOutcomeNet\":\n        # overwrite super so we can pass p as extra param\n        # some quick input checks\n        X = check_X_is_np(X)\n        self._check_inputs(w, p)\n\n        train_func = self._get_train_function()\n        train_params = self.get_params()\n\n        if \"transformation\" not in train_params.keys():\n            train_params.update({\"transformation\": self.transformation})\n\n        if self.rescale_transformation:\n            self._params, self._predict_funs, self._scale_factor = train_func(\n                X, y, w, p, **train_params\n            )\n        else:\n            self._params, self._predict_funs = train_func(X, y, w, p, **train_params)\n\n        return self\n\n    def _get_predict_function(self) -> Callable:\n        # Two step nets do not need this\n        pass\n\n    def predict(\n        self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False\n    ) -> jnp.ndarray:\n        # check input\n        if return_po:\n            raise NotImplementedError(\n                \"TwoStepNets have no Potential outcome predictors.\"\n            )\n\n        if return_prop:\n            raise NotImplementedError(\"TwoStepNets have no Propensity predictors.\")\n\n        if isinstance(X, pd.DataFrame):\n            X = X.values\n\n        if self.rescale_transformation:\n            return 1 / self._scale_factor * self._predict_funs(self._params, X)\n        else:\n            return self._predict_funs(self._params, X)\n\n\nclass DRNet(PseudoOutcomeNet):\n    \"\"\"Wrapper for DR-learner using PseudoOutcomeNet\"\"\"\n\n    def __init__(\n        self,\n        first_stage_strategy: str = T_STRATEGY,\n        data_split: bool = False,\n        cross_fit: bool = False,\n        n_cf_folds: int = DEFAULT_CF_FOLDS,\n        binary_y: bool = False,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n        n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n        n_units_r_t: int = DEFAULT_UNITS_R_T,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_l2_t: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        step_size_t: float = DEFAULT_STEP_SIZE_T,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        rescale_transformation: bool = False,\n        nonlin: str = DEFAULT_NONLIN,\n        first_stage_args: Optional[dict] = None,\n    ) -> None:\n        super().__init__(\n            first_stage_strategy=first_stage_strategy,\n            data_split=data_split,\n            cross_fit=cross_fit,\n            n_cf_folds=n_cf_folds,\n            transformation=DR_TRANSFORMATION,\n            binary_y=binary_y,\n            n_layers_out=n_layers_out,\n            n_layers_r=n_layers_r,\n            n_layers_out_t=n_layers_out_t,\n            n_layers_r_t=n_layers_r_t,\n            n_units_out=n_units_out,\n            n_units_r=n_units_r,\n            n_units_out_t=n_units_out_t,\n            n_units_r_t=n_units_r_t,\n            penalty_l2=penalty_l2,\n            penalty_l2_t=penalty_l2_t,\n            step_size=step_size,\n            step_size_t=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            n_iter_min=n_iter_min,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            nonlin=nonlin,\n            rescale_transformation=rescale_transformation,\n            first_stage_args=first_stage_args,\n        )\n\n\nclass RANet(PseudoOutcomeNet):\n    \"\"\"Wrapper for RA-learner using PseudoOutcomeNet\"\"\"\n\n    def __init__(\n        self,\n        first_stage_strategy: str = T_STRATEGY,\n        data_split: bool = False,\n        cross_fit: bool = False,\n        n_cf_folds: int = DEFAULT_CF_FOLDS,\n        binary_y: bool = False,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n        n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n        n_units_r_t: int = DEFAULT_UNITS_R_T,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_l2_t: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        step_size_t: float = DEFAULT_STEP_SIZE_T,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        rescale_transformation: bool = False,\n        nonlin: str = DEFAULT_NONLIN,\n        first_stage_args: Optional[dict] = None,\n    ) -> None:\n        super().__init__(\n            first_stage_strategy=first_stage_strategy,\n            data_split=data_split,\n            cross_fit=cross_fit,\n            n_cf_folds=n_cf_folds,\n            transformation=RA_TRANSFORMATION,\n            binary_y=binary_y,\n            n_layers_out=n_layers_out,\n            n_layers_r=n_layers_r,\n            n_layers_out_t=n_layers_out_t,\n            n_layers_r_t=n_layers_r_t,\n            n_units_out=n_units_out,\n            n_units_r=n_units_r,\n            n_units_out_t=n_units_out_t,\n            n_units_r_t=n_units_r_t,\n            penalty_l2=penalty_l2,\n            penalty_l2_t=penalty_l2_t,\n            step_size=step_size,\n            step_size_t=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            n_iter_min=n_iter_min,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            nonlin=nonlin,\n            rescale_transformation=rescale_transformation,\n            first_stage_args=first_stage_args,\n        )\n\n\nclass PWNet(PseudoOutcomeNet):\n    \"\"\"Wrapper for PW-learner using PseudoOutcomeNet\"\"\"\n\n    def __init__(\n        self,\n        first_stage_strategy: str = T_STRATEGY,\n        data_split: bool = False,\n        cross_fit: bool = False,\n        n_cf_folds: int = DEFAULT_CF_FOLDS,\n        binary_y: bool = False,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n        n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n        n_units_r_t: int = DEFAULT_UNITS_R_T,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_l2_t: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        step_size_t: float = DEFAULT_STEP_SIZE_T,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        rescale_transformation: bool = False,\n        nonlin: str = DEFAULT_NONLIN,\n        first_stage_args: Optional[dict] = None,\n    ) -> None:\n        super().__init__(\n            first_stage_strategy=first_stage_strategy,\n            data_split=data_split,\n            cross_fit=cross_fit,\n            n_cf_folds=n_cf_folds,\n            transformation=PW_TRANSFORMATION,\n            binary_y=binary_y,\n            n_layers_out=n_layers_out,\n            n_layers_r=n_layers_r,\n            n_layers_out_t=n_layers_out_t,\n            n_layers_r_t=n_layers_r_t,\n            n_units_out=n_units_out,\n            n_units_r=n_units_r,\n            n_units_out_t=n_units_out_t,\n            n_units_r_t=n_units_r_t,\n            penalty_l2=penalty_l2,\n            penalty_l2_t=penalty_l2_t,\n            step_size=step_size,\n            step_size_t=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            n_iter_min=n_iter_min,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            nonlin=nonlin,\n            rescale_transformation=rescale_transformation,\n            first_stage_args=first_stage_args,\n        )\n\n\ndef train_pseudooutcome_net(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    p: Optional[jnp.ndarray] = None,\n    first_stage_strategy: str = T_STRATEGY,\n    data_split: bool = False,\n    cross_fit: bool = False,\n    n_cf_folds: int = DEFAULT_CF_FOLDS,\n    transformation: str = DR_TRANSFORMATION,\n    binary_y: bool = False,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n    n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_units_r: int = DEFAULT_UNITS_R,\n    n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n    n_units_r_t: int = DEFAULT_UNITS_R_T,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_l2_t: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    step_size_t: float = DEFAULT_STEP_SIZE_T,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    rescale_transformation: bool = False,\n    return_val_loss: bool = False,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n    first_stage_args: Optional[dict] = None,\n) -> Tuple:\n    # get shape of data\n    n, d = X.shape\n\n    if p is not None:\n        p = check_shape_1d_data(p)\n\n    # get transformation function\n    transformation_function = _get_transformation_function(transformation)\n\n    # get strategy name\n    if first_stage_strategy not in ALL_STRATEGIES:\n        raise ValueError(\n            \"Parameter first stage should be in \"\n            \"catenets.models.pseudo_outcome_nets.ALL_STRATEGIES. \"\n            \"You passed {}\".format(first_stage_strategy)\n        )\n\n    # split data as wanted\n    if p is None or transformation is not PW_TRANSFORMATION:\n        if not cross_fit:\n            if not data_split:\n                log.debug(\"Training first stage with all data (no data splitting)\")\n                # use all data for both\n                fit_mask = onp.ones(n, dtype=bool)\n                pred_mask = onp.ones(n, dtype=bool)\n            else:\n                log.debug(\"Training first stage with half of the data (data splitting)\")\n                # split data in half\n                fit_idx = onp.random.choice(n, int(onp.round(n / 2)))\n                fit_mask = onp.zeros(n, dtype=bool)\n\n                fit_mask[fit_idx] = 1\n                pred_mask = ~fit_mask\n\n            mu_0, mu_1, pi_hat = _train_and_predict_first_stage(\n                X,\n                y,\n                w,\n                fit_mask,\n                pred_mask,\n                first_stage_strategy=first_stage_strategy,\n                binary_y=binary_y,\n                n_layers_out=n_layers_out,\n                n_layers_r=n_layers_r,\n                n_units_out=n_units_out,\n                n_units_r=n_units_r,\n                penalty_l2=penalty_l2,\n                step_size=step_size,\n                n_iter=n_iter,\n                batch_size=batch_size,\n                val_split_prop=val_split_prop,\n                early_stopping=early_stopping,\n                patience=patience,\n                n_iter_min=n_iter_min,\n                n_iter_print=n_iter_print,\n                seed=seed,\n                nonlin=nonlin,\n                avg_objective=avg_objective,\n                transformation=transformation,\n                first_stage_args=first_stage_args,\n            )\n            if data_split:\n                # keep only prediction data\n                X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :]\n\n                if p is not None:\n                    p = p[pred_mask, :]\n\n        else:\n            log.debug(f\"Training first stage in {n_cf_folds} folds (cross-fitting)\")\n            # do cross fitting\n            mu_0, mu_1, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1)), onp.zeros((n, 1))\n            splitter = StratifiedKFold(\n                n_splits=n_cf_folds, shuffle=True, random_state=seed\n            )\n\n            fold_count = 1\n            for train_idx, test_idx in splitter.split(X, w):\n\n                log.debug(f\"Training fold {fold_count}.\")\n                fold_count = fold_count + 1\n\n                pred_mask = onp.zeros(n, dtype=bool)\n                pred_mask[test_idx] = 1\n                fit_mask = ~pred_mask\n\n                (\n                    mu_0[pred_mask],\n                    mu_1[pred_mask],\n                    pi_hat[pred_mask],\n                ) = _train_and_predict_first_stage(\n                    X,\n                    y,\n                    w,\n                    fit_mask,\n                    pred_mask,\n                    first_stage_strategy=first_stage_strategy,\n                    binary_y=binary_y,\n                    n_layers_out=n_layers_out,\n                    n_layers_r=n_layers_r,\n                    n_units_out=n_units_out,\n                    n_units_r=n_units_r,\n                    penalty_l2=penalty_l2,\n                    step_size=step_size,\n                    n_iter=n_iter,\n                    batch_size=batch_size,\n                    val_split_prop=val_split_prop,\n                    early_stopping=early_stopping,\n                    patience=patience,\n                    n_iter_min=n_iter_min,\n                    n_iter_print=n_iter_print,\n                    seed=seed,\n                    nonlin=nonlin,\n                    avg_objective=avg_objective,\n                    transformation=transformation,\n                    first_stage_args=first_stage_args,\n                )\n\n    log.debug(\"Training second stage.\")\n\n    if p is not None:\n        # use known propensity score\n        p = check_shape_1d_data(p)\n        pi_hat = p\n\n    # second stage\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    # transform data and fit on transformed data\n    if transformation is PW_TRANSFORMATION:\n        mu_0 = None\n        mu_1 = None\n\n    pseudo_outcome = transformation_function(y=y, w=w, p=pi_hat, mu_0=mu_0, mu_1=mu_1)\n    if rescale_transformation:\n        scale_factor = onp.std(y) / onp.std(pseudo_outcome)\n        if scale_factor > 1:\n            scale_factor = 1\n        else:\n            pseudo_outcome = scale_factor * pseudo_outcome\n        params, predict_funs = train_output_net_only(\n            X,\n            pseudo_outcome,\n            binary_y=False,\n            n_layers_out=n_layers_out_t,\n            n_units_out=n_units_out_t,\n            n_layers_r=n_layers_r_t,\n            n_units_r=n_units_r_t,\n            penalty_l2=penalty_l2_t,\n            step_size=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            nonlin=nonlin,\n            avg_objective=avg_objective,\n        )\n        return params, predict_funs, scale_factor\n    else:\n        return train_output_net_only(\n            X,\n            pseudo_outcome,\n            binary_y=False,\n            n_layers_out=n_layers_out_t,\n            n_units_out=n_units_out_t,\n            n_layers_r=n_layers_r_t,\n            n_units_r=n_units_r_t,\n            penalty_l2=penalty_l2_t,\n            step_size=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            nonlin=nonlin,\n            avg_objective=avg_objective,\n        )\n\n\ndef _train_and_predict_first_stage(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    fit_mask: jnp.ndarray,\n    pred_mask: jnp.ndarray,\n    first_stage_strategy: str,\n    binary_y: bool = False,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_units_r: int = DEFAULT_UNITS_R,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = False,\n    transformation: str = DR_TRANSFORMATION,\n    first_stage_args: Optional[dict] = None,\n) -> Tuple:\n    if len(w.shape) > 1:\n        w = w.reshape((len(w),))\n\n    if first_stage_args is None:\n        first_stage_args = {}\n\n    # split the data\n    X_fit, y_fit, w_fit = X[fit_mask, :], y[fit_mask], w[fit_mask]\n    X_pred = X[pred_mask, :]\n\n    train_fun: Callable\n    predict_fun: Callable\n\n    if first_stage_strategy == T_STRATEGY:\n        train_fun, predict_fun = train_tnet, predict_t_net\n    elif first_stage_strategy == S_STRATEGY:\n        train_fun, predict_fun = train_snet, predict_snet\n    elif first_stage_strategy == S1_STRATEGY:\n        train_fun, predict_fun = train_snet1, predict_snet1\n    elif first_stage_strategy == S2_STRATEGY:\n        train_fun, predict_fun = train_snet2, predict_snet2\n    elif first_stage_strategy == S3_STRATEGY:\n        train_fun, predict_fun = train_snet3, predict_snet3\n    elif first_stage_strategy == OFFSET_STRATEGY:\n        train_fun, predict_fun = train_offsetnet, predict_offsetnet\n    elif first_stage_strategy == FLEX_STRATEGY:\n        train_fun, predict_fun = train_flextenet, predict_flextenet\n    else:\n        raise ValueError(\n            \"{} is not a valid first stage strategy for a PseudoOutcomeNet\".format(\n                first_stage_strategy\n            )\n        )\n\n    log.debug(\"Training PO estimators\")\n    trained_params, pred_fun = train_fun(\n        X_fit,\n        y_fit,\n        w_fit,\n        binary_y=binary_y,\n        n_layers_r=n_layers_r,\n        n_units_r=n_units_r,\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        penalty_l2=penalty_l2,\n        step_size=step_size,\n        n_iter=n_iter,\n        batch_size=batch_size,\n        val_split_prop=val_split_prop,\n        early_stopping=early_stopping,\n        patience=patience,\n        n_iter_min=n_iter_min,\n        n_iter_print=n_iter_print,\n        seed=seed,\n        nonlin=nonlin,\n        avg_objective=avg_objective,\n        **first_stage_args,\n    )\n\n    if first_stage_strategy in [S_STRATEGY, S2_STRATEGY, S3_STRATEGY]:\n        _, mu_0, mu_1, pi_hat = predict_fun(\n            X_pred, trained_params, pred_fun, return_po=True, return_prop=True\n        )\n    else:\n        if transformation is not PW_TRANSFORMATION:\n            _, mu_0, mu_1 = predict_fun(\n                X_pred, trained_params, pred_fun, return_po=True\n            )\n        else:\n            mu_0, mu_1 = onp.nan, onp.nan\n\n        if transformation is not RA_TRANSFORMATION:\n            log.debug(\"Training propensity net\")\n            params_prop, predict_fun_prop = train_output_net_only(\n                X_fit,\n                w_fit,\n                binary_y=True,\n                n_layers_out=n_layers_out,\n                n_units_out=n_units_out,\n                n_layers_r=n_layers_r,\n                n_units_r=n_units_r,\n                penalty_l2=penalty_l2,\n                step_size=step_size,\n                n_iter=n_iter,\n                batch_size=batch_size,\n                val_split_prop=val_split_prop,\n                early_stopping=early_stopping,\n                patience=patience,\n                n_iter_min=n_iter_min,\n                n_iter_print=n_iter_print,\n                seed=seed,\n                nonlin=nonlin,\n                avg_objective=avg_objective,\n            )\n            pi_hat = predict_fun_prop(params_prop, X_pred)\n        else:\n            pi_hat = onp.nan\n\n    return mu_0, mu_1, pi_hat\n"
  },
  {
    "path": "catenets/models/jax/representation_nets.py",
    "content": "\"\"\"\nModule implements SNet1 and SNet2, which are based on  CFRNet/TARNet from Shalit et al (2017) and\nDragonNet from Shi et al (2019), respectively.\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Callable, List, Tuple\n\nimport jax.numpy as jnp\nimport numpy as onp\nfrom jax import grad, jit, random\nfrom jax.example_libraries import optimizers\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_AVG_OBJECTIVE,\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_DISC,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_R,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.jax.base import BaseCATENet, OutputHead, ReprBlock\nfrom catenets.models.jax.model_utils import (\n    check_shape_1d_data,\n    heads_l2_penalty,\n    make_val_split,\n)\n\n\nclass SNet1(BaseCATENet):\n    \"\"\"\n    Class implements Shalit et al (2017)'s TARNet & CFR (discrepancy regularization is NOT\n    TESTED). Also referred to as SNet-1 in our paper.\n\n    Parameters\n    ----------\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_layers_r: int\n        Number of shared representation layers before hypothesis layers\n    n_units_r: int\n        Number of hidden units in each representation layer\n    penalty_l2: float\n        l2 (ridge) penalty\n    step_size: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    reg_diff: bool, default False\n        Whether to regularize the difference between the two potential outcome heads\n    penalty_diff: float\n        l2-penalty for regularizing the difference between output heads. used only if\n        train_separate=False\n    same_init: bool, False\n        Whether to initialise the two output heads with same values\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    penalty_disc: float, default zero\n        Discrepancy penalty. Defaults to zero as this feature is not tested.\n    \"\"\"\n\n    def __init__(\n        self,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        reg_diff: bool = False,\n        penalty_diff: float = DEFAULT_PENALTY_L2,\n        same_init: bool = False,\n        nonlin: str = DEFAULT_NONLIN,\n        penalty_disc: float = DEFAULT_PENALTY_DISC,\n    ) -> None:\n        # structure of net\n        self.binary_y = binary_y\n        self.n_layers_r = n_layers_r\n        self.n_layers_out = n_layers_out\n        self.n_units_r = n_units_r\n        self.n_units_out = n_units_out\n        self.nonlin = nonlin\n\n        # penalties\n        self.penalty_l2 = penalty_l2\n        self.penalty_disc = penalty_disc\n        self.reg_diff = reg_diff\n        self.penalty_diff = penalty_diff\n        self.same_init = same_init\n\n        # training params\n        self.step_size = step_size\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.n_iter_print = n_iter_print\n        self.seed = seed\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n\n    def _get_train_function(self) -> Callable:\n        return train_snet1\n\n    def _get_predict_function(self) -> Callable:\n        return predict_snet1\n\n\nclass TARNet(SNet1):\n    \"\"\"Wrapper for TARNet\"\"\"\n\n    def __init__(\n        self,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        reg_diff: bool = False,\n        penalty_diff: float = DEFAULT_PENALTY_L2,\n        same_init: bool = False,\n        nonlin: str = DEFAULT_NONLIN,\n    ):\n        super().__init__(\n            binary_y=binary_y,\n            n_layers_r=n_layers_r,\n            n_units_r=n_units_r,\n            n_layers_out=n_layers_out,\n            n_units_out=n_units_out,\n            penalty_l2=penalty_l2,\n            step_size=step_size,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            reg_diff=reg_diff,\n            penalty_diff=penalty_diff,\n            same_init=same_init,\n            nonlin=nonlin,\n            penalty_disc=0,\n        )\n\n\nclass SNet2(BaseCATENet):\n    \"\"\"\n    Class implements SNet-2, which is based on Shi et al (2019)'s DragonNet (this version does\n    NOT use targeted regularization and has a (possibly deeper) propensity head.\n\n    Parameters\n    ----------\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_layers_out_prop: int\n        Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense\n        layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_units_out_prop: int\n        Number of hidden units in each propensity score hypothesis layer\n    n_layers_r: int\n        Number of shared representation layers before hypothesis layers\n    n_units_r: int\n        Number of hidden units in each representation layer\n    penalty_l2: float\n        l2 (ridge) penalty\n    step_size: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    reg_diff: bool, default False\n        Whether to regularize the difference between the two potential outcome heads\n    penalty_diff: float\n        l2-penalty for regularizing the difference between output heads. used only if\n        train_separate=False\n    same_init: bool, False\n        Whether to initialise the two output heads with same values\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    \"\"\"\n\n    def __init__(\n        self,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        reg_diff: bool = False,\n        same_init: bool = False,\n        penalty_diff: float = DEFAULT_PENALTY_L2,\n        nonlin: str = DEFAULT_NONLIN,\n    ) -> None:\n        self.binary_y = binary_y\n\n        self.n_layers_r = n_layers_r\n        self.n_layers_out = n_layers_out\n        self.n_layers_out_prop = n_layers_out_prop\n        self.n_units_r = n_units_r\n        self.n_units_out = n_units_out\n        self.n_units_out_prop = n_units_out_prop\n        self.nonlin = nonlin\n\n        self.penalty_l2 = penalty_l2\n        self.step_size = step_size\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n        self.reg_diff = reg_diff\n        self.penalty_diff = penalty_diff\n        self.same_init = same_init\n\n        self.seed = seed\n        self.n_iter_print = n_iter_print\n\n    def _get_train_function(self) -> Callable:\n        return train_snet2\n\n    def _get_predict_function(self) -> Callable:\n        return predict_snet2\n\n\nclass DragonNet(SNet2):\n    \"\"\"Wrapper for DragonNet\"\"\"\n\n    def __init__(\n        self,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = 0,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        reg_diff: bool = False,\n        same_init: bool = False,\n        penalty_diff: float = DEFAULT_PENALTY_L2,\n        nonlin: str = DEFAULT_NONLIN,\n    ):\n        super().__init__(\n            binary_y=binary_y,\n            n_layers_r=n_layers_r,\n            n_units_r=n_units_r,\n            n_layers_out=n_layers_out,\n            n_units_out=n_units_out,\n            penalty_l2=penalty_l2,\n            n_units_out_prop=n_units_out_prop,\n            n_layers_out_prop=n_layers_out_prop,\n            step_size=step_size,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            reg_diff=reg_diff,\n            penalty_diff=penalty_diff,\n            same_init=same_init,\n            nonlin=nonlin,\n        )\n\n\n# Training functions for SNet1 -------------------------------------------------\ndef mmd2_lin(X: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:\n    # Squared Linear MMD as implemented in CFR\n    # jax does not support indexing, so this is a workaround with reweighting in means\n    n = w.shape[0]\n    n_t = jnp.sum(w)\n\n    # normalize X so scale matters\n    X = X / jnp.sqrt(jnp.var(X, axis=0))\n\n    mean_control = (n / (n - n_t)) * jnp.mean((1 - w) * X, axis=0)\n    mean_treated = (n / n_t) * jnp.mean(w * X, axis=0)\n\n    return jnp.sum((mean_treated - mean_control) ** 2)\n\n\ndef predict_snet1(\n    X: jnp.ndarray,\n    trained_params: dict,\n    predict_funs: list,\n    return_po: bool = False,\n    return_prop: bool = False,\n) -> jnp.ndarray:\n    if return_prop:\n        raise NotImplementedError(\"SNet1 does not implement a propensity model.\")\n\n    # unpack inputs\n    predict_fun_repr, predict_fun_head = predict_funs\n    param_repr, param_0, param_1 = (\n        trained_params[0],\n        trained_params[1],\n        trained_params[2],\n    )\n\n    # get representation\n    representation = predict_fun_repr(param_repr, X)\n\n    # get potential outcomes\n    mu_0 = predict_fun_head(param_0, representation)\n    mu_1 = predict_fun_head(param_1, representation)\n\n    if return_po:\n        return mu_1 - mu_0, mu_0, mu_1\n    else:\n        return mu_1 - mu_0\n\n\ndef train_snet1(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_disc: int = DEFAULT_PENALTY_DISC,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    reg_diff: bool = False,\n    same_init: bool = False,\n    penalty_diff: float = DEFAULT_PENALTY_L2,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n) -> Any:\n    # function to train TARNET (Johansson et al) using jax\n    # input check\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    if not reg_diff:\n        penalty_diff = penalty_l2\n\n    # get validation split (can be none)\n    X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n        X, y, w, val_split_prop=val_split_prop, seed=seed\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get representation layer\n    init_fun_repr, predict_fun_repr = ReprBlock(\n        n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin\n    )\n\n    # get output head functions (both heads share same structure)\n    init_fun_head, predict_fun_head = OutputHead(\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        binary_y=binary_y,\n        nonlin=nonlin,\n    )\n\n    def init_fun_snet1(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]:\n        # chain together the layers\n        # param should look like [repr, po_0, po_1]\n        rng, layer_rng = random.split(rng)\n        input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape)\n        rng, layer_rng = random.split(rng)\n        if same_init:\n            # initialise both on same values\n            input_shape, param_0 = init_fun_head(layer_rng, input_shape_repr)\n            input_shape, param_1 = init_fun_head(layer_rng, input_shape_repr)\n        else:\n            input_shape, param_0 = init_fun_head(layer_rng, input_shape_repr)\n            rng, layer_rng = random.split(rng)\n            input_shape, param_1 = init_fun_head(layer_rng, input_shape_repr)\n\n        return input_shape, [param_repr, param_0, param_1]\n\n    # Define loss functions\n    # loss functions for the head\n    if not binary_y:\n\n        def loss_head(\n            params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head(params, inputs)\n            return jnp.sum(weights * ((preds - targets) ** 2))\n\n    else:\n\n        def loss_head(\n            params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head(params, inputs)\n            return -jnp.sum(\n                weights\n                * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n            )\n\n    # complete loss function for all parts\n    @jit\n    def loss_snet1(\n        params: List,\n        batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n        penalty_l2: float,\n        penalty_disc: float,\n        penalty_diff: float,\n    ) -> jnp.ndarray:\n        # params: list[representation, head_0, head_1]\n        # batch: (X, y, w)\n        X, y, w = batch\n\n        # get representation\n        reps = predict_fun_repr(params[0], X)\n\n        # get mmd\n        disc = mmd2_lin(reps, w)\n\n        # pass down to two heads\n        loss_0 = loss_head(params[1], (reps, y, 1 - w))\n        loss_1 = loss_head(params[2], (reps, y, w))\n\n        # regularization on representation\n        weightsq_body = sum(\n            [jnp.sum(params[0][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]\n        )\n        weightsq_head = heads_l2_penalty(\n            params[1], params[2], n_layers_out, reg_diff, penalty_l2, penalty_diff\n        )\n        if not avg_objective:\n            return (\n                loss_0\n                + loss_1\n                + penalty_disc * disc\n                + 0.5 * (penalty_l2 * weightsq_body + weightsq_head)\n            )\n        else:\n            n_batch = y.shape[0]\n            return (\n                (loss_0 + loss_1) / n_batch\n                + penalty_disc * disc\n                + 0.5 * (penalty_l2 * weightsq_body + weightsq_head)\n            )\n\n    # Define optimisation routine\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n\n    @jit\n    def update(\n        i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_disc: float\n    ) -> jnp.ndarray:\n        # updating function\n        params = get_params(state)\n        return opt_update(\n            i,\n            grad(loss_snet1)(params, batch, penalty_l2, penalty_disc, penalty_diff),\n            state,\n        )\n\n    # initialise states\n    _, init_params = init_fun_snet1(rng_key, input_shape)\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    # do training\n    for i in range(n_iter):\n        # shuffle data for minibatches\n        onp.random.shuffle(train_indices)\n        for b in range(n_batches):\n            idx_next = train_indices[\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\n            ]\n            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]\n            opt_state = update(\n                i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_disc\n            )\n\n        if (i % n_iter_print == 0) or early_stopping:\n            params_curr = get_params(opt_state)\n            l_curr = loss_snet1(\n                params_curr,\n                (X_val, y_val, w_val),\n                penalty_l2,\n                penalty_disc,\n                penalty_diff,\n            )\n\n        if i % n_iter_print == 0:\n            log.info(f\"Epoch: {i}, current {val_string} loss {l_curr}\")\n\n        if early_stopping:\n            if l_curr < l_best:\n                l_best = l_curr\n                p_curr = 0\n                params_best = params_curr\n            else:\n                if onp.isnan(l_curr):\n                    # if diverged, return best\n                    return params_best, (predict_fun_repr, predict_fun_head)\n                p_curr = p_curr + 1\n\n            if p_curr > patience and ((i + 1) * n_batches > n_iter_min):\n                if return_val_loss:\n                    # return loss without penalty\n                    l_final = loss_snet1(params_curr, (X_val, y_val, w_val), 0, 0, 0)\n                    return params_curr, (predict_fun_repr, predict_fun_head), l_final\n\n                return params_curr, (predict_fun_repr, predict_fun_head)\n\n    # return the parameters\n    trained_params = get_params(opt_state)\n\n    if return_val_loss:\n        # return loss without penalty\n        l_final = loss_snet1(get_params(opt_state), (X_val, y_val, w_val), 0, 0, 0)\n        return trained_params, (predict_fun_repr, predict_fun_head), l_final\n\n    return trained_params, (predict_fun_repr, predict_fun_head)\n\n\n# SNET-2 -----------------------------------------------------------------------------------------\ndef train_snet2(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    n_units_out_prop: int = DEFAULT_UNITS_OUT,\n    n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    reg_diff: bool = False,\n    penalty_diff: float = DEFAULT_PENALTY_L2,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n    same_init: bool = False,\n) -> Any:\n    \"\"\"\n    SNet2 corresponds to DragonNet (Shi et al, 2019) [without TMLE regularisation term].\n    \"\"\"\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    if not reg_diff:\n        penalty_diff = penalty_l2\n\n    # get validation split (can be none)\n    X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n        X, y, w, val_split_prop=val_split_prop, seed=seed\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get representation layer\n    init_fun_repr, predict_fun_repr = ReprBlock(\n        n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin\n    )\n\n    # get output head functions (output heads share same structure)\n    init_fun_head_po, predict_fun_head_po = OutputHead(\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        binary_y=binary_y,\n        nonlin=nonlin,\n    )\n    # add propensity head\n    init_fun_head_prop, predict_fun_head_prop = OutputHead(\n        n_layers_out=n_layers_out_prop,\n        n_units_out=n_units_out_prop,\n        binary_y=True,\n        nonlin=nonlin,\n    )\n\n    def init_fun_snet2(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]:\n        # chain together the layers\n        # param should look like [repr, po_0, po_1, prop]\n        rng, layer_rng = random.split(rng)\n        input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape)\n\n        rng, layer_rng = random.split(rng)\n        if same_init:\n            # initialise both on same values\n            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)\n            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)\n        else:\n            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)\n            rng, layer_rng = random.split(rng)\n            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)\n        rng, layer_rng = random.split(rng)\n        input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr)\n        return input_shape, [param_repr, param_0, param_1, param_prop]\n\n    # Define loss functions\n    # loss functions for the head\n    if not binary_y:\n\n        def loss_head(\n            params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head_po(params, inputs)\n            return jnp.sum(weights * ((preds - targets) ** 2))\n\n    else:\n\n        def loss_head(\n            params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]\n        ) -> jnp.ndarray:\n            # log loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head_po(params, inputs)\n            return -jnp.sum(\n                weights\n                * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n            )\n\n    def loss_head_prop(\n        params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float\n    ) -> jnp.ndarray:\n        # log loss function for propensities\n        inputs, targets = batch\n        preds = predict_fun_head_prop(params, inputs)\n\n        return -jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n\n    # complete loss function for all parts\n    @jit\n    def loss_snet2(\n        params: List,\n        batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n        penalty_l2: float,\n        penalty_diff: float,\n    ) -> jnp.ndarray:\n        # params: list[representation, head_0, head_1, head_prop]\n        # batch: (X, y, w)\n        X, y, w = batch\n\n        # get representation\n        reps = predict_fun_repr(params[0], X)\n\n        # pass down to heads\n        loss_0 = loss_head(params[1], (reps, y, 1 - w))\n        loss_1 = loss_head(params[2], (reps, y, w))\n\n        # pass down to propensity head\n        loss_prop = loss_head_prop(params[3], (reps, w), penalty_l2)\n        weightsq_prop = sum(\n            [\n                jnp.sum(params[3][i][0] ** 2)\n                for i in range(0, 2 * n_layers_out_prop + 1, 2)\n            ]\n        )\n\n        weightsq_body = sum(\n            [jnp.sum(params[0][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]\n        )\n        weightsq_head = heads_l2_penalty(\n            params[1], params[2], n_layers_out, reg_diff, penalty_l2, penalty_diff\n        )\n\n        if not avg_objective:\n            return (\n                loss_0\n                + loss_1\n                + loss_prop\n                + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)\n            )\n        else:\n            n_batch = y.shape[0]\n            return (\n                (loss_0 + loss_1) / n_batch\n                + loss_prop / n_batch\n                + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)\n            )\n\n    # Define optimisation routine\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n\n    @jit\n    def update(\n        i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_diff: float\n    ) -> jnp.ndarray:\n        # updating function\n        params = get_params(state)\n        return opt_update(\n            i, grad(loss_snet2)(params, batch, penalty_l2, penalty_diff), state\n        )\n\n    # initialise states\n    _, init_params = init_fun_snet2(rng_key, input_shape)\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    # do training\n    for i in range(n_iter):\n        # shuffle data for minibatches\n        onp.random.shuffle(train_indices)\n        for b in range(n_batches):\n            idx_next = train_indices[\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\n            ]\n            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]\n            opt_state = update(\n                i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_diff\n            )\n\n        if (i % n_iter_print == 0) or early_stopping:\n            params_curr = get_params(opt_state)\n            l_curr = loss_snet2(\n                params_curr, (X_val, y_val, w_val), penalty_l2, penalty_diff\n            )\n\n        if i % n_iter_print == 0:\n            log.info(f\"Epoch: {i}, current {val_string} loss {l_curr}\")\n\n        if early_stopping and ((i + 1) * n_batches > n_iter_min):\n            # check if loss updated\n            if l_curr < l_best:\n                l_best = l_curr\n                p_curr = 0\n                params_best = params_curr\n            else:\n                if onp.isnan(l_curr):\n                    # if diverged, return best\n                    return params_best, (\n                        predict_fun_repr,\n                        predict_fun_head_po,\n                        predict_fun_head_prop,\n                    )\n                p_curr = p_curr + 1\n\n            if p_curr > patience:\n                if return_val_loss:\n                    # return loss without penalty\n                    l_final = loss_snet2(params_curr, (X_val, y_val, w_val), 0, 0)\n                    return (\n                        params_curr,\n                        (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop),\n                        l_final,\n                    )\n\n                return params_curr, (\n                    predict_fun_repr,\n                    predict_fun_head_po,\n                    predict_fun_head_prop,\n                )\n\n    # return the parameters\n    trained_params = get_params(opt_state)\n\n    if return_val_loss:\n        # return loss without penalty\n        l_final = loss_snet2(get_params(opt_state), (X_val, y_val, w_val), 0, 0)\n        return (\n            trained_params,\n            (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop),\n            l_final,\n        )\n\n    return trained_params, (\n        predict_fun_repr,\n        predict_fun_head_po,\n        predict_fun_head_prop,\n    )\n\n\ndef predict_snet2(\n    X: jnp.ndarray,\n    trained_params: dict,\n    predict_funs: list,\n    return_po: bool = False,\n    return_prop: bool = False,\n) -> jnp.ndarray:\n    # unpack inputs\n    predict_fun_repr, predict_fun_head, predict_fun_prop = predict_funs\n    param_repr, param_0, param_1, param_prop = (\n        trained_params[0],\n        trained_params[1],\n        trained_params[2],\n        trained_params[3],\n    )\n\n    # get representation\n    representation = predict_fun_repr(param_repr, X)\n\n    # get potential outcomes\n    mu_0 = predict_fun_head(param_0, representation)\n    mu_1 = predict_fun_head(param_1, representation)\n\n    te = mu_1 - mu_0\n    if return_prop:\n        # get propensity\n        prop = predict_fun_prop(param_prop, representation)\n\n    # stack other outputs\n    if return_po:\n        if return_prop:\n            return te, mu_0, mu_1, prop\n        else:\n            return te, mu_0, mu_1\n    else:\n        if return_prop:\n            return te, prop\n        else:\n            return te\n"
  },
  {
    "path": "catenets/models/jax/rnet.py",
    "content": "\"\"\"\nImplements NN based on R-learner and U-learner (as discussed in Nie & Wager (2017))\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Callable, Optional\n\nimport jax.numpy as jnp\nimport numpy as onp\nimport pandas as pd\nfrom jax import grad, jit, random\nfrom jax.example_libraries import optimizers\nfrom sklearn.model_selection import StratifiedKFold\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_AVG_OBJECTIVE,\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_CF_FOLDS,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_OUT_T,\n    DEFAULT_LAYERS_R,\n    DEFAULT_LAYERS_R_T,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_STEP_SIZE_T,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_OUT_T,\n    DEFAULT_UNITS_R,\n    DEFAULT_UNITS_R_T,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.jax.base import (\n    BaseCATENet,\n    OutputHead,\n    make_val_split,\n    train_output_net_only,\n)\nfrom catenets.models.jax.model_utils import check_shape_1d_data, check_X_is_np\n\nR_STRATEGY_NAME = \"R\"\nU_STRATEGY_NAME = \"U\"\n\n\nclass RNet(BaseCATENet):\n    \"\"\"\n    Class implements R-learner and U-learner using NNs\n\n    Parameters\n    ----------\n    second_stage_strategy: str, default 'R'\n        Which strategy to use in the second stage ('R' for R-learner, 'U' for U-learner)\n    data_split: bool, default False\n        Whether to split the data in two folds for estimation\n    cross_fit: bool, default False\n        Whether to perform cross fitting\n    n_cf_folds: int\n        Number of crossfitting folds to use\n    n_layers_out: int\n        First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out: int\n        First stage Number of hidden units in each hypothesis layer\n    n_layers_r: int\n        First stage Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_r: int\n        First stage Number of hidden units in each representation layer\n    n_layers_out_t: int\n        Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out_t: int\n        Second stage Number of hidden units in each hypothesis layer\n    n_layers_r_t: int\n        Second stage Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_r_t: int\n        Second stage Number of hidden units in each representation layer\n    penalty_l2: float\n        First stage l2 (ridge) penalty\n    penalty_l2_t: float\n        Second stage l2 (ridge) penalty\n    step_size: float\n        First stage learning rate for optimizer\n    step_size_t: float\n        Second stage learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    \"\"\"\n\n    def __init__(\n        self,\n        second_stage_strategy: str = R_STRATEGY_NAME,\n        data_split: bool = False,\n        cross_fit: bool = False,\n        n_cf_folds: int = DEFAULT_CF_FOLDS,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n        n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n        n_units_r_t: int = DEFAULT_UNITS_R_T,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_l2_t: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        step_size_t: float = DEFAULT_STEP_SIZE_T,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n        binary_y: bool = False,\n    ) -> None:\n        # settings\n        self.binary_y = binary_y\n        self.second_stage_strategy = second_stage_strategy\n        self.data_split = data_split\n        self.cross_fit = cross_fit\n        self.n_cf_folds = n_cf_folds\n\n        # model architecture hyperparams\n        self.n_layers_out = n_layers_out\n        self.n_layers_out_t = n_layers_out_t\n        self.n_layers_r = n_layers_r\n        self.n_layers_r_t = n_layers_r_t\n        self.n_units_out = n_units_out\n        self.n_units_out_t = n_units_out_t\n        self.n_units_r = n_units_r\n        self.n_units_r_t = n_units_r_t\n        self.nonlin = nonlin\n\n        # other hyperparameters\n        self.penalty_l2 = penalty_l2\n        self.penalty_l2_t = penalty_l2_t\n        self.step_size = step_size\n        self.step_size_t = step_size_t\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.n_iter_print = n_iter_print\n        self.seed = seed\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n\n    def _get_train_function(self) -> Callable:\n        return train_r_net\n\n    def fit(\n        self,\n        X: jnp.ndarray,\n        y: jnp.ndarray,\n        w: jnp.ndarray,\n        p: Optional[jnp.ndarray] = None,\n    ) -> \"RNet\":\n        # overwrite super so we can pass p as extra param\n        # some quick input checks\n        X = check_X_is_np(X)\n        self._check_inputs(w, p)\n\n        train_func = self._get_train_function()\n        train_params = self.get_params()\n\n        self._params, self._predict_funs = train_func(X, y, w, p, **train_params)\n\n        return self\n\n    def _get_predict_function(self) -> Callable:\n        # Two step nets do not need this\n        pass\n\n    def predict(\n        self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False\n    ) -> jnp.ndarray:\n        # check input\n        if return_po:\n            raise NotImplementedError(\n                \"TwoStepNets have no Potential outcome predictors.\"\n            )\n\n        if return_prop:\n            raise NotImplementedError(\"TwoStepNets have no Propensity predictors.\")\n\n        if isinstance(X, pd.DataFrame):\n            X = X.values\n        return self._predict_funs(self._params, X)\n\n\ndef train_r_net(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    p: Optional[jnp.ndarray] = None,\n    second_stage_strategy: str = R_STRATEGY_NAME,\n    data_split: bool = False,\n    cross_fit: bool = False,\n    n_cf_folds: int = DEFAULT_CF_FOLDS,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n    n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_units_r: int = DEFAULT_UNITS_R,\n    n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n    n_units_r_t: int = DEFAULT_UNITS_R_T,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_l2_t: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    step_size_t: float = DEFAULT_STEP_SIZE_T,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    nonlin: str = DEFAULT_NONLIN,\n    binary_y: bool = False,\n) -> Any:\n    # get shape of data\n    n, d = X.shape\n\n    if p is not None:\n        p = check_shape_1d_data(p)\n\n    # split data as wanted\n    if not cross_fit:\n        if not data_split:\n            log.debug(\"Training first stage with all data (no data splitting)\")\n            # use all data for both\n            fit_mask = onp.ones(n, dtype=bool)\n            pred_mask = onp.ones(n, dtype=bool)\n        else:\n            log.debug(\"Training first stage with half of the data (data splitting)\")\n            # split data in half\n            fit_idx = onp.random.choice(n, int(onp.round(n / 2)))\n            fit_mask = onp.zeros(n, dtype=bool)\n\n            fit_mask[fit_idx] = 1\n            pred_mask = ~fit_mask\n\n        mu_hat, pi_hat = _train_and_predict_r_stage1(\n            X,\n            y,\n            w,\n            fit_mask,\n            pred_mask,\n            n_layers_out=n_layers_out,\n            n_layers_r=n_layers_r,\n            n_units_out=n_units_out,\n            n_units_r=n_units_r,\n            penalty_l2=penalty_l2,\n            step_size=step_size,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            nonlin=nonlin,\n            binary_y=binary_y,\n        )\n        if data_split:\n            # keep only prediction data\n            X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :]\n\n            if p is not None:\n                p = p[pred_mask, :]\n\n    else:\n        log.debug(f\"Training first stage in {n_cf_folds} folds (cross-fitting)\")\n        # do cross fitting\n        mu_hat, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1))\n        splitter = StratifiedKFold(n_splits=n_cf_folds, shuffle=True, random_state=seed)\n\n        fold_count = 1\n        for train_idx, test_idx in splitter.split(X, w):\n            log.debug(f\"Training fold {fold_count}.\")\n            fold_count = fold_count + 1\n\n            pred_mask = onp.zeros(n, dtype=bool)\n            pred_mask[test_idx] = 1\n            fit_mask = ~pred_mask\n\n            mu_hat[pred_mask], pi_hat[pred_mask] = _train_and_predict_r_stage1(\n                X,\n                y,\n                w,\n                fit_mask,\n                pred_mask,\n                n_layers_out=n_layers_out,\n                n_layers_r=n_layers_r,\n                n_units_out=n_units_out,\n                n_units_r=n_units_r,\n                penalty_l2=penalty_l2,\n                step_size=step_size,\n                n_iter=n_iter,\n                batch_size=batch_size,\n                val_split_prop=val_split_prop,\n                early_stopping=early_stopping,\n                patience=patience,\n                n_iter_min=n_iter_min,\n                n_iter_print=n_iter_print,\n                seed=seed,\n                nonlin=nonlin,\n                binary_y=binary_y,\n            )\n\n    log.debug(\"Training second stage.\")\n\n    if p is not None:\n        # use known propensity score\n        p = check_shape_1d_data(p)\n        pi_hat = p\n\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    w_ortho = w - pi_hat\n    y_ortho = y - mu_hat\n\n    if second_stage_strategy == R_STRATEGY_NAME:\n        return train_r_stage2(\n            X,\n            y_ortho,\n            w_ortho,\n            n_layers_out=n_layers_out_t,\n            n_units_out=n_units_out_t,\n            n_layers_r=n_layers_r_t,\n            n_units_r=n_units_r_t,\n            penalty_l2=penalty_l2_t,\n            step_size=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            nonlin=nonlin,\n        )\n    elif second_stage_strategy == U_STRATEGY_NAME:\n        return train_output_net_only(\n            X,\n            y_ortho / w_ortho,\n            n_layers_out=n_layers_out_t,\n            n_units_out=n_units_out_t,\n            n_layers_r=n_layers_r_t,\n            n_units_r=n_units_r_t,\n            penalty_l2=penalty_l2_t,\n            step_size=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            nonlin=nonlin,\n        )\n    else:\n        raise ValueError(\"R-learner only supports strategies R and U.\")\n\n\ndef _train_and_predict_r_stage1(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    fit_mask: jnp.ndarray,\n    pred_mask: jnp.ndarray,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    nonlin: str = DEFAULT_NONLIN,\n    binary_y: bool = False,\n) -> Any:\n    if len(w.shape) > 1:\n        w = w.reshape((len(w),))\n\n    # split the data\n    X_fit, y_fit, w_fit = X[fit_mask, :], y[fit_mask], w[fit_mask]\n    X_pred = X[pred_mask, :]\n\n    log.debug(\"Training output Net\")\n    params_out, predict_fun_out = train_output_net_only(\n        X_fit,\n        y_fit,\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        n_layers_r=n_layers_r,\n        n_units_r=n_units_r,\n        penalty_l2=penalty_l2,\n        step_size=step_size,\n        n_iter=n_iter,\n        batch_size=batch_size,\n        val_split_prop=val_split_prop,\n        early_stopping=early_stopping,\n        patience=patience,\n        n_iter_min=n_iter_min,\n        n_iter_print=n_iter_print,\n        seed=seed,\n        nonlin=nonlin,\n        binary_y=binary_y,\n    )\n    mu_hat = predict_fun_out(params_out, X_pred)\n\n    log.debug(\"Training propensity net\")\n    params_prop, predict_fun_prop = train_output_net_only(\n        X_fit,\n        w_fit,\n        binary_y=True,\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        n_layers_r=n_layers_r,\n        n_units_r=n_units_r,\n        penalty_l2=penalty_l2,\n        step_size=step_size,\n        n_iter=n_iter,\n        batch_size=batch_size,\n        val_split_prop=val_split_prop,\n        early_stopping=early_stopping,\n        patience=patience,\n        n_iter_min=n_iter_min,\n        n_iter_print=n_iter_print,\n        seed=seed,\n        nonlin=nonlin,\n    )\n    pi_hat = predict_fun_prop(params_prop, X_pred)\n\n    return mu_hat, pi_hat\n\n\ndef train_r_stage2(\n    X: jnp.ndarray,\n    y_ortho: jnp.ndarray,\n    w_ortho: jnp.ndarray,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_layers_r: int = 0,\n    n_units_r: int = DEFAULT_UNITS_R,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n) -> Any:\n    # function to train a single output head\n    # input check\n    y_ortho, w_ortho = check_shape_1d_data(y_ortho), check_shape_1d_data(w_ortho)\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    # get validation split (can be none)\n    X, y_ortho, w_ortho, X_val, y_val, w_val, val_string = make_val_split(\n        X, y_ortho, w_ortho, val_split_prop=val_split_prop, seed=seed, stratify_w=False\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get output head\n    init_fun, predict_fun = OutputHead(\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        n_layers_r=n_layers_r,\n        n_units_r=n_units_r,\n        nonlin=nonlin,\n    )\n\n    # define loss and grad\n    @jit\n    def loss(params: dict, batch: jnp.ndarray, penalty: float) -> jnp.ndarray:\n        # mse loss function\n        inputs, ortho_targets, ortho_treats = batch\n        preds = predict_fun(params, inputs)\n        weightsq = sum(\n            [\n                jnp.sum(params[i][0] ** 2)\n                for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2)\n            ]\n        )\n        if not avg_objective:\n            return (\n                jnp.sum((ortho_targets - ortho_treats * preds) ** 2)\n                + 0.5 * penalty * weightsq\n            )\n        else:\n            return (\n                jnp.average((ortho_targets - ortho_treats * preds) ** 2)\n                + 0.5 * penalty * weightsq\n            )\n\n    # set optimization routine\n    # set optimizer\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n\n    # set update function\n    @jit\n    def update(i: int, state: dict, batch: jnp.ndarray, penalty: float) -> jnp.ndarray:\n        params = get_params(state)\n        g_params = grad(loss)(params, batch, penalty)\n        # g_params = optimizers.clip_grads(g_params, 1.0)\n        return opt_update(i, g_params, state)\n\n    # initialise states\n    _, init_params = init_fun(rng_key, input_shape)\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    # do training\n    for i in range(n_iter):\n        # shuffle data for minibatches\n        onp.random.shuffle(train_indices)\n        for b in range(n_batches):\n            idx_next = train_indices[\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\n            ]\n            next_batch = X[idx_next, :], y_ortho[idx_next, :], w_ortho[idx_next, :]\n            opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2)\n\n        if (i % n_iter_print == 0) or early_stopping:\n            params_curr = get_params(opt_state)\n            l_curr = loss(params_curr, (X_val, y_val, w_val), penalty_l2)\n\n        if i % n_iter_print == 0:\n            log.debug(f\"Epoch: {i}, current {val_string} loss: {l_curr}\")\n\n        if early_stopping and ((i + 1) * n_batches > n_iter_min):\n            # check if loss updated\n            if l_curr < l_best:\n                l_best = l_curr\n                p_curr = 0\n            else:\n                p_curr = p_curr + 1\n\n            if p_curr > patience:\n                trained_params = get_params(opt_state)\n\n                if return_val_loss:\n                    # return loss without penalty\n                    l_final = loss(trained_params, (X_val, y_val, w_val), 0)\n                    return trained_params, predict_fun, l_final\n\n                return trained_params, predict_fun\n\n    # get final parameters\n    trained_params = get_params(opt_state)\n\n    if return_val_loss:\n        # return loss without penalty\n        l_final = loss(trained_params, (X_val, y_val, w_val), 0)\n        return trained_params, predict_fun, l_final\n\n    return trained_params, predict_fun\n"
  },
  {
    "path": "catenets/models/jax/snet.py",
    "content": "\"\"\"\nModule implements SNet class as discussed in Curth & van der Schaar (2021)\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Callable, List, Tuple\n\nimport jax.numpy as jnp\nimport numpy as onp\nfrom jax import grad, jit, random\nfrom jax.example_libraries import optimizers\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_AVG_OBJECTIVE,\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_DISC,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_PENALTY_ORTHOGONAL,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_R_BIG_S3,\n    DEFAULT_UNITS_R_SMALL_S3,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.jax.base import BaseCATENet, OutputHead, ReprBlock\nfrom catenets.models.jax.disentangled_nets import (\n    _concatenate_representations,\n    _get_absolute_rowsums,\n)\nfrom catenets.models.jax.flextenet import _get_cos_reg\nfrom catenets.models.jax.model_utils import (\n    check_shape_1d_data,\n    heads_l2_penalty,\n    make_val_split,\n)\nfrom catenets.models.jax.representation_nets import mmd2_lin\n\nDEFAULT_UNITS_R_BIG_S = 100\nDEFAULT_UNITS_R_SMALL_S = 50\n\n\nclass SNet(BaseCATENet):\n    \"\"\"\n    Class implements SNet as discussed in Curth & van der Schaar (2021). Additionally to the\n    version implemented in the AISTATS paper, we also include an implementation that does not\n    have propensity heads (set with_prop=False)\n\n    Parameters\n    ----------\n    with_prop: bool, True\n        Whether to include propensity head\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_layers_out_prop: int\n        Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense\n        layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_units_out_prop: int\n        Number of hidden units in each propensity score hypothesis layer\n    n_layers_r: int\n        Number of shared & private representation layers before hypothesis layers\n    n_units_r: int\n        If withprop=True: Number of hidden units in representation layer shared by propensity score\n        and outcome  function (the 'confounding factor') and in the ('instrumental factor')\n        If withprop=False: Number of hidden units in representation shared across PO function\n    n_units_r_small: int\n        If withprop=True: Number of hidden units in representation layer of the 'outcome factor'\n        and each PO functions private representation\n        if withprop=False: Number of hidden units in each PO functions private representation\n    penalty_l2: float\n        l2 (ridge) penalty\n    step_size: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    reg_diff: bool, default False\n        Whether to regularize the difference between the two potential outcome heads\n    penalty_diff: float\n        l2-penalty for regularizing the difference between output heads. used only if\n        train_separate=False\n    same_init: bool, False\n        Whether to initialise the two output heads with same values\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    penalty_disc: float, default zero\n        Discrepancy penalty. Defaults to zero as this feature is not tested.\n    ortho_reg_type: str, 'abs'\n        Which type of orthogonalization to use. 'abs' uses the (hard) disentanglement described\n        in AISTATS paper, 'fro' uses frobenius norm as in FlexTENet\n    \"\"\"\n\n    def __init__(\n        self,\n        with_prop: bool = True,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R_BIG_S,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n        penalty_disc: float = DEFAULT_PENALTY_DISC,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        reg_diff: bool = False,\n        penalty_diff: float = DEFAULT_PENALTY_L2,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n        same_init: bool = False,\n        ortho_reg_type: str = \"abs\",\n    ):\n        self.with_prop = with_prop\n        self.binary_y = binary_y\n\n        self.n_layers_r = n_layers_r\n        self.n_layers_out = n_layers_out\n        self.n_layers_out_prop = n_layers_out_prop\n        self.n_units_r = n_units_r\n        self.n_units_r_small = n_units_r_small\n        self.n_units_out = n_units_out\n        self.n_units_out_prop = n_units_out_prop\n        self.nonlin = nonlin\n\n        self.penalty_l2 = penalty_l2\n        self.penalty_orthogonal = penalty_orthogonal\n        self.penalty_disc = penalty_disc\n        self.reg_diff = reg_diff\n        self.penalty_diff = penalty_diff\n        self.same_init = same_init\n        self.ortho_reg_type = ortho_reg_type\n\n        self.step_size = step_size\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n\n        self.seed = seed\n        self.n_iter_print = n_iter_print\n\n    def _get_predict_function(self) -> Callable:\n        if self.with_prop:\n            return predict_snet\n        else:\n            return predict_snet_noprop\n\n    def _get_train_function(self) -> Callable:\n        if self.with_prop:\n            return train_snet\n        else:\n            return train_snet_noprop\n\n\ndef train_snet(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R_BIG_S,\n    n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_units_out_prop: int = DEFAULT_UNITS_OUT,\n    n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_disc: float = DEFAULT_PENALTY_DISC,\n    penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    reg_diff: bool = False,\n    penalty_diff: float = DEFAULT_PENALTY_L2,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n    with_prop: bool = True,\n    same_init: bool = False,\n    ortho_reg_type: str = \"abs\",\n) -> Tuple:\n    # function to train a net with 5 representations\n    if not with_prop:\n        raise ValueError(\"train_snet works only withprop=True\")\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    if not reg_diff:\n        penalty_diff = penalty_l2\n\n    # get validation split (can be none)\n    X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n        X, y, w, val_split_prop=val_split_prop, seed=seed\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get representation layers\n    init_fun_repr, predict_fun_repr = ReprBlock(\n        n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin\n    )\n    init_fun_repr_small, predict_fun_repr_small = ReprBlock(\n        n_layers=n_layers_r, n_units=n_units_r_small, nonlin=nonlin\n    )\n\n    # get output head functions (output heads share same structure)\n    init_fun_head_po, predict_fun_head_po = OutputHead(\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        binary_y=binary_y,\n        nonlin=nonlin,\n    )\n    # add propensity head\n    init_fun_head_prop, predict_fun_head_prop = OutputHead(\n        n_layers_out=n_layers_out_prop,\n        n_units_out=n_units_out_prop,\n        binary_y=True,\n        nonlin=nonlin,\n    )\n\n    def init_fun_snet(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]:\n        # chain together the layers\n        # param should look like [param_repr_c, param_repr_o, param_repr_mu0, param_repr_mu1,\n        #                              param_repr_w, param_0, param_1, param_prop]\n        # initialise representation layers\n        rng, layer_rng = random.split(rng)\n        input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape)\n        rng, layer_rng = random.split(rng)\n        input_shape_repr_small, param_repr_o = init_fun_repr_small(\n            layer_rng, input_shape\n        )\n        rng, layer_rng = random.split(rng)\n        _, param_repr_mu0 = init_fun_repr_small(layer_rng, input_shape)\n        rng, layer_rng = random.split(rng)\n        _, param_repr_mu1 = init_fun_repr_small(layer_rng, input_shape)\n        rng, layer_rng = random.split(rng)\n        _, param_repr_w = init_fun_repr(layer_rng, input_shape)\n\n        # prop and mu_0 each get two representations, mu_1 gets 3\n        input_shape_repr_prop = input_shape_repr[:-1] + (2 * input_shape_repr[-1],)\n        input_shape_repr_mu = input_shape_repr[:-1] + (\n            input_shape_repr[-1] + (2 * input_shape_repr_small[-1]),\n        )\n\n        # initialise output heads\n        rng, layer_rng = random.split(rng)\n        if same_init:\n            # initialise both on same values\n            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr_mu)\n            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr_mu)\n        else:\n            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr_mu)\n            rng, layer_rng = random.split(rng)\n            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr_mu)\n\n        rng, layer_rng = random.split(rng)\n        input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr_prop)\n        return input_shape, [\n            param_repr_c,\n            param_repr_o,\n            param_repr_mu0,\n            param_repr_mu1,\n            param_repr_w,\n            param_0,\n            param_1,\n            param_prop,\n        ]\n\n    # Define loss functions\n    # loss functions for the head\n    if not binary_y:\n\n        def loss_head(\n            params: jnp.ndarray,\n            batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n            penalty: float,\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head_po(params, inputs)\n            return jnp.sum(weights * ((preds - targets) ** 2))\n\n    else:\n\n        def loss_head(\n            params: jnp.ndarray,\n            batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n            penalty: float,\n        ) -> jnp.ndarray:\n            # log loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head_po(params, inputs)\n            return -jnp.sum(\n                weights\n                * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n            )\n\n    def loss_head_prop(\n        params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float\n    ) -> jnp.ndarray:\n        # log loss function for propensities\n        inputs, targets = batch\n        preds = predict_fun_head_prop(params, inputs)\n        return -jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n\n    # define ortho-reg function\n    if ortho_reg_type == \"abs\":\n\n        def ortho_reg(params: jnp.ndarray) -> jnp.ndarray:\n            col_c = _get_absolute_rowsums(params[0][0][0])\n            col_o = _get_absolute_rowsums(params[1][0][0])\n            col_mu0 = _get_absolute_rowsums(params[2][0][0])\n            col_mu1 = _get_absolute_rowsums(params[3][0][0])\n            col_w = _get_absolute_rowsums(params[4][0][0])\n            return jnp.sum(\n                col_c * col_o\n                + col_c * col_w\n                + col_c * col_mu1\n                + col_c * col_mu0\n                + col_w * col_o\n                + col_mu0 * col_o\n                + col_o * col_mu1\n                + col_mu0 * col_mu1\n                + col_mu0 * col_w\n                + col_w * col_mu1\n            )\n\n    elif ortho_reg_type == \"fro\":\n\n        def ortho_reg(params: jnp.ndarray) -> jnp.ndarray:\n            return (\n                _get_cos_reg(params[0][0][0], params[1][0][0], False)\n                + _get_cos_reg(params[0][0][0], params[2][0][0], False)\n                + _get_cos_reg(params[0][0][0], params[3][0][0], False)\n                + _get_cos_reg(params[0][0][0], params[4][0][0], False)\n                + _get_cos_reg(params[1][0][0], params[2][0][0], False)\n                + _get_cos_reg(params[1][0][0], params[3][0][0], False)\n                + _get_cos_reg(params[1][0][0], params[4][0][0], False)\n                + _get_cos_reg(params[2][0][0], params[3][0][0], False)\n                + _get_cos_reg(params[2][0][0], params[4][0][0], False)\n                + _get_cos_reg(params[3][0][0], params[4][0][0], False)\n            )\n\n    else:\n        raise NotImplementedError(\n            \"train_snet_noprop supports only orthogonal regularization \"\n            \"using absolute values or frobenious norms.\"\n        )\n\n    # complete loss function for all parts\n    @jit\n    def loss_snet(\n        params: jnp.ndarray,\n        batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n        penalty_l2: float,\n        penalty_orthogonal: float,\n        penalty_disc: float,\n    ) -> jnp.ndarray:\n        # params: # param should look like [param_repr_c, param_repr_o, param_repr_mu0,\n        #              param_repr_mu1, param_repr_w, param_0, param_1, param_prop]\n        # batch: (X, y, w)\n        X, y, w = batch\n\n        # get representation\n        reps_c = predict_fun_repr(params[0], X)\n        reps_o = predict_fun_repr_small(params[1], X)\n        reps_mu0 = predict_fun_repr_small(params[2], X)\n        reps_mu1 = predict_fun_repr_small(params[3], X)\n        reps_w = predict_fun_repr(params[4], X)\n\n        # concatenate\n        reps_po_0 = _concatenate_representations((reps_c, reps_o, reps_mu0))\n        reps_po_1 = _concatenate_representations((reps_c, reps_o, reps_mu1))\n        reps_prop = _concatenate_representations((reps_c, reps_w))\n\n        # pass down to heads\n        loss_0 = loss_head(params[5], (reps_po_0, y, 1 - w), penalty_l2)\n        loss_1 = loss_head(params[6], (reps_po_1, y, w), penalty_l2)\n\n        # pass down to propensity head\n        loss_prop = loss_head_prop(params[7], (reps_prop, w), penalty_l2)\n\n        # is rep_o balanced between groups?\n        loss_disc = penalty_disc * mmd2_lin(reps_o, w)\n\n        # which variable has impact on which representation -- orthogonal loss\n        loss_o = penalty_orthogonal * ortho_reg(params)\n\n        # weight decay on representations\n        weightsq_body = sum(\n            [\n                sum(\n                    [jnp.sum(params[j][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]\n                )\n                for j in range(5)\n            ]\n        )\n        weightsq_head = heads_l2_penalty(\n            params[5], params[6], n_layers_out, reg_diff, penalty_l2, penalty_diff\n        )\n        weightsq_prop = sum(\n            [\n                jnp.sum(params[7][i][0] ** 2)\n                for i in range(0, 2 * n_layers_out_prop + 1, 2)\n            ]\n        )\n\n        if not avg_objective:\n            return (\n                loss_0\n                + loss_1\n                + loss_prop\n                + loss_disc\n                + loss_o\n                + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)\n            )\n        else:\n            n_batch = y.shape[0]\n            return (\n                (loss_0 + loss_1) / n_batch\n                + loss_prop / n_batch\n                + loss_disc\n                + loss_o\n                + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)\n            )\n\n    # Define optimisation routine\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n\n    @jit\n    def update(\n        i: int,\n        state: dict,\n        batch: jnp.ndarray,\n        penalty_l2: float,\n        penalty_orthogonal: float,\n        penalty_disc: float,\n    ) -> jnp.ndarray:\n        # updating function\n        params = get_params(state)\n        return opt_update(\n            i,\n            grad(loss_snet)(\n                params, batch, penalty_l2, penalty_orthogonal, penalty_disc\n            ),\n            state,\n        )\n\n    # initialise states\n    _, init_params = init_fun_snet(rng_key, input_shape)\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    # do training\n    for i in range(n_iter):\n        # shuffle data for minibatches\n        onp.random.shuffle(train_indices)\n        for b in range(n_batches):\n            idx_next = train_indices[\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\n            ]\n            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]\n            opt_state = update(\n                i * n_batches + b,\n                opt_state,\n                next_batch,\n                penalty_l2,\n                penalty_orthogonal,\n                penalty_disc,\n            )\n\n        if (i % n_iter_print == 0) or early_stopping:\n            params_curr = get_params(opt_state)\n            l_curr = loss_snet(\n                params_curr,\n                (X_val, y_val, w_val),\n                penalty_l2,\n                penalty_orthogonal,\n                penalty_disc,\n            )\n\n        if i % n_iter_print == 0:\n            log.info(f\"Epoch: {i}, current {val_string} loss {l_curr}\")\n\n        if early_stopping and ((i + 1) * n_batches > n_iter_min):\n            # check if loss updated\n            if l_curr < l_best:\n                l_best = l_curr\n                p_curr = 0\n                params_best = params_curr\n            else:\n                if onp.isnan(l_curr):\n                    # if diverged, return best\n                    return params_best, (\n                        predict_fun_repr,\n                        predict_fun_head_po,\n                        predict_fun_head_prop,\n                    )\n                p_curr = p_curr + 1\n\n            if p_curr > patience:\n                if return_val_loss:\n                    # return loss without penalty\n                    l_final = loss_snet(params_curr, (X_val, y_val, w_val), 0, 0, 0)\n                    return (\n                        params_curr,\n                        (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop),\n                        l_final,\n                    )\n\n                return params_curr, (\n                    predict_fun_repr,\n                    predict_fun_head_po,\n                    predict_fun_head_prop,\n                )\n\n    # return the parameters\n    trained_params = get_params(opt_state)\n\n    if return_val_loss:\n        # return loss without penalty\n        l_final = loss_snet(get_params(opt_state), (X_val, y_val, w_val), 0, 0)\n        return (\n            trained_params,\n            (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop),\n            l_final,\n        )\n\n    return trained_params, (\n        predict_fun_repr,\n        predict_fun_head_po,\n        predict_fun_head_prop,\n    )\n\n\ndef predict_snet(\n    X: jnp.ndarray,\n    trained_params: jnp.ndarray,\n    predict_funs: list,\n    return_po: bool = False,\n    return_prop: bool = False,\n) -> jnp.ndarray:\n    # unpack inputs\n    predict_fun_repr, predict_fun_head, predict_fun_prop = predict_funs\n    param_0, param_1, param_prop = (\n        trained_params[5],\n        trained_params[6],\n        trained_params[7],\n    )\n\n    reps_c = predict_fun_repr(trained_params[0], X)\n    reps_o = predict_fun_repr(trained_params[1], X)\n    reps_mu0 = predict_fun_repr(trained_params[2], X)\n    reps_mu1 = predict_fun_repr(trained_params[3], X)\n    reps_w = predict_fun_repr(trained_params[4], X)\n\n    # concatenate\n    reps_po_0 = _concatenate_representations((reps_c, reps_o, reps_mu0))\n    reps_po_1 = _concatenate_representations((reps_c, reps_o, reps_mu1))\n    reps_prop = _concatenate_representations((reps_c, reps_w))\n\n    # get potential outcomes\n    mu_0 = predict_fun_head(param_0, reps_po_0)\n    mu_1 = predict_fun_head(param_1, reps_po_1)\n\n    te = mu_1 - mu_0\n    if return_prop:\n        # get propensity\n        prop = predict_fun_prop(param_prop, reps_prop)\n\n    # stack other outputs\n    if return_po:\n        if return_prop:\n            return te, mu_0, mu_1, prop\n        else:\n            return te, mu_0, mu_1\n    else:\n        if return_prop:\n            return te, prop\n        else:\n            return te\n\n\n# SNet without propensity head  ----------------------------------------\ndef train_snet_noprop(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R_BIG_S3,\n    n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_units_out_prop: int = DEFAULT_UNITS_OUT,\n    n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    reg_diff: bool = False,\n    penalty_diff: float = DEFAULT_PENALTY_L2,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n    with_prop: bool = False,\n    same_init: bool = False,\n    ortho_reg_type: str = \"abs\",\n) -> Tuple:\n    \"\"\"\n    SNet but without the propensity head\n    \"\"\"\n    if with_prop:\n        raise ValueError(\"train_snet_noprop works only with_prop=False\")\n    # function to train a net with 3 representations\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    if not reg_diff:\n        penalty_diff = penalty_l2\n\n    # get validation split (can be none)\n    X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n        X, y, w, val_split_prop=val_split_prop, seed=seed\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get representation layers\n    init_fun_repr, predict_fun_repr = ReprBlock(\n        n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin\n    )\n    init_fun_repr_small, predict_fun_repr_small = ReprBlock(\n        n_layers=n_layers_r, n_units=n_units_r_small, nonlin=nonlin\n    )\n\n    # get output head functions (output heads share same structure)\n    init_fun_head_po, predict_fun_head_po = OutputHead(\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        binary_y=binary_y,\n        nonlin=nonlin,\n    )\n\n    def init_fun_snet_noprop(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]:\n        # chain together the layers\n        # param should look like [repr_o, repr_p0, repr_p1, po_0, po_1]\n        # initialise representation layers\n        rng, layer_rng = random.split(rng)\n        input_shape_repr, param_repr_o = init_fun_repr(layer_rng, input_shape)\n        rng, layer_rng = random.split(rng)\n        input_shape_repr_small, param_repr_p0 = init_fun_repr_small(\n            layer_rng, input_shape\n        )\n        rng, layer_rng = random.split(rng)\n        _, param_repr_p1 = init_fun_repr_small(layer_rng, input_shape)\n\n        # each head gets two representations\n        input_shape_repr = input_shape_repr[:-1] + (\n            input_shape_repr[-1] + input_shape_repr_small[-1],\n        )\n\n        # initialise output heads\n        rng, layer_rng = random.split(rng)\n        if same_init:\n            # initialise both on same values\n            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)\n            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)\n        else:\n            input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)\n            rng, layer_rng = random.split(rng)\n            input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)\n\n        return input_shape, [\n            param_repr_o,\n            param_repr_p0,\n            param_repr_p1,\n            param_0,\n            param_1,\n        ]\n\n    # Define loss functions\n    # loss functions for the head\n    if not binary_y:\n\n        def loss_head(\n            params: jnp.ndarray,\n            batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n            penalty: float,\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head_po(params, inputs)\n            return jnp.sum(weights * ((preds - targets) ** 2))\n\n    else:\n\n        def loss_head(\n            params: jnp.ndarray,\n            batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n            penalty: float,\n        ) -> jnp.ndarray:\n            # log loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head_po(params, inputs)\n            return -jnp.sum(\n                weights\n                * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n            )\n\n    # define ortho-reg function\n    if ortho_reg_type == \"abs\":\n\n        def ortho_reg(params: jnp.ndarray) -> jnp.ndarray:\n            col_o = _get_absolute_rowsums(params[0][0][0])\n            col_p0 = _get_absolute_rowsums(params[1][0][0])\n            col_p1 = _get_absolute_rowsums(params[2][0][0])\n            return jnp.sum(col_o * col_p0 + col_o * col_p1 + col_p1 * col_p0)\n\n    elif ortho_reg_type == \"fro\":\n\n        def ortho_reg(params: jnp.ndarray) -> jnp.ndarray:\n            return (\n                _get_cos_reg(params[0][0][0], params[1][0][0], False)\n                + _get_cos_reg(params[0][0][0], params[2][0][0], False)\n                + _get_cos_reg(params[1][0][0], params[2][0][0], False)\n            )\n\n    else:\n        raise NotImplementedError(\n            \"train_snet_noprop supports only orthogonal regularization \"\n            \"using absolute values or frobenious norms.\"\n        )\n\n    # complete loss function for all parts\n    @jit\n    def loss_snet_noprop(\n        params: jnp.ndarray,\n        batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n        penalty_l2: float,\n        penalty_orthogonal: float,\n    ) -> jnp.ndarray:\n        # params: list[repr_o, repr_p0, repr_p1, po_0, po_1]\n        # batch: (X, y, w)\n        X, y, w = batch\n\n        # get representation\n        reps_o = predict_fun_repr(params[0], X)\n        reps_p0 = predict_fun_repr_small(params[1], X)\n        reps_p1 = predict_fun_repr_small(params[2], X)\n\n        # concatenate\n        reps_po0 = _concatenate_representations((reps_o, reps_p0))\n        reps_po1 = _concatenate_representations((reps_o, reps_p1))\n\n        # pass down to heads\n        loss_0 = loss_head(params[3], (reps_po0, y, 1 - w), penalty_l2)\n        loss_1 = loss_head(params[4], (reps_po1, y, w), penalty_l2)\n\n        # which variable has impact on which representation\n        loss_o = penalty_orthogonal * ortho_reg(params)\n\n        # weight decay on representations\n        weightsq_body = sum(\n            [\n                sum(\n                    [jnp.sum(params[j][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]\n                )\n                for j in range(3)\n            ]\n        )\n        weightsq_head = heads_l2_penalty(\n            params[3], params[4], n_layers_out, reg_diff, penalty_l2, penalty_diff\n        )\n        if not avg_objective:\n            return (\n                loss_0\n                + loss_1\n                + loss_o\n                + 0.5 * (penalty_l2 * weightsq_body + weightsq_head)\n            )\n        else:\n            n_batch = y.shape[0]\n            return (\n                (loss_0 + loss_1) / n_batch\n                + loss_o\n                + 0.5 * (penalty_l2 * weightsq_body + weightsq_head)\n            )\n\n    # Define optimisation routine\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n\n    @jit\n    def update(\n        i: int,\n        state: dict,\n        batch: jnp.ndarray,\n        penalty_l2: float,\n        penalty_orthogonal: float,\n    ) -> jnp.ndarray:\n        # updating function\n        params = get_params(state)\n        return opt_update(\n            i,\n            grad(loss_snet_noprop)(params, batch, penalty_l2, penalty_orthogonal),\n            state,\n        )\n\n    # initialise states\n    _, init_params = init_fun_snet_noprop(rng_key, input_shape)\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    # do training\n    for i in range(n_iter):\n        # shuffle data for minibatches\n        onp.random.shuffle(train_indices)\n        for b in range(n_batches):\n            idx_next = train_indices[\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\n            ]\n            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]\n            opt_state = update(\n                i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_orthogonal\n            )\n\n        if (i % n_iter_print == 0) or early_stopping:\n            params_curr = get_params(opt_state)\n            l_curr = loss_snet_noprop(\n                params_curr, (X_val, y_val, w_val), penalty_l2, penalty_orthogonal\n            )\n\n        if i % n_iter_print == 0:\n            log.info(f\"Epoch: {i}, current {val_string} loss {l_curr}\")\n\n        if early_stopping and ((i + 1) * n_batches > n_iter_min):\n            # check if loss updated\n            if l_curr < l_best:\n                l_best = l_curr\n                p_curr = 0\n                params_best = params_curr\n            else:\n                if onp.isnan(l_curr):\n                    # if diverged, return best\n                    return params_best, (predict_fun_repr, predict_fun_head_po)\n                p_curr = p_curr + 1\n\n            if p_curr > patience:\n                if return_val_loss:\n                    # return loss without penalty\n                    l_final = loss_snet_noprop(params_curr, (X_val, y_val, w_val), 0, 0)\n                    return params_curr, (predict_fun_repr, predict_fun_head_po), l_final\n\n                return params_curr, (predict_fun_repr, predict_fun_head_po)\n\n    # return the parameters\n    trained_params = get_params(opt_state)\n\n    if return_val_loss:\n        # return loss without penalty\n        l_final = loss_snet_noprop(get_params(opt_state), (X_val, y_val, w_val), 0, 0)\n        return trained_params, (predict_fun_repr, predict_fun_head_po), l_final\n\n    return trained_params, (predict_fun_repr, predict_fun_head_po)\n\n\ndef predict_snet_noprop(\n    X: jnp.ndarray,\n    trained_params: jnp.ndarray,\n    predict_funs: list,\n    return_po: bool = False,\n    return_prop: bool = False,\n) -> jnp.ndarray:\n\n    if return_prop:\n        raise NotImplementedError(\"SNet5 does not have propensity estimator\")\n\n    # unpack inputs\n    predict_fun_repr, predict_fun_head = predict_funs\n    param_repr_o, param_repr_po0, param_repr_po1 = (\n        trained_params[0],\n        trained_params[1],\n        trained_params[2],\n    )\n    param_0, param_1 = trained_params[3], trained_params[4]\n\n    # get representations\n    rep_o = predict_fun_repr(param_repr_o, X)\n    rep_po0 = predict_fun_repr(param_repr_po0, X)\n    rep_po1 = predict_fun_repr(param_repr_po1, X)\n\n    # concatenate\n    reps_po0 = jnp.concatenate((rep_o, rep_po0), axis=1)\n    reps_po1 = jnp.concatenate((rep_o, rep_po1), axis=1)\n\n    # get potential outcomes\n    mu_0 = predict_fun_head(param_0, reps_po0)\n    mu_1 = predict_fun_head(param_1, reps_po1)\n\n    te = mu_1 - mu_0\n\n    # stack other outputs\n    if return_po:\n        return te, mu_0, mu_1\n    else:\n        return te\n"
  },
  {
    "path": "catenets/models/jax/tnet.py",
    "content": "\"\"\"\nImplements a T-Net: T-learner for CATE based on a dense NN\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Callable, List, Tuple\n\nimport jax.numpy as jnp\nimport numpy as onp\nfrom jax import grad, jit, random\nfrom jax.example_libraries import optimizers\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_AVG_OBJECTIVE,\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_R,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.jax.base import BaseCATENet, OutputHead, train_output_net_only\nfrom catenets.models.jax.model_utils import (\n    check_shape_1d_data,\n    heads_l2_penalty,\n    make_val_split,\n)\n\n\nclass TNet(BaseCATENet):\n    \"\"\"\n    TNet class -- two separate functions learned for each Potential Outcome function\n\n    Parameters\n    ----------\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_layers_r: int\n        Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_r: int\n        Number of hidden units in each representation layer\n    penalty_l2: float\n        l2 (ridge) penalty\n    step_size: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    train_separate: bool, default True\n        Whether to train the two output heads completely separately or whether to regularize\n        their difference\n    penalty_diff: float\n        l2-penalty for regularizing the difference between output heads. used only if\n        train_separate=False\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    \"\"\"\n\n    def __init__(\n        self,\n        binary_y: bool = False,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        train_separate: bool = True,\n        penalty_diff: float = DEFAULT_PENALTY_L2,\n        nonlin: str = DEFAULT_NONLIN,\n    ) -> None:\n        self.binary_y = binary_y\n        self.n_layers_out = n_layers_out\n        self.n_units_out = n_units_out\n        self.n_layers_r = n_layers_r\n        self.n_units_r = n_units_r\n        self.penalty_l2 = penalty_l2\n        self.step_size = step_size\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n        self.n_iter_print = n_iter_print\n        self.seed = seed\n        self.train_separate = train_separate\n        self.penalty_diff = penalty_diff\n        self.nonlin = nonlin\n\n    def _get_predict_function(self) -> Callable:\n        return predict_t_net\n\n    def _get_train_function(self) -> Callable:\n        return train_tnet\n\n\ndef train_tnet(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    train_separate: bool = True,\n    penalty_diff: float = DEFAULT_PENALTY_L2,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n) -> Any:\n    # w should be 1-D for indexing\n    if len(w.shape) > 1:\n        w = w.reshape((len(w),))\n\n    if train_separate:\n        # train two heads completely independently\n        log.debug(\"Training PO_0 Net\")\n        out_0 = train_output_net_only(\n            X[w == 0],\n            y[w == 0],\n            binary_y=binary_y,\n            n_layers_out=n_layers_out,\n            n_units_out=n_units_out,\n            n_layers_r=n_layers_r,\n            n_units_r=n_units_r,\n            penalty_l2=penalty_l2,\n            step_size=step_size,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            nonlin=nonlin,\n            avg_objective=avg_objective,\n        )\n        log.debug(\"Training PO_1 Net\")\n        out_1 = train_output_net_only(\n            X[w == 1],\n            y[w == 1],\n            binary_y=binary_y,\n            n_layers_out=n_layers_out,\n            n_units_out=n_units_out,\n            n_layers_r=n_layers_r,\n            n_units_r=n_units_r,\n            penalty_l2=penalty_l2,\n            step_size=step_size,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            nonlin=nonlin,\n            avg_objective=avg_objective,\n        )\n\n        if return_val_loss:\n            params_0, predict_fun_0, loss_0 = out_0\n            params_1, predict_fun_1, loss_1 = out_1\n            return (params_0, params_1), (predict_fun_0, predict_fun_1), loss_1 + loss_0\n\n        params_0, predict_fun_0 = out_0\n        params_1, predict_fun_1 = out_1\n    else:\n        # train jointly by regularizing similarity\n        params, predict_fun = _train_tnet_jointly(\n            X,\n            y,\n            w,\n            binary_y=binary_y,\n            n_layers_out=n_layers_out,\n            n_units_out=n_units_out,\n            n_layers_r=n_layers_r,\n            n_units_r=n_units_r,\n            penalty_l2=penalty_l2,\n            step_size=step_size,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            penalty_diff=penalty_diff,\n            nonlin=nonlin,\n        )\n        params_0, params_1 = params[0], params[1]\n        predict_fun_0, predict_fun_1 = predict_fun, predict_fun\n\n    return (params_0, params_1), (predict_fun_0, predict_fun_1)\n\n\ndef predict_t_net(\n    X: jnp.ndarray,\n    trained_params: dict,\n    predict_funs: list,\n    return_po: bool = False,\n    return_prop: bool = False,\n) -> jnp.ndarray:\n    if return_prop:\n        raise NotImplementedError(\"TNet does not implement a propensity model.\")\n\n    # return CATE predictions using T-net params\n    params_0, params_1 = trained_params\n    predict_fun_0, predict_fun_1 = predict_funs\n\n    mu_0 = predict_fun_0(params_0, X)\n    mu_1 = predict_fun_1(params_1, X)\n\n    if return_po:\n        return mu_1 - mu_0, mu_0, mu_1\n    else:\n        return mu_1 - mu_0\n\n\ndef _train_tnet_jointly(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    binary_y: bool = False,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_r: int = DEFAULT_UNITS_R,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    return_val_loss: bool = False,\n    same_init: bool = True,\n    penalty_diff: float = DEFAULT_PENALTY_L2,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n) -> jnp.ndarray:\n    # input check\n    y, w = check_shape_1d_data(y), check_shape_1d_data(w)\n\n    d = X.shape[1]\n    input_shape = (-1, d)\n    rng_key = random.PRNGKey(seed)\n    onp.random.seed(seed)  # set seed for data generation via numpy as well\n\n    # get validation split (can be none)\n    X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n        X, y, w, val_split_prop=val_split_prop, seed=seed\n    )\n    n = X.shape[0]  # could be different from before due to split\n\n    # get output head functions (both heads share same structure)\n    init_fun_head, predict_fun_head = OutputHead(\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        binary_y=binary_y,\n        n_layers_r=n_layers_r,\n        n_units_r=n_units_r,\n        nonlin=nonlin,\n    )\n\n    # Define loss functions\n    # loss functions for the head\n    if not binary_y:\n\n        def loss_head(\n            params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head(params, inputs)\n            return jnp.sum(weights * ((preds - targets) ** 2))\n\n    else:\n\n        def loss_head(\n            params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]\n        ) -> jnp.ndarray:\n            # mse loss function\n            inputs, targets, weights = batch\n            preds = predict_fun_head(params, inputs)\n            return -jnp.sum(\n                weights\n                * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))\n            )\n\n    @jit\n    def loss_tnet(\n        params: List,\n        batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],\n        penalty_l2: float,\n        penalty_diff: float,\n    ) -> jnp.ndarray:\n        # params: list[representation, head_0, head_1]\n        # batch: (X, y, w)\n        X, y, w = batch\n\n        # pass down to two heads\n        loss_0 = loss_head(params[0], (X, y, 1 - w))\n        loss_1 = loss_head(params[1], (X, y, w))\n\n        # regularization\n        weightsq_head = heads_l2_penalty(\n            params[0],\n            params[1],\n            n_layers_r + n_layers_out,\n            True,\n            penalty_l2,\n            penalty_diff,\n        )\n        if not avg_objective:\n            return loss_0 + loss_1 + 0.5 * (weightsq_head)\n        else:\n            n_batch = y.shape[0]\n            return (loss_0 + loss_1) / n_batch + 0.5 * (weightsq_head)\n\n    # Define optimisation routine\n    opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)\n\n    @jit\n    def update(\n        i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_diff: float\n    ) -> jnp.ndarray:\n        # updating function\n        params = get_params(state)\n        return opt_update(\n            i, grad(loss_tnet)(params, batch, penalty_l2, penalty_diff), state\n        )\n\n    # initialise states\n    if same_init:\n        _, init_head = init_fun_head(rng_key, input_shape)\n        init_params = [init_head, init_head]\n    else:\n        rng_key, rng_key_2 = random.split(rng_key)\n        _, init_head_0 = init_fun_head(rng_key, input_shape)\n        _, init_head_1 = init_fun_head(rng_key_2, input_shape)\n        init_params = [init_head_0, init_head_1]\n\n    opt_state = opt_init(init_params)\n\n    # calculate number of batches per epoch\n    batch_size = batch_size if batch_size < n else n\n    n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1\n    train_indices = onp.arange(n)\n\n    l_best = LARGE_VAL\n    p_curr = 0\n\n    # do training\n    for i in range(n_iter):\n        # shuffle data for minibatches\n        onp.random.shuffle(train_indices)\n        for b in range(n_batches):\n            idx_next = train_indices[\n                (b * batch_size) : min((b + 1) * batch_size, n - 1)\n            ]\n            next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]\n            opt_state = update(\n                i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_diff\n            )\n\n        if (i % n_iter_print == 0) or early_stopping:\n            params_curr = get_params(opt_state)\n            l_curr = loss_tnet(\n                params_curr, (X_val, y_val, w_val), penalty_l2, penalty_diff\n            )\n\n        if i % n_iter_print == 0:\n            log.debug(f\"Epoch: {i}, current {val_string} loss {l_curr}\")\n\n        if early_stopping and ((i + 1) * n_batches > n_iter_min):\n            if l_curr < l_best:\n                l_best = l_curr\n                p_curr = 0\n                params_best = params_curr\n            else:\n                if onp.isnan(l_curr):\n                    # if diverged, return best\n                    return params_best, predict_fun_head\n                p_curr = p_curr + 1\n\n            if p_curr > patience:\n                if return_val_loss:\n                    # return loss without penalty\n                    l_final = loss_tnet(params_curr, (X_val, y_val, w_val), 0, 0)\n                    return params_curr, predict_fun_head, l_final\n\n                return params_curr, predict_fun_head\n\n    # return the parameters\n    trained_params = get_params(opt_state)\n\n    if return_val_loss:\n        # return loss without penalty\n        l_final = loss_tnet(get_params(opt_state), (X_val, y_val, w_val), 0, 0)\n        return trained_params, predict_fun_head, l_final\n\n    return trained_params, predict_fun_head\n"
  },
  {
    "path": "catenets/models/jax/transformation_utils.py",
    "content": "\"\"\"\nUtils for transformations\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Optional\n\nimport numpy as np\n\nPW_TRANSFORMATION = \"PW\"\nDR_TRANSFORMATION = \"DR\"\nRA_TRANSFORMATION = \"RA\"\n\nALL_TRANSFORMATIONS = [PW_TRANSFORMATION, DR_TRANSFORMATION, RA_TRANSFORMATION]\n\n\ndef aipw_te_transformation(\n    y: np.ndarray,\n    w: np.ndarray,\n    p: Optional[np.ndarray],\n    mu_0: np.ndarray,\n    mu_1: np.ndarray,\n) -> np.ndarray:\n    \"\"\"\n    Transforms data to efficient influence function pseudo-outcome for CATE estimation\n\n    Parameters\n    ----------\n    y : array-like of shape (n_samples,) or (n_samples, )\n        The observed outcome variable\n    w: array-like of shape (n_samples,)\n        The observed treatment indicator\n    p: array-like of shape (n_samples,)\n        The treatment propensity, estimated or known. Can be None, then p=0.5 is assumed\n    mu_0: array-like of shape (n_samples,)\n        Estimated or known potential outcome mean of the control group\n    mu_1: array-like of shape (n_samples,)\n        Estimated or known potential outcome mean of the treatment group\n\n    Returns\n    -------\n    d_hat:\n        EIF transformation for CATE\n    \"\"\"\n    if p is None:\n        # assume equal\n        p = np.full(len(y), 0.5)\n\n    w_1 = w / p\n    w_0 = (1 - w) / (1 - p)\n    return (w_1 - w_0) * y + ((1 - w_1) * mu_1 - (1 - w_0) * mu_0)\n\n\ndef ht_te_transformation(\n    y: np.ndarray,\n    w: np.ndarray,\n    p: Optional[np.ndarray] = None,\n    mu_0: Optional[np.ndarray] = None,\n    mu_1: Optional[np.ndarray] = None,\n) -> np.ndarray:\n    \"\"\"\n    Transform data to Horvitz-Thompson transformation for CATE\n\n    Parameters\n    ----------\n    y : array-like of shape (n_samples,) or (n_samples, )\n        The observed outcome variable\n    w: array-like of shape (n_samples,)\n        The observed treatment indicator\n    p: array-like of shape (n_samples,)\n        The treatment propensity, estimated or known. Can be None, then p=0.5 is assumed\n    mu_0: array-like of shape (n_samples,)\n        Placeholder, not used. Estimated or known potential outcome mean of the control group\n    mu_1: array-like of shape (n_samples,)\n        Placerholder, not used. Estimated or known potential outcome mean of the treatment group\n\n    Returns\n    -------\n    res: array-like of shape (n_samples,)\n        Horvitz-Thompson transformed data\n    \"\"\"\n    if p is None:\n        # assume equal propensities\n        p = np.full(len(y), 0.5)\n    return (w / p - (1 - w) / (1 - p)) * y\n\n\ndef ra_te_transformation(\n    y: np.ndarray,\n    w: np.ndarray,\n    p: Optional[np.ndarray],\n    mu_0: np.ndarray,\n    mu_1: np.ndarray,\n) -> np.ndarray:\n    \"\"\"\n    Transform data to regression adjustment for CATE\n\n    Parameters\n    ----------\n    y : array-like of shape (n_samples,) or (n_samples, )\n        The observed outcome variable\n    w: array-like of shape (n_samples,)\n        The observed treatment indicator\n    p: array-like of shape (n_samples,)\n        Placeholder, not used. The treatment propensity, estimated or known.\n    mu_0: array-like of shape (n_samples,)\n         Estimated or known potential outcome mean of the control group\n    mu_1: array-like of shape (n_samples,)\n        Estimated or known potential outcome mean of the treatment group\n\n    Returns\n    -------\n    res: array-like of shape (n_samples,)\n        Regression adjusted transformation\n    \"\"\"\n    return w * (y - mu_0) + (1 - w) * (mu_1 - y)\n\n\nTRANSFORMATION_DICT = {\n    PW_TRANSFORMATION: ht_te_transformation,\n    RA_TRANSFORMATION: ra_te_transformation,\n    DR_TRANSFORMATION: aipw_te_transformation,\n}\n\n\ndef _get_transformation_function(transformation_name: str) -> Any:\n    \"\"\"\n    Get transformation function associated with a name\n    \"\"\"\n    if transformation_name not in ALL_TRANSFORMATIONS:\n        raise ValueError(\n            \"Parameter first stage should be in \"\n            \"catenets.models.transformations.ALL_TRANSFORMATIONS.\"\n            \" You passed {}\".format(transformation_name)\n        )\n    return TRANSFORMATION_DICT[transformation_name]\n"
  },
  {
    "path": "catenets/models/jax/xnet.py",
    "content": "\"\"\"\nModule implements X-learner from Kuenzel et al (2019) using NNs\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Callable, Optional, Tuple\n\nimport jax.numpy as jnp\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_AVG_OBJECTIVE,\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_OUT_T,\n    DEFAULT_LAYERS_R,\n    DEFAULT_LAYERS_R_T,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_STEP_SIZE_T,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_OUT_T,\n    DEFAULT_UNITS_R,\n    DEFAULT_UNITS_R_T,\n    DEFAULT_VAL_SPLIT,\n)\nfrom catenets.models.jax.base import BaseCATENet, train_output_net_only\nfrom catenets.models.jax.model_utils import check_shape_1d_data, check_X_is_np\nfrom catenets.models.jax.pseudo_outcome_nets import (  # same strategies as other nets\n    ALL_STRATEGIES,\n    FLEX_STRATEGY,\n    OFFSET_STRATEGY,\n    S1_STRATEGY,\n    S2_STRATEGY,\n    S3_STRATEGY,\n    S_STRATEGY,\n    T_STRATEGY,\n    predict_flextenet,\n    predict_offsetnet,\n    predict_snet,\n    predict_snet1,\n    predict_snet2,\n    predict_snet3,\n    predict_t_net,\n    train_flextenet,\n    train_offsetnet,\n    train_snet,\n    train_snet1,\n    train_snet2,\n    train_snet3,\n    train_tnet,\n)\n\n\nclass XNet(BaseCATENet):\n    \"\"\"\n    Class implements X-learner using NNs.\n\n    Parameters\n    ----------\n    weight_strategy: int, default None\n        Which strategy to use to weight the two CATE estimators in the second stage. weight_strategy\n        is coded as follows: for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)]\n        weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1,\n        weight_strategy=None sets g(x)=pi(x) [propensity score],\n         weight_strategy=-1 sets g(x)=(1-pi(x))\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out: int\n        First stage Number of hidden units in each hypothesis layer\n    n_layers_r: int\n        First stage Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_r: int\n        First stage Number of hidden units in each representation layer\n    n_layers_out_t: int\n        Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out_t: int\n        Second stage Number of hidden units in each hypothesis layer\n    n_layers_r_t: int\n        Second stage Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_r_t: int\n        Second stage Number of hidden units in each representation layer\n    penalty_l2: float\n        First stage l2 (ridge) penalty\n    penalty_l2_t: float\n        Second stage l2 (ridge) penalty\n    step_size: float\n        First stage learning rate for optimizer\n    step_size_t: float\n        Second stage learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN\n    \"\"\"\n\n    def __init__(\n        self,\n        weight_strategy: Optional[int] = None,\n        first_stage_strategy: str = T_STRATEGY,\n        first_stage_args: Optional[dict] = None,\n        binary_y: bool = False,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n        n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n        n_units_r_t: int = DEFAULT_UNITS_R_T,\n        penalty_l2: float = DEFAULT_PENALTY_L2,\n        penalty_l2_t: float = DEFAULT_PENALTY_L2,\n        step_size: float = DEFAULT_STEP_SIZE,\n        step_size_t: float = DEFAULT_STEP_SIZE_T,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n    ):\n        # settings\n        self.weight_strategy = weight_strategy\n        self.first_stage_strategy = first_stage_strategy\n        self.first_stage_args = first_stage_args\n        self.binary_y = binary_y\n\n        # model architecture hyperparams\n        self.n_layers_out = n_layers_out\n        self.n_layers_out_t = n_layers_out_t\n        self.n_layers_r = n_layers_r\n        self.n_layers_r_t = n_layers_r_t\n        self.n_units_out = n_units_out\n        self.n_units_out_t = n_units_out_t\n        self.n_units_r = n_units_r\n        self.n_units_r_t = n_units_r_t\n        self.nonlin = nonlin\n\n        # other hyperparameters\n        self.penalty_l2 = penalty_l2\n        self.penalty_l2_t = penalty_l2_t\n        self.step_size = step_size\n        self.step_size_t = step_size_t\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.n_iter_print = n_iter_print\n        self.seed = seed\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n\n    def _get_train_function(self) -> Callable:\n        return train_x_net\n\n    def _get_predict_function(self) -> Callable:\n        # Two step nets do not need this\n        return predict_x_net\n\n    def predict(\n        self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False\n    ) -> jnp.ndarray:\n        \"\"\"\n        Predict treatment effect estimates using a CATENet. Depending on method, can also return\n        potential outcome estimate and propensity score estimate.\n\n        Parameters\n        ----------\n        X: pd.DataFrame or np.array\n            Covariate matrix\n        return_po: bool, default False\n            Whether to return potential outcome estimate\n        return_prop: bool, default False\n            Whether to return propensity estimate\n\n        Returns\n        -------\n        array of CATE estimates, optionally also potential outcomes and propensity\n        \"\"\"\n        X = check_X_is_np(X)\n        predict_func = self._get_predict_function()\n        return predict_func(\n            X,\n            trained_params=self._params,\n            predict_funs=self._predict_funs,\n            return_po=return_po,\n            return_prop=return_prop,\n            weight_strategy=self.weight_strategy,\n        )\n\n\ndef train_x_net(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    weight_strategy: Optional[int] = None,\n    first_stage_strategy: str = T_STRATEGY,\n    first_stage_args: Optional[dict] = None,\n    binary_y: bool = False,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n    n_layers_r_t: int = DEFAULT_LAYERS_R_T,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_units_r: int = DEFAULT_UNITS_R,\n    n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n    n_units_r_t: int = DEFAULT_UNITS_R_T,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    penalty_l2_t: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    step_size_t: float = DEFAULT_STEP_SIZE_T,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    nonlin: str = DEFAULT_NONLIN,\n    return_val_loss: bool = False,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n) -> Tuple:\n    y = check_shape_1d_data(y)\n    if len(w.shape) > 1:\n        w = w.reshape((len(w),))\n\n    if weight_strategy not in [0, 1, -1, None]:\n        # weight_strategy is coded as follows:\n        # for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)]\n        # weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1,\n        # weight_strategy=None sets g(x)=pi(x) [propensity score],\n        # weight_strategy=-1 sets g(x)=(1-pi(x))\n        raise ValueError(\"XNet only implements weight_strategy in [0, 1, -1, None]\")\n\n    if first_stage_strategy not in ALL_STRATEGIES:\n        raise ValueError(\n            \"Parameter first stage should be in \"\n            \"catenets.models.twostep_nets.ALL_STRATEGIES. \"\n            \"You passed {}\".format(first_stage_strategy)\n        )\n\n    # first stage: get estimates of PO regression\n    log.debug(\"Training first stage\")\n\n    mu_hat_0, mu_hat_1 = _get_first_stage_pos(\n        X,\n        y,\n        w,\n        binary_y=binary_y,\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        n_layers_r=n_layers_r,\n        n_units_r=n_units_r,\n        penalty_l2=penalty_l2,\n        step_size=step_size,\n        n_iter=n_iter,\n        batch_size=batch_size,\n        val_split_prop=val_split_prop,\n        early_stopping=early_stopping,\n        patience=patience,\n        n_iter_min=n_iter_min,\n        n_iter_print=n_iter_print,\n        seed=seed,\n        nonlin=nonlin,\n        avg_objective=avg_objective,\n        first_stage_strategy=first_stage_strategy,\n        first_stage_args=first_stage_args,\n    )\n\n    if weight_strategy is None or weight_strategy == -1:\n        # also fit propensity estimator\n        log.debug(\"Training propensity net\")\n        params_prop, predict_fun_prop = train_output_net_only(\n            X,\n            w,\n            binary_y=True,\n            n_layers_out=n_layers_out,\n            n_units_out=n_units_out,\n            n_layers_r=n_layers_r,\n            n_units_r=n_units_r,\n            penalty_l2=penalty_l2,\n            step_size=step_size,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            nonlin=nonlin,\n            avg_objective=avg_objective,\n        )\n\n    else:\n        params_prop, predict_fun_prop = None, None\n\n    # second stage\n    log.debug(\"Training second stage\")\n    if not weight_strategy == 0:\n        # fit tau_0\n        log.debug(\"Fitting tau_0\")\n        pseudo_outcome0 = mu_hat_1 - y[w == 0]\n        params_tau0, predict_fun_tau0 = train_output_net_only(\n            X[w == 0],\n            pseudo_outcome0,\n            binary_y=False,\n            n_layers_out=n_layers_out_t,\n            n_units_out=n_units_out_t,\n            n_layers_r=n_layers_r_t,\n            n_units_r=n_units_r_t,\n            penalty_l2=penalty_l2_t,\n            step_size=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            nonlin=nonlin,\n            avg_objective=avg_objective,\n        )\n    else:\n        params_tau0, predict_fun_tau0 = None, None\n\n    if not weight_strategy == 1:\n        # fit tau_1\n        log.debug(\"Fitting tau_1\")\n        pseudo_outcome1 = y[w == 1] - mu_hat_0\n        params_tau1, predict_fun_tau1 = train_output_net_only(\n            X[w == 1],\n            pseudo_outcome1,\n            binary_y=False,\n            n_layers_out=n_layers_out_t,\n            n_units_out=n_units_out_t,\n            n_layers_r=n_layers_r_t,\n            n_units_r=n_units_r_t,\n            penalty_l2=penalty_l2_t,\n            step_size=step_size_t,\n            n_iter=n_iter,\n            batch_size=batch_size,\n            val_split_prop=val_split_prop,\n            early_stopping=early_stopping,\n            patience=patience,\n            n_iter_min=n_iter_min,\n            n_iter_print=n_iter_print,\n            seed=seed,\n            return_val_loss=return_val_loss,\n            nonlin=nonlin,\n            avg_objective=avg_objective,\n        )\n\n    else:\n        params_tau1, predict_fun_tau1 = None, None\n\n    params = params_tau0, params_tau1, params_prop\n    predict_funs = predict_fun_tau0, predict_fun_tau1, predict_fun_prop\n\n    return params, predict_funs\n\n\ndef _get_first_stage_pos(\n    X: jnp.ndarray,\n    y: jnp.ndarray,\n    w: jnp.ndarray,\n    first_stage_strategy: str = T_STRATEGY,\n    first_stage_args: Optional[dict] = None,\n    binary_y: bool = False,\n    n_layers_out: int = DEFAULT_LAYERS_OUT,\n    n_layers_r: int = DEFAULT_LAYERS_R,\n    n_units_out: int = DEFAULT_UNITS_OUT,\n    n_units_r: int = DEFAULT_UNITS_R,\n    penalty_l2: float = DEFAULT_PENALTY_L2,\n    step_size: float = DEFAULT_STEP_SIZE,\n    n_iter: int = DEFAULT_N_ITER,\n    batch_size: int = DEFAULT_BATCH_SIZE,\n    n_iter_min: int = DEFAULT_N_ITER_MIN,\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\n    early_stopping: bool = True,\n    patience: int = DEFAULT_PATIENCE,\n    n_iter_print: int = DEFAULT_N_ITER_PRINT,\n    seed: int = DEFAULT_SEED,\n    nonlin: str = DEFAULT_NONLIN,\n    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,\n) -> Tuple[jnp.ndarray, jnp.ndarray]:\n    if first_stage_args is None:\n        first_stage_args = {}\n\n    train_fun: Callable\n    predict_fun: Callable\n\n    if first_stage_strategy == T_STRATEGY:\n        train_fun, predict_fun = train_tnet, predict_t_net\n    elif first_stage_strategy == S_STRATEGY:\n        train_fun, predict_fun = train_snet, predict_snet\n    elif first_stage_strategy == S1_STRATEGY:\n        train_fun, predict_fun = train_snet1, predict_snet1\n    elif first_stage_strategy == S2_STRATEGY:\n        train_fun, predict_fun = train_snet2, predict_snet2\n    elif first_stage_strategy == S3_STRATEGY:\n        train_fun, predict_fun = train_snet3, predict_snet3\n    elif first_stage_strategy == OFFSET_STRATEGY:\n        train_fun, predict_fun = train_offsetnet, predict_offsetnet\n    elif first_stage_strategy == FLEX_STRATEGY:\n        train_fun, predict_fun = train_flextenet, predict_flextenet\n\n    trained_params, pred_fun = train_fun(\n        X,\n        y,\n        w,\n        binary_y=binary_y,\n        n_layers_r=n_layers_r,\n        n_units_r=n_units_r,\n        n_layers_out=n_layers_out,\n        n_units_out=n_units_out,\n        penalty_l2=penalty_l2,\n        step_size=step_size,\n        n_iter=n_iter,\n        batch_size=batch_size,\n        val_split_prop=val_split_prop,\n        early_stopping=early_stopping,\n        patience=patience,\n        n_iter_min=n_iter_min,\n        n_iter_print=n_iter_print,\n        seed=seed,\n        nonlin=nonlin,\n        avg_objective=avg_objective,\n        **first_stage_args\n    )\n\n    _, mu_0, mu_1 = predict_fun(X, trained_params, pred_fun, return_po=True)\n\n    return mu_0[w == 1], mu_1[w == 0]\n\n\ndef predict_x_net(\n    X: jnp.ndarray,\n    trained_params: dict,\n    predict_funs: list,\n    return_po: bool = False,\n    return_prop: bool = False,\n    weight_strategy: Optional[int] = None,\n) -> jnp.ndarray:\n    if return_po:\n        raise NotImplementedError(\"TwoStepNets have no Potential outcome predictors.\")\n\n    if return_prop:\n        raise NotImplementedError(\"TwoStepNets have no Propensity predictors.\")\n\n    params_tau0, params_tau1, params_prop = trained_params\n    predict_fun_tau0, predict_fun_tau1, predict_fun_prop = predict_funs\n\n    tau0_pred: jnp.ndarray\n    tau1_pred: jnp.ndarray\n\n    if not weight_strategy == 0:\n        tau0_pred = predict_fun_tau0(params_tau0, X)\n    else:\n        tau0_pred = 0\n\n    if not weight_strategy == 1:\n        tau1_pred = predict_fun_tau1(params_tau1, X)\n    else:\n        tau1_pred = 0\n\n    if weight_strategy is None or weight_strategy == -1:\n        prop_pred = predict_fun_prop(params_prop, X)\n\n    if weight_strategy is None:\n        weight = prop_pred\n    elif weight_strategy == -1:\n        weight = 1 - prop_pred\n    elif weight_strategy == 0:\n        weight = 0\n    elif weight_strategy == 1:\n        weight = 1\n\n    return weight * tau0_pred + (1 - weight) * tau1_pred\n"
  },
  {
    "path": "catenets/models/torch/__init__.py",
    "content": "\"\"\"\nPyTorch-based implementations for the CATE estimators.\n\"\"\"\nfrom .flextenet import FlexTENet\nfrom .pseudo_outcome_nets import (\n    DRLearner,\n    PWLearner,\n    RALearner,\n    RLearner,\n    ULearner,\n    XLearner,\n)\nfrom .representation_nets import DragonNet, TARNet\nfrom .slearner import SLearner\nfrom .snet import SNet\nfrom .tlearner import TLearner\n\n__all__ = [\n    \"TLearner\",\n    \"SLearner\",\n    \"TARNet\",\n    \"DragonNet\",\n    \"XLearner\",\n    \"RLearner\",\n    \"ULearner\",\n    \"RALearner\",\n    \"PWLearner\",\n    \"DRLearner\",\n    \"SNet\",\n    \"FlexTENet\",\n]\n"
  },
  {
    "path": "catenets/models/torch/base.py",
    "content": "import abc\r\nfrom typing import Optional\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\n\r\nimport catenets.logger as log\r\nfrom catenets.models.constants import (\r\n    DEFAULT_BATCH_SIZE,\r\n    DEFAULT_LAYERS_OUT,\r\n    DEFAULT_LAYERS_R,\r\n    DEFAULT_N_ITER,\r\n    DEFAULT_N_ITER_MIN,\r\n    DEFAULT_N_ITER_PRINT,\r\n    DEFAULT_NONLIN,\r\n    DEFAULT_PATIENCE,\r\n    DEFAULT_PENALTY_L2,\r\n    DEFAULT_SEED,\r\n    DEFAULT_STEP_SIZE,\r\n    DEFAULT_UNITS_OUT,\r\n    DEFAULT_UNITS_R,\r\n    DEFAULT_VAL_SPLIT,\r\n    LARGE_VAL,\r\n)\r\nfrom catenets.models.torch.utils.decorators import benchmark, check_input_train\r\nfrom catenets.models.torch.utils.model_utils import make_val_split\r\nfrom catenets.models.torch.utils.weight_utils import compute_importance_weights\r\n\r\nDEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n\r\nEPS = 1e-8\r\n\r\nNONLIN = {\r\n    \"elu\": nn.ELU,\r\n    \"relu\": nn.ReLU,\r\n    \"leaky_relu\": nn.LeakyReLU,\r\n    \"selu\": nn.SELU,\r\n    \"sigmoid\": nn.Sigmoid,\r\n}\r\n\r\n\r\nclass BasicNet(nn.Module):\r\n    \"\"\"\r\n    Basic hypothesis neural net.\r\n\r\n    Parameters\r\n    ----------\r\n    n_unit_in: int\r\n        Number of features\r\n    n_layers_out: int\r\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)\r\n    n_units_out: int\r\n        Number of hidden units in each hypothesis layer\r\n    binary_y: bool, default False\r\n        Whether the outcome is binary. Impacts the loss function.\r\n    nonlin: string, default 'elu'\r\n        Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\r\n    lr: float\r\n        learning rate for optimizer. step_size equivalent in the JAX version.\r\n    weight_decay: float\r\n        l2 (ridge) penalty for the weights.\r\n    n_iter: int\r\n        Maximum number of iterations.\r\n    batch_size: int\r\n        Batch size\r\n    n_iter_print: int\r\n        Number of iterations after which to print updates and check the validation loss.\r\n    seed: int\r\n        Seed used\r\n    val_split_prop: float\r\n        Proportion of samples used for validation split (can be 0)\r\n    patience: int\r\n        Number of iterations to wait before early stopping after decrease in validation loss\r\n    n_iter_min: int\r\n        Minimum number of iterations to go through before starting early stopping\r\n    clipping_value: int, default 1\r\n        Gradients clipping value\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        name: str,\r\n        n_unit_in: int,\r\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\r\n        n_units_out: int = DEFAULT_UNITS_OUT,\r\n        binary_y: bool = False,\r\n        nonlin: str = DEFAULT_NONLIN,\r\n        lr: float = DEFAULT_STEP_SIZE,\r\n        weight_decay: float = DEFAULT_PENALTY_L2,\r\n        n_iter: int = DEFAULT_N_ITER,\r\n        batch_size: int = DEFAULT_BATCH_SIZE,\r\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\r\n        seed: int = DEFAULT_SEED,\r\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\r\n        patience: int = DEFAULT_PATIENCE,\r\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\r\n        clipping_value: int = 1,\r\n        batch_norm: bool = True,\r\n        early_stopping: bool = True,\r\n        dropout: bool = False,\r\n        dropout_prob: float = 0.2,\r\n    ) -> None:\r\n        super(BasicNet, self).__init__()\r\n\r\n        self.name = name\r\n        if nonlin not in list(NONLIN.keys()):\r\n            raise ValueError(\"Unknown nonlinearity\")\r\n\r\n        NL = NONLIN[nonlin]\r\n\r\n        if n_layers_out > 0:\r\n            if batch_norm:\r\n                layers = [\r\n                    nn.Linear(n_unit_in, n_units_out),\r\n                    nn.BatchNorm1d(n_units_out),\r\n                    NL(),\r\n                ]\r\n            else:\r\n                layers = [nn.Linear(n_unit_in, n_units_out), NL()]\r\n\r\n            # add required number of layers\r\n            for i in range(n_layers_out - 1):\r\n                if dropout:\r\n                    layers.extend([nn.Dropout(dropout_prob)])\r\n                if batch_norm:\r\n                    layers.extend(\r\n                        [\r\n                            nn.Linear(n_units_out, n_units_out),\r\n                            nn.BatchNorm1d(n_units_out),\r\n                            NL(),\r\n                        ]\r\n                    )\r\n                else:\r\n                    layers.extend(\r\n                        [\r\n                            nn.Linear(n_units_out, n_units_out),\r\n                            NL(),\r\n                        ]\r\n                    )\r\n\r\n            # add final layers\r\n            layers.append(nn.Linear(n_units_out, 1))\r\n        else:\r\n            layers = [nn.Linear(n_unit_in, 1)]\r\n\r\n        if binary_y:\r\n            layers.append(nn.Sigmoid())\r\n\r\n        # return final architecture\r\n        self.model = nn.Sequential(*layers).to(DEVICE)\r\n        self.binary_y = binary_y\r\n\r\n        self.n_iter = n_iter\r\n        self.batch_size = batch_size\r\n        self.n_iter_print = n_iter_print\r\n        self.seed = seed\r\n        self.val_split_prop = val_split_prop\r\n        self.patience = patience\r\n        self.n_iter_min = n_iter_min\r\n        self.clipping_value = clipping_value\r\n        self.early_stopping = early_stopping\r\n\r\n        self.optimizer = torch.optim.Adam(\r\n            self.parameters(), lr=lr, weight_decay=weight_decay\r\n        )\r\n\r\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\r\n        return self.model(X)\r\n\r\n    def fit(\r\n        self, X: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] = None\r\n    ) -> \"BasicNet\":\r\n        self.train()\r\n\r\n        X = self._check_tensor(X)\r\n        y = self._check_tensor(y).squeeze()\r\n\r\n        # get validation split (can be none)\r\n        X, y, X_val, y_val, val_string = make_val_split(\r\n            X, y, val_split_prop=self.val_split_prop, seed=self.seed\r\n        )\r\n        y_val = y_val.squeeze()\r\n        n = X.shape[0]  # could be different from before due to split\r\n\r\n        # calculate number of batches per epoch\r\n        batch_size = self.batch_size if self.batch_size < n else n\r\n        n_batches = int(np.round(n / batch_size)) if batch_size < n else 1\r\n        train_indices = np.arange(n)\r\n\r\n        # do training\r\n        val_loss_best = LARGE_VAL\r\n        patience = 0\r\n        for i in range(self.n_iter):\r\n            # shuffle data for minibatches\r\n            np.random.shuffle(train_indices)\r\n            train_loss = []\r\n            for b in range(n_batches):\r\n                self.optimizer.zero_grad()\r\n\r\n                idx_next = train_indices[\r\n                    (b * batch_size) : min((b + 1) * batch_size, n - 1)\r\n                ]\r\n\r\n                X_next = X[idx_next]\r\n                y_next = y[idx_next]\r\n\r\n                weight_next = None\r\n                if weight is not None:\r\n                    weight_next = weight[idx_next].detach()\r\n\r\n                loss = nn.BCELoss(weight=weight_next) if self.binary_y else nn.MSELoss()\r\n\r\n                preds = self.forward(X_next).squeeze()\r\n\r\n                batch_loss = loss(preds, y_next)\r\n\r\n                batch_loss.backward()\r\n\r\n                torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)\r\n\r\n                self.optimizer.step()\r\n\r\n                train_loss.append(batch_loss.detach())\r\n\r\n            train_loss = torch.Tensor(train_loss).to(DEVICE)\r\n\r\n            if self.early_stopping or i % self.n_iter_print == 0:\r\n                loss = nn.BCELoss() if self.binary_y else nn.MSELoss()\r\n                with torch.no_grad():\r\n                    preds = self.forward(X_val).squeeze()\r\n                    val_loss = loss(preds, y_val)\r\n\r\n                    if self.early_stopping:\r\n                        if val_loss_best > val_loss:\r\n                            val_loss_best = val_loss\r\n                            patience = 0\r\n                        else:\r\n                            patience += 1\r\n\r\n                        if patience > self.patience and i > self.n_iter_min:\r\n                            break\r\n\r\n                    if i % self.n_iter_print == 0:\r\n                        log.info(\r\n                            f\"[{self.name}] Epoch: {i}, current {val_string} loss: {val_loss}, train_loss: {torch.mean(train_loss)}\"\r\n                        )\r\n\r\n        return self\r\n\r\n    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:\r\n        if isinstance(X, torch.Tensor):\r\n            return X.to(DEVICE)\r\n        else:\r\n            return torch.from_numpy(np.asarray(X)).to(DEVICE)\r\n\r\n\r\nclass RepresentationNet(nn.Module):\r\n    \"\"\"\r\n    Basic representation neural net\r\n\r\n    Parameters\r\n    ----------\r\n    n_unit_in: int\r\n        Number of features\r\n    n_layers: int\r\n        Number of shared representation layers before hypothesis layers\r\n    n_units: int\r\n        Number of hidden units in each representation layer\r\n    nonlin: string, default 'elu'\r\n        Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        n_unit_in: int,\r\n        n_layers: int = DEFAULT_LAYERS_R,\r\n        n_units: int = DEFAULT_UNITS_R,\r\n        nonlin: str = DEFAULT_NONLIN,\r\n        batch_norm: bool = True,\r\n    ) -> None:\r\n        super(RepresentationNet, self).__init__()\r\n        if nonlin not in list(NONLIN.keys()):\r\n            raise ValueError(\"Unknown nonlinearity\")\r\n\r\n        NL = NONLIN[nonlin]\r\n\r\n        if batch_norm:\r\n            layers = [nn.Linear(n_unit_in, n_units), nn.BatchNorm1d(n_units), NL()]\r\n        else:\r\n            layers = [nn.Linear(n_unit_in, n_units), NL()]\r\n        # add required number of layers\r\n        for i in range(n_layers - 1):\r\n            if batch_norm:\r\n                layers.extend(\r\n                    [nn.Linear(n_units, n_units), nn.BatchNorm1d(n_units), NL()]\r\n                )\r\n            else:\r\n                layers.extend([nn.Linear(n_units, n_units), NL()])\r\n\r\n        self.model = nn.Sequential(*layers).to(DEVICE)\r\n\r\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\r\n        return self.model(X)\r\n\r\n\r\nclass PropensityNet(nn.Module):\r\n    \"\"\"\r\n    Basic propensity neural net\r\n\r\n    Parameters\r\n    ----------\r\n    name: str\r\n        Display name\r\n    n_unit_in: int\r\n        Number of features\r\n    n_unit_out: int\r\n        Number of output features\r\n    weighting_strategy: str\r\n        Weighting strategy\r\n    n_units_out_prop: int\r\n        Number of hidden units in each propensity score hypothesis layer\r\n    n_layers_out_prop: int\r\n        Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense\r\n        layer)\r\n    nonlin: string, default 'elu'\r\n        Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\r\n    lr: float\r\n        learning rate for optimizer. step_size equivalent in the JAX version.\r\n    weight_decay: float\r\n        l2 (ridge) penalty for the weights.\r\n    n_iter: int\r\n        Maximum number of iterations.\r\n    batch_size: int\r\n        Batch size\r\n    n_iter_print: int\r\n        Number of iterations after which to print updates and check the validation loss.\r\n    seed: int\r\n        Seed used\r\n    val_split_prop: float\r\n        Proportion of samples used for validation split (can be 0)\r\n    patience: int\r\n        Number of iterations to wait before early stopping after decrease in validation loss\r\n    n_iter_min: int\r\n        Minimum number of iterations to go through before starting early stopping\r\n    clipping_value: int, default 1\r\n        Gradients clipping value\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        name: str,\r\n        n_unit_in: int,\r\n        n_unit_out: int,\r\n        weighting_strategy: str,\r\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\r\n        n_layers_out_prop: int = 0,\r\n        nonlin: str = DEFAULT_NONLIN,\r\n        lr: float = DEFAULT_STEP_SIZE,\r\n        weight_decay: float = DEFAULT_PENALTY_L2,\r\n        n_iter: int = DEFAULT_N_ITER,\r\n        batch_size: int = DEFAULT_BATCH_SIZE,\r\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\r\n        seed: int = DEFAULT_SEED,\r\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\r\n        patience: int = DEFAULT_PATIENCE,\r\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\r\n        clipping_value: int = 1,\r\n        batch_norm: bool = True,\r\n        early_stopping: bool = True,\r\n        dropout: bool = False,\r\n        dropout_prob: float = 0.2,\r\n    ) -> None:\r\n        super(PropensityNet, self).__init__()\r\n        if nonlin not in list(NONLIN.keys()):\r\n            raise ValueError(\"Unknown nonlinearity\")\r\n\r\n        NL = NONLIN[nonlin]\r\n\r\n        if batch_norm:\r\n            layers = [\r\n                nn.Linear(in_features=n_unit_in, out_features=n_units_out_prop),\r\n                nn.BatchNorm1d(n_units_out_prop),\r\n                NL(),\r\n            ]\r\n        else:\r\n            layers = [\r\n                nn.Linear(in_features=n_unit_in, out_features=n_units_out_prop),\r\n                NL(),\r\n            ]\r\n\r\n        for i in range(n_layers_out_prop - 1):\r\n            if dropout:\r\n                layers.extend([nn.Dropout(dropout_prob)])\r\n            if batch_norm:\r\n                layers.extend(\r\n                    [\r\n                        nn.Linear(\r\n                            in_features=n_units_out_prop, out_features=n_units_out_prop\r\n                        ),\r\n                        nn.BatchNorm1d(n_units_out_prop),\r\n                        NL(),\r\n                    ]\r\n                )\r\n            else:\r\n                layers.extend(\r\n                    [\r\n                        nn.Linear(\r\n                            in_features=n_units_out_prop, out_features=n_units_out_prop\r\n                        ),\r\n                        NL(),\r\n                    ]\r\n                )\r\n        layers.extend(\r\n            [\r\n                nn.Linear(in_features=n_units_out_prop, out_features=n_unit_out),\r\n                nn.Softmax(dim=-1),\r\n            ]\r\n        )\r\n\r\n        self.model = nn.Sequential(*layers).to(DEVICE)\r\n        self.name = name\r\n        self.weighting_strategy = weighting_strategy\r\n        self.n_iter = n_iter\r\n        self.batch_size = batch_size\r\n        self.n_iter_print = n_iter_print\r\n        self.seed = seed\r\n        self.val_split_prop = val_split_prop\r\n        self.patience = patience\r\n        self.n_iter_min = n_iter_min\r\n        self.clipping_value = clipping_value\r\n        self.early_stopping = early_stopping\r\n\r\n        self.optimizer = torch.optim.Adam(\r\n            self.parameters(), lr=lr, weight_decay=weight_decay\r\n        )\r\n\r\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\r\n        return self.model(X)\r\n\r\n    def get_importance_weights(\r\n        self, X: torch.Tensor, w: Optional[torch.Tensor] = None\r\n    ) -> torch.Tensor:\r\n        p_pred = self.forward(X).squeeze()[:, 1]\r\n        return compute_importance_weights(p_pred, w, self.weighting_strategy, {})\r\n\r\n    def loss(self, y_pred: torch.Tensor, y_target: torch.Tensor) -> torch.Tensor:\r\n        return nn.NLLLoss()(torch.log(y_pred + EPS), y_target)\r\n\r\n    def fit(self, X: torch.Tensor, y: torch.Tensor) -> \"PropensityNet\":\r\n        self.train()\r\n\r\n        X = self._check_tensor(X)\r\n        y = self._check_tensor(y).long()\r\n\r\n        # get validation split (can be none)\r\n        X, y, X_val, y_val, val_string = make_val_split(\r\n            X, y, val_split_prop=self.val_split_prop, seed=self.seed\r\n        )\r\n        y_val = y_val.squeeze()\r\n        n = X.shape[0]  # could be different from before due to split\r\n\r\n        # calculate number of batches per epoch\r\n        batch_size = self.batch_size if self.batch_size < n else n\r\n        n_batches = int(np.round(n / batch_size)) if batch_size < n else 1\r\n        train_indices = np.arange(n)\r\n\r\n        # do training\r\n        val_loss_best = LARGE_VAL\r\n        patience = 0\r\n        for i in range(self.n_iter):\r\n            # shuffle data for minibatches\r\n            np.random.shuffle(train_indices)\r\n            train_loss = []\r\n            for b in range(n_batches):\r\n                self.optimizer.zero_grad()\r\n\r\n                idx_next = train_indices[\r\n                    (b * batch_size) : min((b + 1) * batch_size, n - 1)\r\n                ]\r\n\r\n                X_next = X[idx_next]\r\n                y_next = y[idx_next].squeeze()\r\n\r\n                preds = self.forward(X_next).squeeze()\r\n\r\n                batch_loss = self.loss(preds, y_next)\r\n\r\n                batch_loss.backward()\r\n\r\n                torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)\r\n\r\n                self.optimizer.step()\r\n                train_loss.append(batch_loss.detach())\r\n\r\n            train_loss = torch.Tensor(train_loss).to(DEVICE)\r\n\r\n            if self.early_stopping or i % self.n_iter_print == 0:\r\n                with torch.no_grad():\r\n                    preds = self.forward(X_val).squeeze()\r\n                    val_loss = self.loss(preds, y_val)\r\n\r\n                    if self.early_stopping:\r\n                        if val_loss_best > val_loss:\r\n                            val_loss_best = val_loss\r\n                            patience = 0\r\n                        else:\r\n                            patience += 1\r\n                        if patience > self.patience and (\r\n                            (i + 1) * n_batches > self.n_iter_min\r\n                        ):\r\n                            break\r\n                    if i % self.n_iter_print == 0:\r\n                        log.info(\r\n                            f\"[{self.name}] Epoch: {i}, current {val_string} loss: {val_loss}, train_loss: {torch.mean(train_loss)}\"\r\n                        )\r\n\r\n        return self\r\n\r\n    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:\r\n        if isinstance(X, torch.Tensor):\r\n            return X.to(DEVICE)\r\n        else:\r\n            return torch.from_numpy(np.asarray(X)).to(DEVICE)\r\n\r\n\r\nclass BaseCATEEstimator(nn.Module):\r\n    \"\"\"\r\n    Interface for estimators of CATE.\r\n\r\n    The interface has train/forward API for PyTorch-based models and fit/predict API for sklearn-based models.\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n    ) -> None:\r\n        super(BaseCATEEstimator, self).__init__()\r\n\r\n    def score(\r\n        self,\r\n        X: torch.Tensor,\r\n        y: torch.Tensor,\r\n    ) -> float:\r\n        \"\"\"\r\n        Return the sqrt PEHE error (oracle metric).\r\n\r\n        Parameters\r\n        ----------\r\n        X: torch.Tensor\r\n            Covariate matrix\r\n        y: torch.Tensor\r\n            Expected potential outcome vector\r\n        \"\"\"\r\n        X = self._check_tensor(X)\r\n        y = self._check_tensor(y)\r\n        if len(X) != len(y):\r\n            raise ValueError(\"X/y length mismatch for score\")\r\n        if y.shape[-1] != 2:\r\n            raise ValueError(f\"y has invalid shape {y.shape}\")\r\n\r\n        hat_te = self.predict(X)\r\n\r\n        return torch.sqrt(torch.mean(((y[:, 1] - y[:, 0]) - hat_te) ** 2))\r\n\r\n    @abc.abstractmethod\r\n    @check_input_train\r\n    @benchmark\r\n    def fit(\r\n        self,\r\n        X: torch.Tensor,\r\n        y: torch.Tensor,\r\n        w: torch.Tensor,\r\n    ) -> \"BaseCATEEstimator\":\r\n        \"\"\"\r\n        Train method for a CATEModel\r\n\r\n        Parameters\r\n        ----------\r\n        X: torch.Tensor\r\n            Covariate matrix\r\n        y: torch.Tensor\r\n            Outcome vector\r\n        w: torch.Tensor\r\n            Treatment indicator\r\n        \"\"\"\r\n        ...\r\n\r\n    @benchmark\r\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\r\n        \"\"\"\r\n        Predict treatment effect estimates using a CATE estimator.\r\n\r\n        Parameters\r\n        ----------\r\n        X: pd.DataFrame or np.array\r\n            Covariate matrix\r\n        Returns\r\n        -------\r\n        potential outcomes probabilities\r\n        \"\"\"\r\n        return self.predict(X, return_po=False, training=True)\r\n\r\n    @abc.abstractmethod\r\n    @benchmark\r\n    def predict(\r\n        self, X: torch.Tensor, return_po: bool = False, training: bool = False\r\n    ) -> torch.Tensor:\r\n        \"\"\"\r\n        Predict treatment effect estimates using a CATE estimator.\r\n\r\n        Parameters\r\n        ----------\r\n        X: pd.DataFrame or np.array\r\n            Covariate matrix\r\n        return_po: bool, optional\r\n            Return the potential outcomes too\r\n        Returns\r\n        -------\r\n        potential outcomes probabilities\r\n        \"\"\"\r\n        ...\r\n\r\n    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:\r\n        if isinstance(X, torch.Tensor):\r\n            return X.to(DEVICE)\r\n        else:\r\n            return torch.from_numpy(np.asarray(X)).to(DEVICE)\r\n"
  },
  {
    "path": "catenets/models/torch/flextenet.py",
    "content": "from typing import Any, Callable, List\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_DIM_P_OUT,\n    DEFAULT_DIM_P_R,\n    DEFAULT_DIM_S_OUT,\n    DEFAULT_DIM_S_R,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_PENALTY_ORTHOGONAL,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.torch.base import DEVICE, BaseCATEEstimator\nfrom catenets.models.torch.utils.model_utils import make_val_split\n\n\nclass FlexTELinearLayer(nn.Module):\n    \"\"\"Layer constructor function for a fully-connected layer. Adapted to allow passing\n    treatment indicator through layer without using it\"\"\"\n\n    def __init__(\n        self,\n        name: str,\n        dropout: bool = False,\n        dropout_prob: float = 0.5,\n        *args: Any,\n        **kwargs: Any,\n    ) -> None:\n        super(FlexTELinearLayer, self).__init__()\n        self.name = name\n        if dropout:\n            self.model = nn.Sequential(\n                nn.Dropout(dropout_prob), nn.Linear(*args, **kwargs)\n            ).to(DEVICE)\n        else:\n            self.model = nn.Sequential(nn.Linear(*args, **kwargs)).to(DEVICE)\n\n    def forward(self, tensors: List[torch.Tensor]) -> List:\n        if len(tensors) != 2:\n            raise ValueError(\n                \"Invalid number of tensor for the FlexLinearLayer layer. It requires the features vector and the treatments vector\"\n            )\n\n        features_vector = tensors[0]\n        treatments_vector = tensors[1]\n\n        return [self.model(features_vector), treatments_vector]\n\n\nclass FlexTESplitLayer(nn.Module):\n    \"\"\"\n    Create multitask layer has shape [shared, private_0, private_1]\n    \"\"\"\n\n    def __init__(\n        self,\n        name: str,\n        n_units_in: int,\n        n_units_in_p: int,\n        n_units_s: int,\n        n_units_p: int,\n        first_layer: bool,\n        dropout: bool = False,\n        dropout_prob: float = 0.5,\n    ) -> None:\n        super(FlexTESplitLayer, self).__init__()\n        self.name = name\n        self.first_layer = first_layer\n        self.n_units_in = n_units_in\n        self.n_units_in_p = n_units_in_p\n        self.n_units_s = n_units_s\n        self.n_units_p = n_units_p\n\n        if dropout:\n            self.net_shared = nn.Sequential(\n                nn.Dropout(dropout_prob), nn.Linear(n_units_in, n_units_s)\n            ).to(DEVICE)\n            self.net_p0 = nn.Sequential(\n                nn.Dropout(dropout_prob), nn.Linear(n_units_in_p, n_units_p)\n            ).to(DEVICE)\n            self.net_p1 = nn.Sequential(\n                nn.Dropout(dropout_prob), nn.Linear(n_units_in_p, n_units_p)\n            ).to(DEVICE)\n        else:\n            self.net_shared = nn.Sequential(nn.Linear(n_units_in, n_units_s)).to(DEVICE)\n            self.net_p0 = nn.Sequential(nn.Linear(n_units_in_p, n_units_p)).to(DEVICE)\n            self.net_p1 = nn.Sequential(nn.Linear(n_units_in_p, n_units_p)).to(DEVICE)\n\n    def forward(self, tensors: List[torch.Tensor]) -> List:\n        if self.first_layer and len(tensors) != 2:\n            raise ValueError(\n                \"Invalid number of tensor for the FlexSplitLayer layer. It requires the features vector and the treatments vector\"\n            )\n        if not self.first_layer and len(tensors) != 4:\n            raise ValueError(\n                \"Invalid number of tensor for the FlexSplitLayer layer. It requires X_s, X_p0, X_p1 and W as input\"\n            )\n\n        if self.first_layer:\n            X = tensors[0]\n            W = tensors[1]\n\n            rep_s = self.net_shared(X)\n            rep_p0 = self.net_p0(X)\n            rep_p1 = self.net_p1(X)\n\n        else:\n            X_s = tensors[0]\n            X_p0 = tensors[1]\n            X_p1 = tensors[2]\n            W = tensors[3]\n\n            rep_s = self.net_shared(X_s)\n            rep_p0 = self.net_p0(torch.cat([X_s, X_p0], dim=1))\n            rep_p1 = self.net_p1(torch.cat([X_s, X_p1], dim=1))\n\n        return [rep_s, rep_p0, rep_p1, W]\n\n\nclass FlexTEOutputLayer(nn.Module):\n    def __init__(\n        self,\n        n_units_in: int,\n        n_units_in_p: int,\n        private: bool,\n        dropout: bool = False,\n        dropout_prob: float = 0.5,\n    ) -> None:\n        super(FlexTEOutputLayer, self).__init__()\n        self.private = private\n        if dropout:\n            self.net_shared = nn.Sequential(\n                nn.Dropout(dropout_prob), nn.Linear(n_units_in, 1)\n            ).to(DEVICE)\n            self.net_p0 = nn.Sequential(\n                nn.Dropout(dropout_prob), nn.Linear(n_units_in_p, 1)\n            ).to(DEVICE)\n            self.net_p1 = nn.Sequential(\n                nn.Dropout(dropout_prob), nn.Linear(n_units_in_p, 1)\n            ).to(DEVICE)\n        else:\n            self.net_shared = nn.Sequential(nn.Linear(n_units_in, 1)).to(DEVICE)\n            self.net_p0 = nn.Sequential(nn.Linear(n_units_in_p, 1)).to(DEVICE)\n            self.net_p1 = nn.Sequential(nn.Linear(n_units_in_p, 1)).to(DEVICE)\n\n    def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:\n        if len(tensors) != 4:\n            raise ValueError(\n                \"Invalid number of tensor for the FlexSplitLayer layer. It requires X_s, X_p0, X_p1 and W as input\"\n            )\n        X_s = tensors[0]\n        X_p0 = tensors[1]\n        X_p1 = tensors[2]\n        W = tensors[3]\n\n        if self.private:\n            rep_p0 = self.net_p0(torch.cat([X_s, X_p0], dim=1)).squeeze()\n            rep_p1 = self.net_p1(torch.cat([X_s, X_p1], dim=1)).squeeze()\n\n            return (1 - W) * rep_p0 + W * rep_p1\n        else:\n            rep_s = self.net_shared(X_s).squeeze()\n            rep_p0 = self.net_p0(torch.cat([X_s, X_p0], dim=1)).squeeze()\n            rep_p1 = self.net_p1(torch.cat([X_s, X_p1], dim=1)).squeeze()\n\n            return (1 - W) * rep_p0 + W * rep_p1 + rep_s\n\n\nclass ElementWiseParallelActivation(nn.Module):\n    \"\"\"Layer that applies a scalar function elementwise on its inputs.\n\n    Input looks like: X_s, X_p0, X_p1, t = inputs\n    \"\"\"\n\n    def __init__(self, act: Callable, **act_kwargs: Any) -> None:\n        super(ElementWiseParallelActivation, self).__init__()\n        self.act = act\n        self.act_kwargs = act_kwargs\n\n    def forward(self, tensors: List[torch.Tensor]) -> List:\n        if len(tensors) != 4:\n            raise ValueError(\n                \"Invalid number of tensor for the ElementWiseParallelActivation layer. It requires X_s, X_p0, X_p1, t as input\"\n            )\n\n        return [\n            self.act(tensors[0], **self.act_kwargs),\n            self.act(tensors[1], **self.act_kwargs),\n            self.act(tensors[2], **self.act_kwargs),\n            tensors[3],\n        ]\n\n\nclass ElementWiseSplitActivation(nn.Module):\n    \"\"\"Layer that applies a scalar function elementwise on its inputs.\n\n    Input looks like: X, t = inputs\n    \"\"\"\n\n    def __init__(self, act: Callable, **act_kwargs: Any) -> None:\n        super(ElementWiseSplitActivation, self).__init__()\n        self.act = act\n        self.act_kwargs = act_kwargs\n\n    def forward(self, tensors: List[torch.Tensor]) -> List:\n        if len(tensors) != 2:\n            raise ValueError(\n                \"Invalid number of tensor for the ElementWiseSplitActivation layer. It requires X, t as input\"\n            )\n\n        return [\n            self.act(tensors[0], **self.act_kwargs),\n            tensors[1],\n        ]\n\n\nclass FlexTENet(BaseCATEEstimator):\n    \"\"\"\n    CLass implements FlexTENet, an architecture for treatment effect estimation that allows for\n    both shared and private information in each layer of the network.\n\n    Parameters\n    ----------\n    n_unit_in: int\n        Number of features\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)\n    n_units_s_out: int\n        Number of hidden units in each shared hypothesis layer\n    n_units_p_out: int\n        Number of hidden units in each private hypothesis layer\n    n_layers_r: int\n        Number of representation layers before hypothesis layers (distinction between\n        hypothesis layers and representation layers is made to match TARNet & SNets)\n    n_units_s_r: int\n        Number of hidden units in each shared representation layer\n    n_units_s_r: int\n        Number of hidden units in each private representation layer\n    private_out: bool, False\n        Whether the final prediction layer should be fully private, or retain a shared component.\n    weight_decay: float\n        l2 (ridge) penalty\n    penalty_orthogonal: float\n        orthogonalisation penalty\n    lr: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    early_stopping: bool, default True\n        Whether to use early stopping\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    opt: str, default 'adam'\n        Optimizer to use, accepts 'adam' and 'sgd'\n    shared_repr: bool, False\n        Whether to use a shared representation block as TARNet\n    lr_scale: float\n        Whether to scale down the learning rate after unfreezing the private components of the\n        network (only used if pretrain_shared=True)\n    normalize_ortho: bool, False\n        Whether to normalize the orthogonality penalty (by depth of network)\n    clipping_value: int, default 1\n        Gradients clipping value\n    \"\"\"\n\n    def __init__(\n        self,\n        n_unit_in: int,\n        binary_y: bool,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_s_out: int = DEFAULT_DIM_S_OUT,\n        n_units_p_out: int = DEFAULT_DIM_P_OUT,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_s_r: int = DEFAULT_DIM_S_R,\n        n_units_p_r: int = DEFAULT_DIM_P_R,\n        private_out: bool = False,\n        weight_decay: float = DEFAULT_PENALTY_L2,\n        penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n        lr: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        early_stopping: bool = True,\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        shared_repr: bool = False,\n        normalize_ortho: bool = False,\n        mode: int = 1,\n        clipping_value: int = 1,\n        dropout: bool = False,\n        dropout_prob: float = 0.5,\n    ) -> None:\n        super(FlexTENet, self).__init__()\n\n        self.binary_y = binary_y\n        self.n_layers_r = n_layers_r if n_layers_r else 1\n        self.n_layers_out = n_layers_out\n        self.n_units_s_out = n_units_s_out\n        self.n_units_p_out = n_units_p_out\n        self.n_units_s_r = n_units_s_r\n        self.n_units_p_r = n_units_p_r\n        self.private_out = private_out\n        self.mode = mode\n\n        self.penalty_orthogonal = penalty_orthogonal\n        self.weight_decay = weight_decay\n        self.lr = lr\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.val_split_prop = val_split_prop\n        self.early_stopping = early_stopping\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n        self.shared_repr = shared_repr\n        self.normalize_ortho = normalize_ortho\n        self.clipping_value = clipping_value\n        self.early_stopping = early_stopping\n        self.dropout = dropout\n        self.dropout_prob = dropout_prob\n\n        self.seed = seed\n        self.n_iter_print = n_iter_print\n\n        layers = []\n\n        if shared_repr:  # fully shared representation as in TARNet\n            layers.extend(\n                [\n                    FlexTELinearLayer(\n                        \"shared_repr_layer_0\",\n                        dropout,\n                        dropout_prob,\n                        n_unit_in,\n                        n_units_s_r,\n                    ),\n                    ElementWiseSplitActivation(nn.SELU(inplace=True)),\n                ]\n            )\n\n            # add required number of layers\n            for i in range(self.n_layers_r - 1):\n                layers.extend(\n                    [\n                        FlexTELinearLayer(\n                            f\"shared_repr_layer_{i + 1}\",\n                            dropout,\n                            dropout_prob,\n                            n_units_s_r,\n                            n_units_s_r,\n                        ),\n                        ElementWiseSplitActivation(nn.SELU(inplace=True)),\n                    ]\n                )\n\n        else:  # shared AND private representations\n            layers.extend(\n                [\n                    FlexTESplitLayer(\n                        \"shared_private_layer_0\",\n                        n_unit_in,\n                        n_unit_in,\n                        n_units_s_r,\n                        n_units_p_r,\n                        first_layer=True,\n                        dropout=dropout,\n                        dropout_prob=dropout_prob,\n                    ),\n                    ElementWiseParallelActivation(nn.SELU(inplace=True)),\n                ]\n            )\n\n            # add required number of layers\n            for i in range(n_layers_r - 1):\n                layers.extend(\n                    [\n                        FlexTESplitLayer(\n                            f\"shared_private_layer_{i + 1}\",\n                            n_units_s_r,\n                            n_units_s_r + n_units_p_r,\n                            n_units_s_r,\n                            n_units_p_r,\n                            first_layer=False,\n                            dropout=dropout,\n                            dropout_prob=dropout_prob,\n                        ),\n                        ElementWiseParallelActivation(nn.SELU(inplace=True)),\n                    ]\n                )\n\n        # add output layers\n        layers.extend(\n            [\n                FlexTESplitLayer(\n                    \"output_layer_0\",\n                    n_units_s_r,\n                    n_units_s_r if shared_repr else n_units_s_r + n_units_p_r,\n                    n_units_s_out,\n                    n_units_p_out,\n                    first_layer=(shared_repr),\n                    dropout=dropout,\n                    dropout_prob=dropout_prob,\n                ),\n                ElementWiseParallelActivation(nn.SELU(inplace=True)),\n            ]\n        )\n\n        # add required number of layers\n        for i in range(n_layers_out - 1):\n            layers.extend(\n                [\n                    FlexTESplitLayer(\n                        f\"output_layer_{i + 1}\",\n                        n_units_s_out,\n                        n_units_s_out + n_units_p_out,\n                        n_units_s_out,\n                        n_units_p_out,\n                        first_layer=False,\n                        dropout=dropout,\n                        dropout_prob=dropout_prob,\n                    ),\n                    ElementWiseParallelActivation(nn.SELU(inplace=True)),\n                ]\n            )\n\n        # append final layer\n        layers.append(\n            FlexTEOutputLayer(\n                n_units_s_out,\n                n_units_s_out + n_units_p_out,\n                private=self.private_out,\n                dropout=dropout,\n                dropout_prob=dropout_prob,\n            )\n        )\n        if binary_y:\n            layers.append(nn.Sigmoid())\n\n        self.model = nn.Sequential(*layers).to(DEVICE)\n\n    def _ortho_penalty_asymmetric(self) -> torch.Tensor:\n        def _get_cos_reg(\n            params_0: torch.Tensor, params_1: torch.Tensor, normalize: bool\n        ) -> torch.Tensor:\n            if normalize:\n                params_0 = params_0 / torch.linalg.norm(params_0, dim=0)\n                params_1 = params_1 / torch.linalg.norm(params_1, dim=0)\n\n            x_min = min(params_0.shape[0], params_1.shape[0])\n            y_min = min(params_0.shape[1], params_1.shape[1])\n\n            return (\n                torch.linalg.norm(\n                    params_0[:x_min, :y_min] * params_1[:x_min, :y_min], \"fro\"\n                )\n                ** 2\n            )\n\n        def _apply_reg_split_layer(\n            layer: FlexTESplitLayer, full: bool = True\n        ) -> torch.Tensor:\n            _ortho_body = 0\n            if full:\n                _ortho_body = _get_cos_reg(\n                    layer.net_p0[-1].weight,\n                    layer.net_p1[-1].weight,\n                    self.normalize_ortho,\n                )\n            _ortho_body += torch.sum(\n                _get_cos_reg(\n                    layer.net_shared[-1].weight,\n                    layer.net_p0[-1].weight,\n                    self.normalize_ortho,\n                )\n                + _get_cos_reg(\n                    layer.net_shared[-1].weight,\n                    layer.net_p1[-1].weight,\n                    self.normalize_ortho,\n                )\n            )\n            return _ortho_body\n\n        ortho_body = 0\n        for layer in self.model:\n            if not isinstance(layer, (FlexTESplitLayer, FlexTEOutputLayer)):\n                continue\n\n            if isinstance(layer, FlexTESplitLayer):\n                ortho_body += _apply_reg_split_layer(layer, full=True)\n\n            if self.private_out:\n                continue\n\n            ortho_body += _apply_reg_split_layer(layer, full=False)\n\n        return self.penalty_orthogonal * ortho_body\n\n    def loss(\n        self,\n        y0_pred: torch.Tensor,\n        y1_pred: torch.Tensor,\n        y_true: torch.Tensor,\n        t_true: torch.Tensor,\n    ) -> torch.Tensor:\n        def head_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n            if self.binary_y:\n                return nn.BCELoss()(y_pred, y_true)\n            else:\n                return (y_pred - y_true) ** 2\n\n        def po_loss() -> torch.Tensor:\n            loss0 = torch.mean((1.0 - t_true) * head_loss(y0_pred, y_true))\n            loss1 = torch.mean(t_true * head_loss(y1_pred, y_true))\n\n            return loss0 + loss1\n\n        return po_loss() + self._ortho_penalty_asymmetric()\n\n    def fit(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n    ) -> \"FlexTENet\":\n        \"\"\"\n        Fit treatment models.\n\n        Parameters\n        ----------\n        X : torch.Tensor of shape (n_samples, n_features)\n            The features to fit to\n        y : torch.Tensor of shape (n_samples,) or (n_samples, )\n            The outcome variable\n        w: torch.Tensor of shape (n_samples,)\n            The treatment indicator\n        \"\"\"\n        self.model.train()\n\n        X = torch.Tensor(X).to(DEVICE)\n        y = torch.Tensor(y).squeeze().to(DEVICE)\n        w = torch.Tensor(w).squeeze().long().to(DEVICE)\n\n        X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n            X, y, w=w, val_split_prop=self.val_split_prop, seed=self.seed\n        )\n\n        n = X.shape[0]  # could be different from before due to split\n\n        # calculate number of batches per epoch\n        batch_size = self.batch_size if self.batch_size < n else n\n        n_batches = int(np.round(n / batch_size)) if batch_size < n else 1\n        train_indices = np.arange(n)\n\n        optimizer = torch.optim.Adam(\n            self.parameters(), lr=self.lr, weight_decay=self.weight_decay\n        )\n\n        # training\n        val_loss_best = LARGE_VAL\n        patience = 0\n        for i in range(self.n_iter):\n            # shuffle data for minibatches\n            np.random.shuffle(train_indices)\n            train_loss = []\n            for b in range(n_batches):\n                optimizer.zero_grad()\n\n                idx_next = train_indices[\n                    (b * batch_size) : min((b + 1) * batch_size, n - 1)\n                ]\n\n                X_next = X[idx_next]\n                y_next = y[idx_next].squeeze()\n                w_next = w[idx_next].squeeze()\n\n                _, mu0, mu1 = self.predict(X_next, return_po=True, training=True)\n                batch_loss = self.loss(mu0, mu1, y_next, w_next)\n\n                batch_loss.backward()\n\n                torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)\n\n                optimizer.step()\n\n                train_loss.append(batch_loss.detach())\n\n            train_loss = torch.Tensor(train_loss).to(DEVICE)\n\n            if self.early_stopping or i % self.n_iter_print == 0:\n                with torch.no_grad():\n                    _, mu0, mu1 = self.predict(X_val, return_po=True, training=True)\n                    val_loss = self.loss(mu0, mu1, y_val, w_val).detach().cpu()\n                    if self.early_stopping:\n                        if val_loss_best > val_loss:\n                            val_loss_best = val_loss\n                            patience = 0\n                        else:\n                            patience += 1\n                        if patience > self.patience and (\n                            (i + 1) * n_batches > self.n_iter_min\n                        ):\n                            break\n                    if i % self.n_iter_print == 0:\n                        log.info(\n                            f\"[FlexTENet] Epoch: {i}, current {val_string} loss: {val_loss} train_loss: {torch.mean(train_loss)}\"\n                        )\n\n        return self\n\n    def predict(\n        self, X: torch.Tensor, return_po: bool = False, training: bool = False\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict treatment effects and potential outcomes\n\n        Parameters\n        ----------\n        X: array-like of shape (n_samples, n_features)\n            Test-sample features\n        Returns\n        -------\n        y: array-like of shape (n_samples,)\n        \"\"\"\n        if not training:\n            self.model.eval()\n\n        X = self._check_tensor(X).float()\n        W0 = torch.zeros(X.shape[0]).to(DEVICE)\n        W1 = torch.ones(X.shape[0]).to(DEVICE)\n\n        mu0 = self.model([X, W0])\n        mu1 = self.model([X, W1])\n\n        te = mu1 - mu0\n\n        if return_po:\n            return te, mu0, mu1\n\n        return te\n"
  },
  {
    "path": "catenets/models/torch/pseudo_outcome_nets.py",
    "content": "import abc\nimport copy\nfrom typing import Any, Optional, Tuple\n\nimport numpy as np\nimport torch\nfrom sklearn.model_selection import StratifiedKFold\nfrom torch import nn\n\nfrom catenets.models.constants import (\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_CF_FOLDS,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_OUT_T,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_STEP_SIZE_T,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_OUT_T,\n    DEFAULT_VAL_SPLIT,\n)\nfrom catenets.models.torch.base import (\n    DEVICE,\n    BaseCATEEstimator,\n    BasicNet,\n    PropensityNet,\n)\nfrom catenets.models.torch.utils.model_utils import predict_wrapper, train_wrapper\nfrom catenets.models.torch.utils.transformations import (\n    dr_transformation_cate,\n    pw_transformation_cate,\n    ra_transformation_cate,\n    u_transformation_cate,\n)\n\n\nclass PseudoOutcomeLearner(BaseCATEEstimator):\n    \"\"\"\n    Class implements TwoStepLearners based on pseudo-outcome regression as discussed in\n    Curth &vd Schaar (2021): RA-learner, PW-learner and DR-learner\n\n    Parameters\n    ----------\n    n_unit_in: int\n        Number of features\n    binary_y: bool, default False\n        Whether the outcome is binary\n    po_estimator: sklearn/PyTorch model, default: None\n        Custom potential outcome model. If this parameter is set, the rest of the parameters are ignored.\n    te_estimator: sklearn/PyTorch model, default: None\n        Custom treatment effects model. If this parameter is set, the rest of the parameters are ignored.\n    n_folds: int, default 1\n        Number of cross-fitting folds. If 1, no cross-fitting\n    n_layers_out: int\n        First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)\n    n_units_out: int\n        First stage Number of hidden units in each hypothesis layer\n    n_layers_r: int\n        Number of shared & private representation layers before hypothesis layers\n    n_units_r: int\n        Number of hidden units in representation shared before the hypothesis layers.\n    n_layers_out_t: int\n        Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)\n    n_units_out_t: int\n        Second stage Number of hidden units in each hypothesis layer\n    n_layers_out_prop: int\n        Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense\n        layer)\n    n_units_out_prop: int\n        Number of hidden units in each propensity score hypothesis layer\n    weight_decay: float\n        First stage l2 (ridge) penalty\n    weight_decay_t: float\n        Second stage l2 (ridge) penalty\n    lr: float\n        First stage learning rate for optimizer\n    lr_: float\n        Second stage learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    nonlin: string, default 'elu'\n        Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\n    weighting_strategy: str, default \"prop\"\n        Weighting strategy. Can be \"prop\" or \"1-prop\".\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    \"\"\"\n\n    def __init__(\n        self,\n        n_unit_in: int,\n        binary_y: bool,\n        po_estimator: Any = None,\n        te_estimator: Any = None,\n        n_folds: int = DEFAULT_CF_FOLDS,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_out_t: int = DEFAULT_UNITS_OUT_T,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = 0,\n        weight_decay: float = DEFAULT_PENALTY_L2,\n        weight_decay_t: float = DEFAULT_PENALTY_L2,\n        lr: float = DEFAULT_STEP_SIZE,\n        lr_t: float = DEFAULT_STEP_SIZE_T,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n        weighting_strategy: Optional[str] = \"prop\",\n        patience: int = DEFAULT_PATIENCE,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        batch_norm: bool = True,\n        early_stopping: bool = True,\n        dropout: bool = False,\n        dropout_prob: float = 0.2,\n    ):\n        super(PseudoOutcomeLearner, self).__init__()\n        self.n_unit_in = n_unit_in\n        self.binary_y = binary_y\n        self.n_layers_out = n_layers_out\n        self.n_units_out = n_units_out\n        self.n_units_out_prop = n_units_out_prop\n        self.n_layers_out_prop = n_layers_out_prop\n        self.weight_decay_t = weight_decay_t\n        self.weight_decay = weight_decay\n        self.weighting_strategy = weighting_strategy\n        self.lr = lr\n        self.lr_t = lr_t\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.val_split_prop = val_split_prop\n        self.n_iter_print = n_iter_print\n        self.seed = seed\n        self.nonlin = nonlin\n        self.n_folds = n_folds\n        self.patience = patience\n        self.n_iter_min = n_iter_min\n        self.n_layers_out_t = n_layers_out_t\n        self.n_units_out_t = n_units_out_t\n        self.n_layers_out = n_layers_out\n        self.n_units_out = n_units_out\n        self.batch_norm = batch_norm\n        self.early_stopping = early_stopping\n        self.dropout = dropout\n        self.dropout_prob = dropout_prob\n\n        # set estimators\n        self._te_template = te_estimator\n        self._po_template = po_estimator\n\n        self._te_estimator = self._generate_te_estimator()\n        self._po_estimator = self._generate_po_estimator()\n        if weighting_strategy is not None:\n            self._propensity_estimator = self._generate_propensity_estimator()\n\n    def _generate_te_estimator(self, name: str = \"te_estimator\") -> nn.Module:\n        if self._te_template is not None:\n            return copy.deepcopy(self._te_template)\n        return BasicNet(\n            name,\n            self.n_unit_in,\n            binary_y=False,\n            n_layers_out=self.n_layers_out_t,\n            n_units_out=self.n_units_out_t,\n            weight_decay=self.weight_decay_t,\n            lr=self.lr_t,\n            n_iter=self.n_iter,\n            batch_size=self.batch_size,\n            val_split_prop=self.val_split_prop,\n            n_iter_print=self.n_iter_print,\n            seed=self.seed,\n            nonlin=self.nonlin,\n            patience=self.patience,\n            n_iter_min=self.n_iter_min,\n            batch_norm=self.batch_norm,\n            early_stopping=self.early_stopping,\n            dropout=self.dropout,\n            dropout_prob=self.dropout_prob,\n        ).to(DEVICE)\n\n    def _generate_po_estimator(self, name: str = \"po_estimator\") -> nn.Module:\n        if self._po_template is not None:\n            return copy.deepcopy(self._po_template)\n\n        return BasicNet(\n            name,\n            self.n_unit_in,\n            binary_y=self.binary_y,\n            n_layers_out=self.n_layers_out,\n            n_units_out=self.n_units_out,\n            weight_decay=self.weight_decay,\n            lr=self.lr,\n            n_iter=self.n_iter,\n            batch_size=self.batch_size,\n            val_split_prop=self.val_split_prop,\n            n_iter_print=self.n_iter_print,\n            seed=self.seed,\n            nonlin=self.nonlin,\n            patience=self.patience,\n            n_iter_min=self.n_iter_min,\n            batch_norm=self.batch_norm,\n            early_stopping=self.early_stopping,\n            dropout=self.dropout,\n            dropout_prob=self.dropout_prob,\n        ).to(DEVICE)\n\n    def _generate_propensity_estimator(\n        self, name: str = \"propensity_estimator\"\n    ) -> nn.Module:\n        if self.weighting_strategy is None:\n            raise ValueError(\"Invalid weighting_strategy for PropensityNet\")\n        return PropensityNet(\n            name,\n            self.n_unit_in,\n            2,  # number of treatments\n            self.weighting_strategy,\n            n_units_out_prop=self.n_units_out_prop,\n            n_layers_out_prop=self.n_layers_out_prop,\n            weight_decay=self.weight_decay,\n            lr=self.lr,\n            n_iter=self.n_iter,\n            batch_size=self.batch_size,\n            n_iter_print=self.n_iter_print,\n            seed=self.seed,\n            nonlin=self.nonlin,\n            val_split_prop=self.val_split_prop,\n            batch_norm=self.batch_norm,\n            early_stopping=self.early_stopping,\n            dropout_prob=self.dropout_prob,\n            dropout=self.dropout,\n        ).to(DEVICE)\n\n    def fit(\n        self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor\n    ) -> \"PseudoOutcomeLearner\":\n        \"\"\"\n        Train treatment effects nets.\n\n        Parameters\n        ----------\n        X: array-like of shape (n_samples, n_features)\n            Train-sample features\n        y: array-like of shape (n_samples,)\n            Train-sample labels\n        w: array-like of shape (n_samples,)\n            Train-sample treatments\n        \"\"\"\n        self.train()\n\n        X = self._check_tensor(X).float()\n        y = self._check_tensor(y).squeeze().float()\n        w = self._check_tensor(w).squeeze().float()\n\n        n = len(y)\n\n        # STEP 1: fit plug-in estimators via cross-fitting\n        if self.n_folds == 1:\n            pred_mask = np.ones(n, dtype=bool)\n            # fit plug-in models\n            mu_0_pred, mu_1_pred, p_pred = self._first_step(\n                X, y, w, pred_mask, pred_mask\n            )\n        else:\n            mu_0_pred, mu_1_pred, p_pred = (\n                torch.zeros(n).to(DEVICE),\n                torch.zeros(n).to(DEVICE),\n                torch.zeros(n).to(DEVICE),\n            )\n\n            # create folds stratified by treatment assignment to ensure balance\n            splitter = StratifiedKFold(\n                n_splits=self.n_folds, shuffle=True, random_state=self.seed\n            )\n\n            for train_index, test_index in splitter.split(X.cpu(), w.cpu()):\n                # create masks\n                pred_mask = torch.zeros(n, dtype=bool).to(DEVICE)\n                pred_mask[test_index] = 1\n\n                # fit plug-in te_estimator\n                (\n                    mu_0_pred[pred_mask],\n                    mu_1_pred[pred_mask],\n                    p_pred[pred_mask],\n                ) = self._first_step(X, y, w, ~pred_mask, pred_mask)\n\n        # use estimated propensity scores\n        if self.weighting_strategy is not None:\n            p = p_pred\n\n        # STEP 2: direct TE estimation\n        self._second_step(X, y, w, p, mu_0_pred, mu_1_pred)\n\n        return self\n\n    def predict(\n        self, X: torch.Tensor, return_po: bool = False, training: bool = False\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict treatment effects\n\n        Parameters\n        ----------\n        X: array-like of shape (n_samples, n_features)\n            Test-sample features\n        Returns\n        -------\n        te_est: array-like of shape (n_samples,)\n            Predicted treatment effects\n        \"\"\"\n        if return_po:\n            raise NotImplementedError(\n                \"PseudoOutcomeLearners have no Potential outcome predictors.\"\n            )\n        if not training:\n            self.eval()\n\n        X = self._check_tensor(X).float()\n        return predict_wrapper(self._te_estimator, X)\n\n    @abc.abstractmethod\n    def _first_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        pass\n\n    @abc.abstractmethod\n    def _second_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        p: torch.Tensor,\n        mu_0: torch.Tensor,\n        mu_1: torch.Tensor,\n    ) -> None:\n        pass\n\n    def _impute_pos(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        # split sample\n        X_fit, Y_fit, W_fit = X[fit_mask, :], y[fit_mask], w[fit_mask]\n\n        # fit two separate (standard) models\n        # untreated model\n        temp_model_0 = self._generate_po_estimator(\"po_estimator_0_impute_pos\")\n        train_wrapper(temp_model_0, X_fit[W_fit == 0], Y_fit[W_fit == 0])\n\n        # treated model\n        temp_model_1 = self._generate_po_estimator(\"po_estimator_1_impute_pos\")\n        train_wrapper(temp_model_1, X_fit[W_fit == 1], Y_fit[W_fit == 1])\n\n        mu_0_pred = predict_wrapper(temp_model_0, X[pred_mask, :])\n        mu_1_pred = predict_wrapper(temp_model_1, X[pred_mask, :])\n\n        return mu_0_pred, mu_1_pred\n\n    def _impute_propensity(\n        self,\n        X: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.tensor,\n        pred_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        # split sample\n        X_fit, W_fit = X[fit_mask, :], w[fit_mask]\n\n        # fit propensity estimator\n        temp_propensity_estimator = self._generate_propensity_estimator(\n            \"prop_estimator_impute_propensity\"\n        )\n        train_wrapper(temp_propensity_estimator, X_fit, W_fit)\n\n        # predict propensity on hold out\n        return temp_propensity_estimator.get_importance_weights(\n            X[pred_mask, :], w[pred_mask]\n        )\n\n    def _impute_unconditional_mean(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        # R-learner and U-learner need to impute unconditional mean\n        X_fit, Y_fit = X[fit_mask, :], y[fit_mask]\n\n        # fit model\n        temp_model = self._generate_po_estimator(\"po_est_impute_unconditional_mean\")\n        train_wrapper(temp_model, X_fit, Y_fit)\n\n        return predict_wrapper(temp_model, X[pred_mask, :])\n\n\nclass DRLearner(PseudoOutcomeLearner):\n    \"\"\"\n    DR-learner for CATE estimation, based on doubly robust AIPW pseudo-outcome\n    \"\"\"\n\n    def _first_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        mu0_pred, mu1_pred = self._impute_pos(X, y, w, fit_mask, pred_mask)\n        p_pred = self._impute_propensity(X, w, fit_mask, pred_mask).squeeze()\n        return (\n            mu0_pred.squeeze().to(DEVICE),\n            mu1_pred.squeeze().to(DEVICE),\n            p_pred.to(DEVICE),\n        )\n\n    def _second_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        p: torch.Tensor,\n        mu_0: torch.Tensor,\n        mu_1: torch.Tensor,\n    ) -> None:\n        pseudo_outcome = dr_transformation_cate(y, w, p, mu_0, mu_1)\n        train_wrapper(self._te_estimator, X, pseudo_outcome.detach())\n\n\nclass PWLearner(PseudoOutcomeLearner):\n    \"\"\"\n    PW-learner for CATE estimation, based on singly robust Horvitz Thompson pseudo-outcome\n    \"\"\"\n\n    def _first_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n\n        mu0_pred, mu1_pred = np.nan, np.nan  # not needed\n        p_pred = self._impute_propensity(X, w, fit_mask, pred_mask).squeeze()\n        return mu0_pred.to(DEVICE), mu1_pred.to(DEVICE), p_pred.to(DEVICE)\n\n    def _second_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        p: torch.Tensor,\n        mu_0: torch.Tensor,\n        mu_1: torch.Tensor,\n    ) -> None:\n        pseudo_outcome = pw_transformation_cate(y, w, p)\n        train_wrapper(self._te_estimator, X, pseudo_outcome.detach())\n\n\nclass RALearner(PseudoOutcomeLearner):\n    \"\"\"\n    RA-learner for CATE estimation, based on singly robust regression-adjusted pseudo-outcome\n    \"\"\"\n\n    def _first_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        mu0_pred, mu1_pred = self._impute_pos(X, y, w, fit_mask, pred_mask)\n        p_pred = np.nan  # not needed\n        return mu0_pred.squeeze().to(DEVICE), mu1_pred.squeeze().to(DEVICE), p_pred\n\n    def _second_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        p: torch.Tensor,\n        mu_0: torch.Tensor,\n        mu_1: torch.Tensor,\n    ) -> None:\n        pseudo_outcome = ra_transformation_cate(y, w, p, mu_0, mu_1)\n        train_wrapper(self._te_estimator, X, pseudo_outcome.detach())\n\n\nclass ULearner(PseudoOutcomeLearner):\n    \"\"\"\n    U-learner for CATE estimation. Based on pseudo-outcome (Y-mu(x))/(w-pi(x))\n    \"\"\"\n\n    def _first_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n\n        mu_pred = self._impute_unconditional_mean(X, y, fit_mask, pred_mask).squeeze()\n        mu1_pred = np.nan  # only have one thing to impute here\n        p_pred = self._impute_propensity(X, w, fit_mask, pred_mask).squeeze()\n        return mu_pred.to(DEVICE), mu1_pred, p_pred.to(DEVICE)\n\n    def _second_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        p: torch.Tensor,\n        mu_0: torch.Tensor,\n        mu_1: torch.Tensor,\n    ) -> None:\n        pseudo_outcome = u_transformation_cate(y, w, p, mu_0)\n        train_wrapper(self._te_estimator, X, pseudo_outcome.detach())\n\n\nclass RLearner(PseudoOutcomeLearner):\n    \"\"\"\n    R-learner for CATE estimation. Based on pseudo-outcome (Y-mu(x))/(w-pi(x)) and sample weight\n    (w-pi(x))^2 -- can only be implemented if .fit of te_estimator takes argument 'sample_weight'.\n    \"\"\"\n\n    def _first_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        mu_pred = self._impute_unconditional_mean(X, y, fit_mask, pred_mask).squeeze()\n        mu1_pred = np.nan  # only have one thing to impute here\n        p_pred = self._impute_propensity(X, w, fit_mask, pred_mask).squeeze()\n        return mu_pred.to(DEVICE), mu1_pred, p_pred.to(DEVICE)\n\n    def _second_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        p: torch.Tensor,\n        mu_0: torch.Tensor,\n        mu_1: torch.Tensor,\n    ) -> None:\n        pseudo_outcome = u_transformation_cate(y, w, p, mu_0)\n        train_wrapper(\n            self._te_estimator, X, pseudo_outcome.detach(), weight=(w - p) ** 2\n        )\n\n\nclass XLearner(PseudoOutcomeLearner):\n    \"\"\"\n    X-learner for CATE estimation. Combines two CATE estimates via a weighting function g(x):\n    tau(x) = g(x) tau_0(x) + (1-g(x)) tau_1(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        *args: Any,\n        weighting_strategy: str = \"prop\",\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            *args,\n            **kwargs,\n        )\n        self.weighting_strategy = weighting_strategy\n\n    def _first_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        fit_mask: torch.Tensor,\n        pred_mask: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        mu0_pred, mu1_pred = self._impute_pos(X, y, w, fit_mask, pred_mask)\n        p_pred = np.nan\n        return mu0_pred.squeeze().to(DEVICE), mu1_pred.squeeze().to(DEVICE), p_pred\n\n    def _second_step(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n        p: torch.Tensor,\n        mu_0: torch.Tensor,\n        mu_1: torch.Tensor,\n    ) -> None:\n        # split by treatment status, fit one model per group\n        pseudo_0 = mu_1[w == 0] - y[w == 0]\n        self._te_estimator_0 = self._generate_te_estimator(\"te_estimator_0_xnet\")\n        train_wrapper(self._te_estimator_0, X[w == 0], pseudo_0.detach())\n\n        pseudo_1 = y[w == 1] - mu_0[w == 1]\n        self._te_estimator_1 = self._generate_te_estimator(\"te_estimator_1_xnet\")\n        train_wrapper(self._te_estimator_1, X[w == 1], pseudo_1.detach())\n\n        train_wrapper(self._propensity_estimator, X, w)\n\n    def predict(\n        self, X: torch.Tensor, return_po: bool = False, training: bool = False\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict treatment effects\n\n        Parameters\n        ----------\n        X: array-like of shape (n_samples, n_features)\n            Test-sample features\n        return_po: bool, default False\n            Whether to return potential outcome predictions. Placeholder, can only accept False.\n        Returns\n        -------\n        te_est: array-like of shape (n_samples,)\n            Predicted treatment effects\n        \"\"\"\n        if return_po:\n            raise NotImplementedError(\n                \"PseudoOutcomeLearners have no Potential outcome predictors.\"\n            )\n\n        if not training:\n            self.eval()\n\n        X = self._check_tensor(X).float().to(DEVICE)\n        tau0_pred = predict_wrapper(self._te_estimator_0, X)\n        tau1_pred = predict_wrapper(self._te_estimator_1, X)\n\n        weight = self._propensity_estimator.get_importance_weights(X)\n\n        return weight * tau0_pred + (1 - weight) * tau1_pred\n"
  },
  {
    "path": "catenets/models/torch/representation_nets.py",
    "content": "import abc\nfrom typing import Any, Optional, Tuple\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_DISC,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_R,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.torch.base import (\n    DEVICE,\n    BaseCATEEstimator,\n    BasicNet,\n    PropensityNet,\n    RepresentationNet,\n)\nfrom catenets.models.torch.utils.model_utils import make_val_split\n\nEPS = 1e-8\n\n\nclass BasicDragonNet(BaseCATEEstimator):\n    \"\"\"\n    Base class for TARNet and DragonNet.\n\n    Parameters\n    ----------\n    name: str\n        Estimator name\n    n_unit_in: int\n        Number of features\n    propensity_estimator: nn.Module\n        Propensity estimator\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_layers_r: int\n        Number of shared & private representation layers before the hypothesis layers.\n    n_units_r: int\n        Number of hidden units in representation before the hypothesis layers.\n    weight_decay: float\n        l2 (ridge) penalty\n    lr: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    nonlin: string, default 'elu'\n        Nonlinearity to use in the neural net. Can be 'elu', 'relu', 'selu', 'leaky_relu'.\n    weighting_strategy: optional str, None\n        Whether to include propensity head and which weightening strategy to use\n    penalty_disc: float, default zero\n         Discrepancy penalty.\n    \"\"\"\n\n    def __init__(\n        self,\n        name: str,\n        n_unit_in: int,\n        propensity_estimator: nn.Module,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        weight_decay: float = DEFAULT_PENALTY_L2,\n        lr: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n        weighting_strategy: Optional[str] = None,\n        penalty_disc: float = 0,\n        batch_norm: bool = True,\n        early_stopping: bool = True,\n        prop_loss_multiplier: float = 1,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        patience: int = DEFAULT_PATIENCE,\n        dropout: bool = False,\n        dropout_prob: float = 0.2,\n    ) -> None:\n        super(BasicDragonNet, self).__init__()\n\n        self.name = name\n        self.val_split_prop = val_split_prop\n        self.seed = seed\n        self.batch_size = batch_size\n        self.n_iter = n_iter\n        self.n_iter_print = n_iter_print\n        self.lr = lr\n        self.weight_decay = weight_decay\n        self.binary_y = binary_y\n        self.penalty_disc = penalty_disc\n        self.early_stopping = early_stopping\n        self.prop_loss_multiplier = prop_loss_multiplier\n        self.n_iter_min = n_iter_min\n        self.patience = patience\n        self.dropout = dropout\n        self.dropout_prob = dropout_prob\n\n        self._repr_estimator = RepresentationNet(\n            n_unit_in,\n            n_units=n_units_r,\n            n_layers=n_layers_r,\n            nonlin=nonlin,\n            batch_norm=batch_norm,\n        )\n        self._po_estimators = []\n        for idx in range(2):\n            self._po_estimators.append(\n                BasicNet(\n                    f\"{name}_po_estimator_{idx}\",\n                    n_units_r,\n                    binary_y=binary_y,\n                    n_layers_out=n_layers_out,\n                    n_units_out=n_units_out,\n                    nonlin=nonlin,\n                    batch_norm=batch_norm,\n                    dropout=dropout,\n                    dropout_prob=dropout_prob,\n                )\n            )\n        self._propensity_estimator = propensity_estimator\n\n    def loss(\n        self,\n        po_pred: torch.Tensor,\n        t_pred: torch.Tensor,\n        y_true: torch.Tensor,\n        t_true: torch.Tensor,\n        discrepancy: torch.Tensor,\n    ) -> torch.Tensor:\n        def head_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n            if self.binary_y:\n                return nn.BCELoss()(y_pred, y_true)\n            else:\n                return (y_pred - y_true) ** 2\n\n        def po_loss(\n            po_pred: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor\n        ) -> torch.Tensor:\n            y0_pred = po_pred[:, 0]\n            y1_pred = po_pred[:, 1]\n\n            loss0 = torch.mean((1.0 - t_true) * head_loss(y0_pred, y_true))\n            loss1 = torch.mean(t_true * head_loss(y1_pred, y_true))\n\n            return loss0 + loss1\n\n        def prop_loss(t_pred: torch.Tensor, t_true: torch.Tensor) -> torch.Tensor:\n            t_pred = t_pred + EPS\n            return nn.CrossEntropyLoss()(t_pred, t_true)\n\n        return (\n            po_loss(po_pred, y_true, t_true)\n            + self.prop_loss_multiplier * prop_loss(t_pred, t_true)\n            + discrepancy\n        )\n\n    def fit(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n    ) -> \"BasicDragonNet\":\n        \"\"\"\n        Fit the treatment models.\n\n        Parameters\n        ----------\n        X : torch.Tensor of shape (n_samples, n_features)\n            The features to fit to\n        y : torch.Tensor of shape (n_samples,) or (n_samples, )\n            The outcome variable\n        w: torch.Tensor of shape (n_samples,)\n            The treatment indicator\n        \"\"\"\n        self.train()\n\n        X = torch.Tensor(X).to(DEVICE)\n        y = torch.Tensor(y).squeeze().to(DEVICE)\n        w = torch.Tensor(w).squeeze().long().to(DEVICE)\n\n        X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n            X, y, w=w, val_split_prop=self.val_split_prop, seed=self.seed\n        )\n\n        n = X.shape[0]  # could be different from before due to split\n\n        # calculate number of batches per epoch\n        batch_size = self.batch_size if self.batch_size < n else n\n        n_batches = int(np.round(n / batch_size)) if batch_size < n else 1\n        train_indices = np.arange(n)\n\n        params = (\n            list(self._repr_estimator.parameters())\n            + list(self._po_estimators[0].parameters())\n            + list(self._po_estimators[1].parameters())\n            + list(self._propensity_estimator.parameters())\n        )\n        optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay)\n\n        # training\n        val_loss_best = LARGE_VAL\n        patience = 0\n        for i in range(self.n_iter):\n            # shuffle data for minibatches\n            np.random.shuffle(train_indices)\n            train_loss = []\n            for b in range(n_batches):\n                optimizer.zero_grad()\n\n                idx_next = train_indices[\n                    (b * batch_size) : min((b + 1) * batch_size, n - 1)\n                ]\n\n                X_next = X[idx_next]\n                y_next = y[idx_next].squeeze()\n                w_next = w[idx_next].squeeze()\n\n                po_preds, prop_preds, discr = self._step(X_next, w_next)\n                batch_loss = self.loss(po_preds, prop_preds, y_next, w_next, discr)\n\n                batch_loss.backward()\n\n                optimizer.step()\n\n                train_loss.append(batch_loss.detach())\n\n            train_loss = torch.Tensor(train_loss).to(DEVICE)\n\n            if self.early_stopping or i % self.n_iter_print == 0:\n                with torch.no_grad():\n                    po_preds, prop_preds, discr = self._step(X_val, w_val)\n                    val_loss = self.loss(po_preds, prop_preds, y_val, w_val, discr)\n                    if self.early_stopping:\n                        if val_loss_best > val_loss:\n                            val_loss_best = val_loss\n                            patience = 0\n                        else:\n                            patience += 1\n                        if patience > self.patience and (\n                            (i + 1) * n_batches > self.n_iter_min\n                        ):\n                            break\n                    if i % self.n_iter_print == 0:\n                        log.info(\n                            f\"[{self.name}] Epoch: {i}, current {val_string} loss: {val_loss} train_loss: {torch.mean(train_loss)}\"\n                        )\n\n        return self\n\n    @abc.abstractmethod\n    def _step(\n        self, X: torch.Tensor, w: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        ...\n\n    def _forward(self, X: torch.Tensor) -> torch.Tensor:\n        X = self._check_tensor(X)\n        repr_preds = self._repr_estimator(X).squeeze()\n        y0_preds = self._po_estimators[0](repr_preds).squeeze()\n        y1_preds = self._po_estimators[1](repr_preds).squeeze()\n\n        return torch.vstack((y0_preds, y1_preds)).T\n\n    def predict(\n        self, X: torch.Tensor, return_po: bool = False, training: bool = False\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict the treatment effects\n\n        Parameters\n        ----------\n        X: array-like of shape (n_samples, n_features)\n            Test-sample features\n        Returns\n        -------\n        y: array-like of shape (n_samples,)\n        \"\"\"\n        if not training:\n            self.eval()\n\n        X = self._check_tensor(X).float()\n        preds = self._forward(X)\n        y0_preds = preds[:, 0]\n        y1_preds = preds[:, 1]\n\n        outcome = y1_preds - y0_preds\n\n        if return_po:\n            return outcome, y0_preds, y1_preds\n\n        return outcome\n\n    def _maximum_mean_discrepancy(\n        self, X: torch.Tensor, w: torch.Tensor\n    ) -> torch.Tensor:\n        n = w.shape[0]\n        n_t = torch.sum(w)\n\n        X = X / torch.sqrt(torch.var(X, dim=0) + EPS)\n        w = w.unsqueeze(dim=0)\n\n        mean_control = (n / (n - n_t)) * torch.mean((1 - w).T * X, dim=0)\n        mean_treated = (n / n_t) * torch.mean(w.T * X, dim=0)\n\n        return self.penalty_disc * torch.sum((mean_treated - mean_control) ** 2)\n\n\nclass TARNet(BasicDragonNet):\n    \"\"\"\n    Class implements Shalit et al (2017)'s TARNet\n    \"\"\"\n\n    def __init__(\n        self,\n        n_unit_in: int,\n        binary_y: bool = False,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = 0,\n        nonlin: str = DEFAULT_NONLIN,\n        penalty_disc: float = DEFAULT_PENALTY_DISC,\n        batch_norm: bool = True,\n        dropout: bool = False,\n        dropout_prob: float = 0.2,\n        **kwargs: Any,\n    ) -> None:\n        propensity_estimator = PropensityNet(\n            \"tarnet_propensity_estimator\",\n            n_unit_in,\n            2,\n            \"prop\",\n            n_layers_out_prop=n_layers_out_prop,\n            n_units_out_prop=n_units_out_prop,\n            nonlin=nonlin,\n            batch_norm=batch_norm,\n            dropout_prob=dropout_prob,\n            dropout=dropout,\n        ).to(DEVICE)\n        super(TARNet, self).__init__(\n            \"TARNet\",\n            n_unit_in,\n            propensity_estimator,\n            binary_y=binary_y,\n            nonlin=nonlin,\n            penalty_disc=penalty_disc,\n            batch_norm=batch_norm,\n            dropout=dropout,\n            dropout_prob=dropout_prob,\n            **kwargs,\n        )\n        self.prop_loss_multiplier = 0\n\n    def _step(\n        self, X: torch.Tensor, w: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        repr_preds = self._repr_estimator(X).squeeze()\n\n        y0_preds = self._po_estimators[0](repr_preds).squeeze()\n        y1_preds = self._po_estimators[1](repr_preds).squeeze()\n\n        po_preds = torch.vstack((y0_preds, y1_preds)).T\n\n        prop_preds = self._propensity_estimator(X)\n\n        return po_preds, prop_preds, self._maximum_mean_discrepancy(repr_preds, w)\n\n\nclass DragonNet(BasicDragonNet):\n    \"\"\"\n    Class implements a variant based on Shi et al (2019)'s DragonNet.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_unit_in: int,\n        binary_y: bool = False,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = 0,\n        nonlin: str = DEFAULT_NONLIN,\n        n_units_r: int = DEFAULT_UNITS_R,\n        batch_norm: bool = True,\n        dropout: bool = False,\n        dropout_prob: float = 0.2,\n        **kwargs: Any,\n    ) -> None:\n        propensity_estimator = PropensityNet(\n            \"dragonnet_propensity_estimator\",\n            n_units_r,\n            2,\n            \"prop\",\n            n_layers_out_prop=n_layers_out_prop,\n            n_units_out_prop=n_units_out_prop,\n            nonlin=nonlin,\n            batch_norm=batch_norm,\n            dropout=dropout,\n            dropout_prob=dropout_prob,\n        ).to(DEVICE)\n        super(DragonNet, self).__init__(\n            \"DragonNet\",\n            n_unit_in,\n            propensity_estimator,\n            binary_y=binary_y,\n            nonlin=nonlin,\n            batch_norm=batch_norm,\n            dropout=dropout,\n            dropout_prob=dropout_prob,\n            **kwargs,\n        )\n\n    def _step(\n        self, X: torch.Tensor, w: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        repr_preds = self._repr_estimator(X).squeeze()\n\n        y0_preds = self._po_estimators[0](repr_preds).squeeze()\n        y1_preds = self._po_estimators[1](repr_preds).squeeze()\n\n        po_preds = torch.vstack((y0_preds, y1_preds)).T\n\n        prop_preds = self._propensity_estimator(repr_preds)\n\n        return po_preds, prop_preds, self._maximum_mean_discrepancy(repr_preds, w)\n"
  },
  {
    "path": "catenets/models/torch/slearner.py",
    "content": "from typing import Any, Optional\n\nimport torch\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_VAL_SPLIT,\n)\nfrom catenets.models.torch.base import (\n    DEVICE,\n    BaseCATEEstimator,\n    BasicNet,\n    PropensityNet,\n)\nfrom catenets.models.torch.utils.model_utils import predict_wrapper\n\n\nclass SLearner(BaseCATEEstimator):\n    \"\"\"\n    S-learner for treatment effect estimation (single learner, treatment indicator just another feature).\n\n    Parameters\n    ----------\n    n_unit_in: int\n        Number of features\n    binary_y: bool\n        Whether the outcome is binary\n    po_estimator: sklearn/PyTorch model, default: None\n        Custom potential outcome model. If this parameter is set, the rest of the parameters are ignored.\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)\n    n_layers_out_prop: int\n        Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Linear\n        layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_units_out_prop: int\n        Number of hidden units in each propensity score hypothesis layer\n    weight_decay: float\n        l2 (ridge) penalty\n    lr: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    nonlin: string, default 'elu'\n        Nonlinearity to use in the neural net. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\n    weighting_strategy: optional str, None\n        Whether to include propensity head and which weightening strategy to use\n    \"\"\"\n\n    def __init__(\n        self,\n        n_unit_in: int,\n        binary_y: bool,\n        po_estimator: Any = None,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n        weight_decay: float = DEFAULT_PENALTY_L2,\n        lr: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n        weighting_strategy: Optional[str] = None,\n        batch_norm: bool = True,\n        early_stopping: bool = True,\n        dropout: bool = False,\n        dropout_prob: float = 0.2,\n    ) -> None:\n        super(SLearner, self).__init__()\n\n        self._weighting_strategy = weighting_strategy\n        if po_estimator is not None:\n            self._po_estimator = po_estimator\n        else:\n            self._po_estimator = BasicNet(\n                \"slearner_po_estimator\",\n                n_unit_in + 1,\n                binary_y=binary_y,\n                n_layers_out=n_layers_out,\n                n_units_out=n_units_out,\n                weight_decay=weight_decay,\n                lr=lr,\n                n_iter=n_iter,\n                batch_size=batch_size,\n                val_split_prop=val_split_prop,\n                n_iter_print=n_iter_print,\n                seed=seed,\n                nonlin=nonlin,\n                batch_norm=batch_norm,\n                early_stopping=early_stopping,\n                dropout_prob=dropout_prob,\n                dropout=dropout,\n            ).to(DEVICE)\n        if weighting_strategy is not None:\n            self._propensity_estimator = PropensityNet(\n                \"slearner_prop_estimator\",\n                n_unit_in,\n                2,  # number of treatments\n                weighting_strategy,\n                n_units_out_prop=n_units_out_prop,\n                n_layers_out_prop=n_layers_out_prop,\n                weight_decay=weight_decay,\n                lr=lr,\n                n_iter=n_iter,\n                batch_size=batch_size,\n                n_iter_print=n_iter_print,\n                seed=seed,\n                nonlin=nonlin,\n                val_split_prop=val_split_prop,\n                batch_norm=batch_norm,\n                early_stopping=early_stopping,\n                dropout=dropout,\n                dropout_prob=dropout_prob,\n            ).to(DEVICE)\n\n    def fit(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n    ) -> \"SLearner\":\n        \"\"\"\n        Fit treatment models.\n\n        Parameters\n        ----------\n        X : torch.Tensor of shape (n_samples, n_features)\n            The features to fit to\n        y : torch.Tensor of shape (n_samples,) or (n_samples, )\n            The outcome variable\n        w: torch.Tensor of shape (n_samples,)\n            The treatment indicator\n        \"\"\"\n        self.train()\n\n        X = torch.Tensor(X).to(DEVICE)\n        y = torch.Tensor(y).to(DEVICE)\n        w = torch.Tensor(w).to(DEVICE)\n\n        # add indicator as additional variable\n        X_ext = torch.cat((X, w.reshape((-1, 1))), dim=1).to(DEVICE)\n\n        if not (\n            hasattr(self._po_estimator, \"train\") or hasattr(self._po_estimator, \"fit\")\n        ):\n            raise NotImplementedError(\"invalid po_estimator for the slearner\")\n\n        if hasattr(self._po_estimator, \"fit\"):\n            log.info(\"Fit the sklearn po_estimator\")\n            self._po_estimator.fit(\n                X_ext.detach().cpu().numpy(), y.detach().cpu().numpy()\n            )\n            return self\n\n        if self._weighting_strategy is None:\n            # fit standard S-learner\n            log.info(\"Fit the PyTorch po_estimator\")\n            self._po_estimator.fit(X_ext, y)\n            return self\n\n        # use reweighting within the outcome model\n        log.info(\"Fit the PyTorch po_estimator with the propensity estimator\")\n        self._propensity_estimator.fit(X, w)\n        weights = self._propensity_estimator.get_importance_weights(X, w)\n        self._po_estimator.fit(X_ext, y, weight=weights)\n\n        return self\n\n    def _create_extended_matrices(self, X: torch.Tensor) -> torch.Tensor:\n        n = X.shape[0]\n        X = self._check_tensor(X)\n\n        # create extended matrices\n        w_1 = torch.ones((n, 1)).to(DEVICE)\n        w_0 = torch.zeros((n, 1)).to(DEVICE)\n        X_ext_0 = torch.cat((X, w_0), dim=1).to(DEVICE)\n        X_ext_1 = torch.cat((X, w_1), dim=1).to(DEVICE)\n\n        return [X_ext_0, X_ext_1]\n\n    def predict(\n        self, X: torch.Tensor, return_po: bool = False, training: bool = False\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict treatment effects and potential outcomes\n\n        Parameters\n        ----------\n        X: array-like of shape (n_samples, n_features)\n            Test-sample features\n        Returns\n        -------\n        y: array-like of shape (n_samples,)\n        \"\"\"\n        if not training:\n            self.eval()\n\n        X = self._check_tensor(X).float()\n        X_ext = self._create_extended_matrices(X)\n\n        y = []\n        for ext_mat in X_ext:\n            y.append(predict_wrapper(self._po_estimator, ext_mat).to(DEVICE))\n\n        outcome = y[1] - y[0]\n\n        if return_po:\n            return outcome, y[0], y[1]\n\n        return outcome\n"
  },
  {
    "path": "catenets/models/torch/snet.py",
    "content": "from typing import Tuple\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_LAYERS_R,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_MIN,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PATIENCE,\n    DEFAULT_PENALTY_DISC,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_PENALTY_ORTHOGONAL,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_UNITS_R_BIG_S,\n    DEFAULT_UNITS_R_SMALL_S,\n    DEFAULT_VAL_SPLIT,\n    LARGE_VAL,\n)\nfrom catenets.models.torch.base import (\n    DEVICE,\n    BaseCATEEstimator,\n    BasicNet,\n    PropensityNet,\n    RepresentationNet,\n)\nfrom catenets.models.torch.utils.model_utils import make_val_split\n\nEPS = 1e-8\n\n\nclass SNet(BaseCATEEstimator):\n    \"\"\"\n    Class implements SNet as discussed in Curth & van der Schaar (2021). Additionally to the\n    version implemented in the AISTATS paper, we also include an implementation that does not\n    have propensity heads (set with_prop=False)\n    Parameters\n    ----------\n    n_unit_in: int\n        Number of features\n    binary_y: bool, default False\n        Whether the outcome is binary\n    n_layers_r: int\n        Number of shared & private representation layers before the hypothesis layers.\n    n_units_r: int\n        Number of hidden units in representation shared before the hypothesis layer.\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)\n    n_layers_out_prop: int\n        Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Linear\n        layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    n_units_out_prop: int\n        Number of hidden units in each propensity score hypothesis layer\n    n_units_r_small: int\n        Number of hidden units in each PO functions private representation\n    weight_decay: float\n        l2 (ridge) penalty\n    lr: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    patience: int\n        Number of iterations to wait before early stopping after decrease in validation loss\n    n_iter_min: int\n        Minimum number of iterations to go through before starting early stopping\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    nonlin: string, default 'elu'\n        Nonlinearity to use in the neural net. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.\n    penalty_disc: float, default zero\n        Discrepancy penalty. Defaults to zero as this feature is not tested.\n    clipping_value: int, default 1\n        Gradients clipping value\n    \"\"\"\n\n    def __init__(\n        self,\n        n_unit_in: int,\n        binary_y: bool = False,\n        n_layers_r: int = DEFAULT_LAYERS_R,\n        n_units_r: int = DEFAULT_UNITS_R_BIG_S,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        n_units_out_prop: int = DEFAULT_UNITS_OUT,\n        n_layers_out_prop: int = DEFAULT_LAYERS_OUT,\n        weight_decay: float = DEFAULT_PENALTY_L2,\n        penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,\n        penalty_disc: float = DEFAULT_PENALTY_DISC,\n        lr: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        n_iter_min: int = DEFAULT_N_ITER_MIN,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n        ortho_reg_type: str = \"abs\",\n        patience: int = DEFAULT_PATIENCE,\n        clipping_value: int = 1,\n        batch_norm: bool = True,\n        with_prop: bool = True,\n        early_stopping: bool = True,\n        prop_loss_multiplier: float = 1,\n        dropout: bool = False,\n        dropout_prob: float = 0.2,\n    ) -> None:\n        super(SNet, self).__init__()\n\n        self.n_unit_in = n_unit_in\n        self.binary_y = binary_y\n        self.penalty_orthogonal = penalty_orthogonal\n        self.penalty_disc = penalty_disc\n        self.n_iter = n_iter\n        self.batch_size = batch_size\n        self.val_split_prop = val_split_prop\n        self.n_iter_print = n_iter_print\n        self.seed = seed\n        self.ortho_reg_type = ortho_reg_type\n        self.clipping_value = clipping_value\n        self.patience = patience\n        self.with_prop = with_prop\n        self.early_stopping = early_stopping\n        self.n_iter_min = n_iter_min\n        self.prop_loss_multiplier = prop_loss_multiplier\n        self.dropout = dropout\n        self.dropout_prob = dropout_prob\n\n        self._reps_mu0 = RepresentationNet(\n            n_unit_in,\n            n_units=n_units_r_small,\n            n_layers=n_layers_r,\n            nonlin=nonlin,\n            batch_norm=batch_norm,\n        )\n        self._reps_mu1 = RepresentationNet(\n            n_unit_in,\n            n_units=n_units_r_small,\n            n_layers=n_layers_r,\n            nonlin=nonlin,\n            batch_norm=batch_norm,\n        )\n\n        self._po_estimators = []\n\n        if self.with_prop:\n            self._reps_c = RepresentationNet(\n                n_unit_in,\n                n_units=n_units_r,\n                n_layers=n_layers_r,\n                nonlin=nonlin,\n                batch_norm=batch_norm,\n            )\n\n            self._reps_o = RepresentationNet(\n                n_unit_in,\n                n_units=n_units_r_small,\n                n_layers=n_layers_r,\n                nonlin=nonlin,\n                batch_norm=batch_norm,\n            )\n\n            self._reps_prop = RepresentationNet(\n                n_unit_in,\n                n_units=n_units_r,\n                n_layers=n_layers_r,\n                nonlin=nonlin,\n                batch_norm=batch_norm,\n            )\n\n            for idx in range(2):\n                self._po_estimators.append(\n                    BasicNet(\n                        f\"snet_po_estimator_{idx}\",\n                        n_units_r\n                        + n_units_r_small\n                        + n_units_r_small,  # (reps_c, reps_o, reps_mu{idx})\n                        binary_y=binary_y,\n                        n_layers_out=n_layers_out,\n                        n_units_out=n_units_out,\n                        nonlin=nonlin,\n                        batch_norm=batch_norm,\n                        dropout_prob=dropout_prob,\n                        dropout=dropout,\n                    )\n                )\n            self._propensity_estimator = PropensityNet(\n                \"snet_propensity_estimator\",\n                n_units_r + n_units_r,  # reps_c, reps_w\n                2,\n                \"prop\",\n                n_layers_out_prop=n_layers_out_prop,\n                n_units_out_prop=n_units_out_prop,\n                nonlin=nonlin,\n                batch_norm=batch_norm,\n                dropout=dropout,\n                dropout_prob=dropout_prob,\n            ).to(DEVICE)\n\n            params = (\n                list(self._reps_c.parameters())\n                + list(self._reps_o.parameters())\n                + list(self._reps_mu0.parameters())\n                + list(self._reps_mu1.parameters())\n                + list(self._reps_prop.parameters())\n                + list(self._po_estimators[0].parameters())\n                + list(self._po_estimators[1].parameters())\n                + list(self._propensity_estimator.parameters())\n            )\n        else:\n            self._reps_o = RepresentationNet(\n                n_unit_in,\n                n_units=n_units_r,\n                n_layers=n_layers_r,\n                nonlin=nonlin,\n                batch_norm=batch_norm,\n            )\n\n            for idx in range(2):\n                self._po_estimators.append(\n                    BasicNet(\n                        f\"snet_po_estimator_{idx}\",\n                        n_units_r + n_units_r_small,  # (reps_o, reps_mu{idx})\n                        binary_y=binary_y,\n                        n_layers_out=n_layers_out,\n                        n_units_out=n_units_out,\n                        nonlin=nonlin,\n                        batch_norm=batch_norm,\n                    )\n                )\n\n            params = (\n                list(self._reps_o.parameters())\n                + list(self._reps_mu0.parameters())\n                + list(self._reps_mu1.parameters())\n                + list(self._po_estimators[0].parameters())\n                + list(self._po_estimators[1].parameters())\n            )\n\n        self.optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)\n\n    def loss(\n        self,\n        y0_pred: torch.Tensor,\n        y1_pred: torch.Tensor,\n        t_pred: torch.Tensor,\n        discrepancy: torch.Tensor,\n        y_true: torch.Tensor,\n        t_true: torch.Tensor,\n    ) -> torch.Tensor:\n        def head_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n            if self.binary_y:\n                return nn.BCELoss()(y_pred, y_true)\n            else:\n                return (y_pred - y_true) ** 2\n\n        def po_loss(\n            y0_pred: torch.Tensor,\n            y1_pred: torch.Tensor,\n            y_true: torch.Tensor,\n            t_true: torch.Tensor,\n        ) -> torch.Tensor:\n            loss0 = torch.mean((1.0 - t_true) * head_loss(y0_pred, y_true))\n            loss1 = torch.mean(t_true * head_loss(y1_pred, y_true))\n\n            return loss0 + loss1\n\n        def prop_loss(\n            t_pred: torch.Tensor,\n            t_true: torch.Tensor,\n        ) -> torch.Tensor:\n            if self.with_prop:\n                t_pred = t_pred + EPS\n                return nn.CrossEntropyLoss()(t_pred, t_true)\n            else:\n                return 0\n\n        return (\n            po_loss(y0_pred, y1_pred, y_true, t_true)\n            + self.prop_loss_multiplier * prop_loss(t_pred, t_true)\n            + discrepancy\n            + self._ortho_reg()\n        )\n\n    def fit(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n    ) -> \"SNet\":\n        \"\"\"\n        Fit treatment models.\n\n        Parameters\n        ----------\n        X : torch.Tensor of shape (n_samples, n_features)\n            The features to fit to\n        y : torch.Tensor of shape (n_samples,) or (n_samples, )\n            The outcome variable\n        w: torch.Tensor of shape (n_samples,)\n            The treatment indicator\n        \"\"\"\n        self.train()\n\n        X = torch.Tensor(X).to(DEVICE)\n        y = torch.Tensor(y).squeeze().to(DEVICE)\n        w = torch.Tensor(w).squeeze().long().to(DEVICE)\n\n        X, y, w, X_val, y_val, w_val, val_string = make_val_split(\n            X, y, w=w, val_split_prop=self.val_split_prop, seed=self.seed\n        )\n\n        n = X.shape[0]  # could be different from before due to split\n\n        # calculate number of batches per epoch\n        batch_size = self.batch_size if self.batch_size < n else n\n        n_batches = int(np.round(n / batch_size)) if batch_size < n else 1\n        train_indices = np.arange(n)\n\n        # training\n        val_loss_best = LARGE_VAL\n        patience = 0\n        for i in range(self.n_iter):\n            # shuffle data for minibatches\n            np.random.shuffle(train_indices)\n            train_loss = []\n            for b in range(n_batches):\n                self.optimizer.zero_grad()\n\n                idx_next = train_indices[\n                    (b * batch_size) : min((b + 1) * batch_size, n - 1)\n                ]\n\n                X_next = X[idx_next]\n                y_next = y[idx_next].squeeze()\n                w_next = w[idx_next].squeeze()\n\n                y0_preds, y1_preds, prop_preds, discrepancy = self._step(X_next, w_next)\n                batch_loss = self.loss(\n                    y0_preds, y1_preds, prop_preds, discrepancy, y_next, w_next\n                )\n\n                batch_loss.backward()\n\n                torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)\n\n                self.optimizer.step()\n\n                train_loss.append(batch_loss.detach())\n\n            train_loss = torch.Tensor(train_loss).to(DEVICE)\n\n            if self.early_stopping or i % self.n_iter_print == 0:\n                with torch.no_grad():\n                    y0_preds, y1_preds, prop_preds, discrepancy = self._step(\n                        X_val, w_val\n                    )\n                    val_loss = (\n                        self.loss(\n                            y0_preds, y1_preds, prop_preds, discrepancy, y_val, w_val\n                        )\n                        .detach()\n                        .cpu()\n                    )\n                    if self.early_stopping:\n                        if val_loss_best > val_loss:\n                            val_loss_best = val_loss\n                            patience = 0\n                        else:\n                            patience += 1\n                        if patience > self.patience and (\n                            (i + 1) * n_batches > self.n_iter_min\n                        ):\n                            break\n\n                    if i % self.n_iter_print == 0:\n                        log.info(\n                            f\"[SNet] Epoch: {i}, current {val_string} loss: {val_loss} train_loss: {torch.mean(train_loss)}\"\n                        )\n\n        return self\n\n    def _ortho_reg(self) -> float:\n        def _get_absolute_rowsums(mat: torch) -> torch.Tensor:\n            return torch.sum(torch.abs(mat), dim=0)\n\n        def _get_cos_reg(\n            params_0: torch.Tensor, params_1: torch.Tensor, normalize: bool = False\n        ) -> torch.Tensor:\n            if normalize:\n                params_0 = params_0 / torch.linalg.norm(params_0, dim=0)\n                params_1 = params_1 / torch.linalg.norm(params_1, dim=0)\n\n            x_min = min(params_0.shape[0], params_1.shape[0])\n            y_min = min(params_0.shape[1], params_1.shape[1])\n\n            return (\n                torch.linalg.norm(\n                    params_0[:x_min, :y_min] * params_1[:x_min, :y_min], \"fro\"\n                )\n                ** 2\n            )\n\n        reps_o_params = self._reps_o.model[0].weight\n        reps_mu0_params = self._reps_mu0.model[0].weight\n        reps_mu1_params = self._reps_mu1.model[0].weight\n\n        if self.with_prop:\n            reps_c_params = self._reps_c.model[0].weight\n            reps_prop_params = self._reps_prop.model[0].weight\n\n        # define ortho-reg function\n        if self.ortho_reg_type == \"abs\":\n            col_o = _get_absolute_rowsums(reps_o_params)\n            col_mu0 = _get_absolute_rowsums(reps_mu0_params)\n            col_mu1 = _get_absolute_rowsums(reps_mu1_params)\n            if self.with_prop:\n                col_c = _get_absolute_rowsums(reps_c_params)\n                col_w = _get_absolute_rowsums(reps_prop_params)\n\n                return self.penalty_orthogonal * torch.sum(\n                    col_c * col_o\n                    + col_c * col_w\n                    + col_c * col_mu1\n                    + col_c * col_mu0\n                    + col_w * col_o\n                    + col_mu0 * col_o\n                    + col_o * col_mu1\n                    + col_mu0 * col_mu1\n                    + col_mu0 * col_w\n                    + col_w * col_mu1\n                )\n            else:\n                return self.penalty_orthogonal * torch.sum(\n                    +col_mu0 * col_o + col_o * col_mu1 + col_mu0 * col_mu1\n                )\n\n        elif self.ortho_reg_type == \"fro\":\n            if self.with_prop:\n                return self.penalty_orthogonal * (\n                    _get_cos_reg(reps_c_params, reps_o_params)\n                    + _get_cos_reg(reps_c_params, reps_mu0_params)\n                    + _get_cos_reg(reps_c_params, reps_mu1_params)\n                    + _get_cos_reg(reps_c_params, reps_prop_params)\n                    + _get_cos_reg(reps_o_params, reps_mu0_params)\n                    + _get_cos_reg(reps_o_params, reps_mu1_params)\n                    + _get_cos_reg(reps_o_params, reps_prop_params)\n                    + _get_cos_reg(reps_mu0_params, reps_mu1_params)\n                    + _get_cos_reg(reps_mu0_params, reps_prop_params)\n                    + _get_cos_reg(reps_mu1_params, reps_prop_params)\n                )\n            else:\n                return self.penalty_orthogonal * (\n                    +_get_cos_reg(reps_o_params, reps_mu0_params)\n                    + _get_cos_reg(reps_o_params, reps_mu1_params)\n                    + _get_cos_reg(reps_mu0_params, reps_mu1_params)\n                )\n\n        else:\n            raise ValueError(f\"Invalid orth_reg_typ {self.ortho_reg_type}\")\n\n    def _maximum_mean_discrepancy(\n        self, X: torch.Tensor, w: torch.Tensor\n    ) -> torch.Tensor:\n        n = w.shape[0]\n        n_t = torch.sum(w)\n\n        X = X / torch.sqrt(torch.var(X, dim=0) + EPS)\n        w = w.unsqueeze(dim=0)\n\n        mean_control = (n / (n - n_t)) * torch.mean((1 - w).T * X, dim=0)\n        mean_treated = (n / n_t) * torch.mean(w.T * X, dim=0)\n\n        return torch.sum((mean_treated - mean_control) ** 2)\n\n    def _step(\n        self, X: torch.Tensor, w: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        y0_preds, y1_preds, prop_preds, reps_o = self._forward(X)\n\n        discrepancy = self.penalty_disc * self._maximum_mean_discrepancy(reps_o, w)\n\n        return y0_preds, y1_preds, prop_preds, discrepancy\n\n    def _forward(\n        self, X: torch.Tensor\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        reps_o = self._reps_o(X)\n        reps_mu0 = self._reps_mu0(X)\n        reps_mu1 = self._reps_mu1(X)\n\n        if self.with_prop:\n            reps_c = self._reps_c(X)\n            reps_w = self._reps_prop(X)\n\n            reps_po_0 = torch.cat((reps_c, reps_o, reps_mu0), dim=1)\n            reps_po_1 = torch.cat((reps_c, reps_o, reps_mu1), dim=1)\n            reps_w = torch.cat((reps_c, reps_w), dim=1)\n            prop_preds = self._propensity_estimator(reps_w)\n        else:\n            reps_po_0 = torch.cat((reps_o, reps_mu0), dim=1)\n            reps_po_1 = torch.cat((reps_o, reps_mu1), dim=1)\n            prop_preds = 0.5 * torch.ones(len(X))  # no probability predictions\n\n        y0_preds = self._po_estimators[0](reps_po_0).squeeze()\n        y1_preds = self._po_estimators[1](reps_po_1).squeeze()\n\n        return y0_preds, y1_preds, prop_preds, reps_o\n\n    def predict(\n        self, X: torch.Tensor, return_po: bool = False, training: bool = False\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict treatment effects and potential outcomes\n\n        Parameters\n        ----------\n        X: array-like of shape (n_samples, n_features)\n            Test-sample features\n        Returns\n        -------\n        y: array-like of shape (n_samples,)\n        \"\"\"\n        if not training:\n            self.eval()\n\n        X = self._check_tensor(X).float()\n        y0_preds, y1_preds, _, _ = self._forward(X)\n\n        outcome = y1_preds - y0_preds\n\n        if return_po:\n            return outcome, y0_preds, y1_preds\n\n        return outcome\n"
  },
  {
    "path": "catenets/models/torch/tlearner.py",
    "content": "import copy\nfrom typing import Any\n\nimport torch\n\nfrom catenets.models.constants import (\n    DEFAULT_BATCH_SIZE,\n    DEFAULT_LAYERS_OUT,\n    DEFAULT_N_ITER,\n    DEFAULT_N_ITER_PRINT,\n    DEFAULT_NONLIN,\n    DEFAULT_PENALTY_L2,\n    DEFAULT_SEED,\n    DEFAULT_STEP_SIZE,\n    DEFAULT_UNITS_OUT,\n    DEFAULT_VAL_SPLIT,\n)\nfrom catenets.models.torch.base import DEVICE, BaseCATEEstimator, BasicNet\nfrom catenets.models.torch.utils.model_utils import predict_wrapper, train_wrapper\n\n\nclass TLearner(BaseCATEEstimator):\n    \"\"\"\n    TLearner class -- two separate functions learned for each Potential Outcome function\n\n    Parameters\n    ----------\n    n_unit_in: int\n        Number of features\n    binary_y: bool, default False\n        Whether the outcome is binary\n    po_estimator: sklearn/PyTorch model, default: None\n        Custom plugin model. If this parameter is set, the rest of the parameters are ignored.\n    n_layers_out: int\n        Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)\n    n_units_out: int\n        Number of hidden units in each hypothesis layer\n    weight_decay: float\n        l2 (ridge) penalty\n    lr: float\n        learning rate for optimizer\n    n_iter: int\n        Maximum number of iterations\n    batch_size: int\n        Batch size\n    val_split_prop: float\n        Proportion of samples used for validation split (can be 0)\n    n_iter_print: int\n        Number of iterations after which to print updates\n    seed: int\n        Seed used\n    nonlin: string, default 'elu'\n        Nonlinearity to use in the neural net. Cat be 'elu', 'relu', 'selu' or 'leaky_relu'.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_unit_in: int,\n        binary_y: bool,\n        po_estimator: Any = None,\n        n_layers_out: int = DEFAULT_LAYERS_OUT,\n        n_units_out: int = DEFAULT_UNITS_OUT,\n        weight_decay: float = DEFAULT_PENALTY_L2,\n        lr: float = DEFAULT_STEP_SIZE,\n        n_iter: int = DEFAULT_N_ITER,\n        batch_size: int = DEFAULT_BATCH_SIZE,\n        val_split_prop: float = DEFAULT_VAL_SPLIT,\n        n_iter_print: int = DEFAULT_N_ITER_PRINT,\n        seed: int = DEFAULT_SEED,\n        nonlin: str = DEFAULT_NONLIN,\n        batch_norm: bool = True,\n        early_stopping: bool = True,\n        dropout: bool = False,\n        dropout_prob: float = 0.2,\n    ) -> None:\n        super(TLearner, self).__init__()\n\n        self._plug_in: Any = []\n        plugins = [f\"tlearner_po_estimator_{i}\" for i in range(2)]\n        if po_estimator is not None:\n            for plugin in plugins:\n                self._plug_in.append(copy.deepcopy(po_estimator))\n        else:\n            for plugin in plugins:\n                self._plug_in.append(\n                    BasicNet(\n                        plugin,\n                        n_unit_in,\n                        binary_y=binary_y,\n                        n_layers_out=n_layers_out,\n                        n_units_out=n_units_out,\n                        weight_decay=weight_decay,\n                        lr=lr,\n                        n_iter=n_iter,\n                        batch_size=batch_size,\n                        val_split_prop=val_split_prop,\n                        n_iter_print=n_iter_print,\n                        seed=seed,\n                        nonlin=nonlin,\n                        batch_norm=batch_norm,\n                        early_stopping=early_stopping,\n                        dropout_prob=dropout_prob,\n                        dropout=dropout,\n                    ).to(DEVICE),\n                )\n\n    def predict(\n        self, X: torch.Tensor, return_po: bool = False, training: bool = False\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict treatment effects and potential outcomes\n        Parameters\n        ----------\n        X: torch.Tensor of shape (n_samples, n_features)\n            Test-sample features\n        return_po: bool\n            Return potential outcomes too\n\n        Returns\n        -------\n        y: torch.Tensor of shape (n_samples,)\n        \"\"\"\n        if not training:\n            self.eval()\n\n        X = self._check_tensor(X).float()\n\n        y_hat = []\n        for widx, plugin in enumerate(self._plug_in):\n            y_hat.append(predict_wrapper(plugin, X))\n\n        outcome = y_hat[1] - y_hat[0]\n\n        if return_po:\n            return outcome, y_hat[0], y_hat[1]\n\n        return outcome\n\n    def fit(\n        self,\n        X: torch.Tensor,\n        y: torch.Tensor,\n        w: torch.Tensor,\n    ) -> \"TLearner\":\n        \"\"\"\n        Train plug-in models.\n\n        Parameters\n        ----------\n        X : torch.Tensor (n_samples, n_features)\n            The features to fit to\n        y : torch.Tensor (n_samples,) or (n_samples, )\n            The outcome variable\n        w: torch.Tensor (n_samples,)\n            The treatment indicator\n        \"\"\"\n        self.train()\n\n        X = torch.Tensor(X).to(DEVICE)\n        y = torch.Tensor(y).to(DEVICE)\n        w = torch.Tensor(w).to(DEVICE)\n\n        for widx, plugin in enumerate(self._plug_in):\n            train_wrapper(plugin, X[w == widx], y[w == widx])\n\n        return self\n"
  },
  {
    "path": "catenets/models/torch/utils/__init__.py",
    "content": ""
  },
  {
    "path": "catenets/models/torch/utils/decorators.py",
    "content": "import time\nfrom typing import Any, Callable\n\nimport torch\n\nimport catenets.logger as log\n\n\ndef check_input_train(func: Callable) -> Callable:\n    \"\"\"Decorator used for checking training params.\n\n    Args:\n        func: the function to be benchmarked.\n\n    Returns:\n        Callable: the decorator\n\n    \"\"\"\n\n    def wrapper(self: Any, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> Any:\n\n        w = torch.Tensor(w)\n\n        if not ((w == 0) | (w == 1)).all():\n            raise ValueError(\"W should be binary\")\n\n        return func(self, X, y, w)\n\n    return wrapper\n\n\ndef benchmark(func: Callable) -> Callable:\n    \"\"\"Decorator used for function duration benchmarking. It is active only with DEBUG loglevel.\n\n    Args:\n        func: the function to be benchmarked.\n\n    Returns:\n        Callable: the decorator\n\n    \"\"\"\n\n    def wrapper(*args: Any, **kwargs: Any) -> Any:\n        start = time.time()\n        res = func(*args, **kwargs)\n        end = time.time()\n\n        log.debug(f\"{func.__qualname__} took {round(end - start, 4)} seconds\")\n        return res\n\n    return wrapper\n"
  },
  {
    "path": "catenets/models/torch/utils/model_utils.py",
    "content": "\"\"\"\r\nModel utils shared across different nets\r\n\"\"\"\r\n# Author: Alicia Curth, Bogdan Cebere\r\nfrom typing import Any, Optional\r\n\r\nimport torch\r\nfrom sklearn.model_selection import train_test_split\r\n\r\nimport catenets.logger as log\r\nfrom catenets.models.constants import DEFAULT_SEED, DEFAULT_VAL_SPLIT\r\n\r\nDEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n\r\nTRAIN_STRING = \"training\"\r\nVALIDATION_STRING = \"validation\"\r\n\r\n\r\ndef make_val_split(\r\n    X: torch.Tensor,\r\n    y: torch.Tensor,\r\n    w: Optional[torch.Tensor] = None,\r\n    val_split_prop: float = DEFAULT_VAL_SPLIT,\r\n    seed: int = DEFAULT_SEED,\r\n    stratify_w: bool = True,\r\n) -> Any:\r\n    if val_split_prop == 0:\r\n        # return original data\r\n        if w is None:\r\n            return X, y, X, y, TRAIN_STRING\r\n\r\n        return X, y, w, X, y, w, TRAIN_STRING\r\n\r\n    X = X.cpu()\r\n    y = y.cpu()\r\n    # make actual split\r\n    if w is None:\r\n        X_t, X_val, y_t, y_val = train_test_split(\r\n            X, y, test_size=val_split_prop, random_state=seed, shuffle=True\r\n        )\r\n        return (\r\n            X_t.to(DEVICE),\r\n            y_t.to(DEVICE),\r\n            X_val.to(DEVICE),\r\n            y_val.to(DEVICE),\r\n            VALIDATION_STRING,\r\n        )\r\n\r\n    w = w.cpu()\r\n    if stratify_w:\r\n        # split to stratify by group\r\n        X_t, X_val, y_t, y_val, w_t, w_val = train_test_split(\r\n            X,\r\n            y,\r\n            w,\r\n            test_size=val_split_prop,\r\n            random_state=seed,\r\n            stratify=w,\r\n            shuffle=True,\r\n        )\r\n    else:\r\n        X_t, X_val, y_t, y_val, w_t, w_val = train_test_split(\r\n            X, y, w, test_size=val_split_prop, random_state=seed, shuffle=True\r\n        )\r\n\r\n    return (\r\n        X_t.to(DEVICE),\r\n        y_t.to(DEVICE),\r\n        w_t.to(DEVICE),\r\n        X_val.to(DEVICE),\r\n        y_val.to(DEVICE),\r\n        w_val.to(DEVICE),\r\n        VALIDATION_STRING,\r\n    )\r\n\r\n\r\ndef train_wrapper(\r\n    estimator: Any,\r\n    X: torch.Tensor,\r\n    y: torch.Tensor,\r\n    **kwargs: Any,\r\n) -> None:\r\n    if hasattr(estimator, \"train\"):\r\n        log.debug(f\"Train PyTorch network {estimator}\")\r\n        estimator.fit(X, y, **kwargs)\r\n    elif hasattr(estimator, \"fit\"):\r\n        log.debug(f\"Train sklearn estimator {estimator}\")\r\n        estimator.fit(X.detach().cpu().numpy(), y.detach().cpu().numpy())\r\n    else:\r\n        raise NotImplementedError(f\"Invalid estimator for the {estimator}\")\r\n\r\n\r\ndef predict_wrapper(estimator: Any, X: torch.Tensor) -> torch.Tensor:\r\n    if hasattr(estimator, \"forward\"):\r\n        return estimator(X)\r\n    elif hasattr(estimator, \"predict_proba\"):\r\n        X_np = X.detach().cpu().numpy()\r\n        no_event_proba = estimator.predict_proba(X_np)[:, 0]  # no event probability\r\n\r\n        return torch.Tensor(no_event_proba)\r\n    elif hasattr(estimator, \"predict\"):\r\n        X_np = X.detach().cpu().numpy()\r\n        no_event_proba = estimator.predict(X_np)\r\n\r\n        return torch.Tensor(no_event_proba)\r\n    else:\r\n        raise NotImplementedError(f\"Invalid estimator for the {estimator}\")\r\n"
  },
  {
    "path": "catenets/models/torch/utils/transformations.py",
    "content": "\"\"\"\nUnbiased Transformations for CATE\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Optional\n\nimport torch\n\n\ndef dr_transformation_cate(\n    y: torch.Tensor,\n    w: torch.Tensor,\n    p: torch.Tensor,\n    mu_0: torch.Tensor,\n    mu_1: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"\n    Transforms data to efficient influence function/aipw pseudo-outcome for CATE estimation\n\n    Parameters\n    ----------\n    y : array-like of shape (n_samples,) or (n_samples, )\n        The observed outcome variable\n    w: array-like of shape (n_samples,)\n        The observed treatment indicator\n    p: array-like of shape (n_samples,)\n        The treatment propensity, estimated or known. Can be None, then p=0.5 is assumed\n    mu_0: array-like of shape (n_samples,)\n        Estimated or known potential outcome mean of the control group\n    mu_1: array-like of shape (n_samples,)\n        Estimated or known potential outcome mean of the treatment group\n    Returns\n    -------\n    d_hat:\n        EIF transformation for CATE\n    \"\"\"\n    if p is None:\n        # assume equal\n        p = torch.full(y.shape, 0.5)\n\n    EPS = 1e-7\n    w_1 = w / (p + EPS)\n    w_0 = (1 - w) / (EPS + 1 - p)\n    return (w_1 - w_0) * y + ((1 - w_1) * mu_1 - (1 - w_0) * mu_0)\n\n\ndef pw_transformation_cate(\n    y: torch.Tensor,\n    w: torch.Tensor,\n    p: Optional[torch.Tensor] = None,\n    mu_0: Optional[torch.Tensor] = None,\n    mu_1: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Transform data to Horvitz-Thompson transformation for CATE\n    Parameters\n    ----------\n    y : array-like of shape (n_samples,) or (n_samples, )\n        The observed outcome variable\n    w: array-like of shape (n_samples,)\n        The observed treatment indicator\n    p: array-like of shape (n_samples,)\n        The treatment propensity, estimated or known. Can be None, then p=0.5 is assumed\n    mu_0: array-like of shape (n_samples,)\n         Estimated or known potential outcome mean of the control group. Placeholder, not used.\n    mu_1: array-like of shape (n_samples,)\n        Estimated or known potential outcome mean of the treatment group. Placeholder, not used.\n    Returns\n    -------\n    res: array-like of shape (n_samples,)\n        Horvitz-Thompson transformed data\n    \"\"\"\n    if p is None:\n        # assume equal propensities\n        p = torch.full(y.shape, 0.5)\n    return (w / p - (1 - w) / (1 - p)) * y\n\n\ndef ra_transformation_cate(\n    y: torch.Tensor,\n    w: torch.Tensor,\n    p: torch.Tensor,\n    mu_0: torch.Tensor,\n    mu_1: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"\n    Transform data to regression adjustment for CATE\n\n    Parameters\n    ----------\n    y : array-like of shape (n_samples,) or (n_samples, )\n        The observed outcome variable\n    w: array-like of shape (n_samples,)\n        The observed treatment indicator\n    p: array-like of shape (n_samples,)\n        Placeholder, not used. The treatment propensity, estimated or known.\n    mu_0: array-like of shape (n_samples,)\n         Estimated or known potential outcome mean of the control group\n    mu_1: array-like of shape (n_samples,)\n        Estimated or known potential outcome mean of the treatment group\n\n    Returns\n    -------\n    res: array-like of shape (n_samples,)\n        Regression adjusted transformation\n    \"\"\"\n    return w * (y - mu_0) + (1 - w) * (mu_1 - y)\n\n\ndef u_transformation_cate(\n    y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu: torch.Tensor\n) -> torch.Tensor:\n    \"\"\"\n    Transform data to U-transformation (described in Kuenzel et al, 2019, Nie & Wager, 2017)\n    which underlies both R-learner and U-learner\n\n    Parameters\n    ----------\n    y : array-like of shape (n_samples,) or (n_samples, )\n        The observed outcome variable\n    w: array-like of shape (n_samples,)\n        The observed treatment indicator\n    p: array-like of shape (n_samples,)\n        Placeholder, not used. The treatment propensity, estimated or known.\n    mu_0: array-like of shape (n_samples,)\n         Estimated or known potential outcome mean of the control group\n    mu_1: array-like of shape (n_samples,)\n        Estimated or known potential outcome mean of the treatment group\n\n    Returns\n    -------\n    res: array-like of shape (n_samples,)\n        Regression adjusted transformation\n    \"\"\"\n    if p is None:\n        # assume equal propensities\n        p = torch.full(y.shape, 0.5)\n    return (y - mu) / (w - p)\n"
  },
  {
    "path": "catenets/models/torch/utils/weight_utils.py",
    "content": "\"\"\"\nImplement different reweighting/balancing strategies as in Li et al (2018)\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Optional\n\nimport torch\n\nIPW_NAME = \"ipw\"\nTRUNC_IPW_NAME = \"truncipw\"\nOVERLAP_NAME = \"overlap\"\nMATCHING_NAME = \"match\"\nPROP = \"prop\"\nONE_MINUS_PROP = \"1-prop\"\n\nALL_WEIGHTING_STRATEGIES = [\n    IPW_NAME,\n    TRUNC_IPW_NAME,\n    OVERLAP_NAME,\n    MATCHING_NAME,\n    PROP,\n    ONE_MINUS_PROP,\n]\n\n\ndef compute_importance_weights(\n    propensity: torch.Tensor,\n    w: torch.Tensor,\n    weighting_strategy: str,\n    weight_args: Optional[dict] = None,\n) -> torch.Tensor:\n    if weighting_strategy not in ALL_WEIGHTING_STRATEGIES:\n        raise ValueError(\n            f\"weighting_strategy should be in {ALL_WEIGHTING_STRATEGIES}\"\n            f\"You passed {weighting_strategy}\"\n        )\n    if weight_args is None:\n        weight_args = {}\n\n    if weighting_strategy == PROP:\n        return propensity\n    elif weighting_strategy == ONE_MINUS_PROP:\n        return 1 - propensity\n    elif weighting_strategy == IPW_NAME:\n        return compute_ipw(propensity, w)\n    elif weighting_strategy == TRUNC_IPW_NAME:\n        return compute_trunc_ipw(propensity, w, **weight_args)\n    elif weighting_strategy == OVERLAP_NAME:\n        return compute_overlap_weights(propensity, w)\n    elif weighting_strategy == MATCHING_NAME:\n        return compute_matching_weights(propensity, w)\n\n\ndef compute_ipw(propensity: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n    p_hat = torch.mean(w)\n    return w * p_hat / propensity + (1 - w) * (1 - p_hat) / (1 - propensity)\n\n\ndef compute_trunc_ipw(\n    propensity: torch.Tensor, w: torch.Tensor, cutoff: float = 0.05\n) -> torch.Tensor:\n    ipw = compute_ipw(propensity, w)\n    return torch.where((propensity > cutoff) & (propensity < 1 - cutoff), ipw, 0)\n\n\n# TODO check normalizing these weights\ndef compute_matching_weights(propensity: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n    ipw = compute_ipw(propensity, w)\n    return torch.minimum(ipw, 1 - ipw) * ipw\n\n\ndef compute_overlap_weights(propensity: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n    ipw = compute_ipw(propensity, w)\n    return propensity * (1 - propensity) * ipw\n"
  },
  {
    "path": "catenets/version.py",
    "content": "__version__ = \"0.2.3\"\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\n# import os\n# import sys\n# sys.path.insert(0, os.path.abspath('.'))\n\nimport datetime\nimport os\nimport shutil\nimport subprocess\nimport sys\n\nimport sphinx_rtd_theme\n\nsys.path.insert(0, os.path.abspath(\"..\"))\n\nsubprocess.run(\n    [\n        \"sphinx-apidoc\",\n        \"--ext-autodoc\",\n        \"--ext-doctest\",\n        \"--ext-mathjax\",\n        \"--ext-viewcode\",\n        \"-e\",\n        \"-T\",\n        \"-M\",\n        \"-F\",\n        \"-P\",\n        \"-f\",\n        \"-o\",\n        \"generated\",\n        \"../catenets/\",\n    ]\n)\n# -- Project information -----------------------------------------------------\nnow = datetime.datetime.now()\n\nproject = \"CATENets\"\nauthor = \"Alicia Curth\"\ncopyright = f\"{now.year}, {author}\"\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.napoleon\",\n    \"m2r2\",\n]\nautodoc_default_options = {\n    \"members\": True,\n    \"inherited-members\": False,\n    \"inherit_docstrings\": False,\n}\n\n\nadd_module_names = False\nautosummary_generate = True\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\"]\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"sphinx_rtd_theme\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\n"
  },
  {
    "path": "docs/datasets.rst",
    "content": "Datasets\n=========================\n\nDataloaders for datasets used for experiments.\n\n.. toctree::\n    :glob:\n    :maxdepth: 2\n\n    IHDP dataset <generated/catenets.datasets.dataset_ihdp.rst>\n    Twins dataset <generated/catenets.datasets.dataset_twins.rst>\n    ACIC dataset <generated/catenets.datasets.dataset_acic2016.rst>\n    Helpers <generated/catenets.datasets.network.rst>\n"
  },
  {
    "path": "docs/index.rst",
    "content": "Welcome to CATENets's documentation!\n====================================\n\n.. mdinclude:: ../README.md\n\n\nAPI documentation\n=================\n\nJAX models\n==========\n.. toctree::\n    :glob:\n    :maxdepth: 2\n\n    jax_models.rst\n\nPyTorch models\n==============\n.. toctree::\n    :glob:\n    :maxdepth: 2\n\n    torch_models.rst\n\n\nDatasets\n========\n.. toctree::\n    :glob:\n    :maxdepth: 2\n\n    datasets.rst\n"
  },
  {
    "path": "docs/jax_models.rst",
    "content": "JAX models\n=========================\n\nJAX-based CATE estimators\n\n.. toctree::\n    :glob:\n    :maxdepth: 2\n\n    T-Learners <generated/catenets.models.jax.tnet.rst>\n    R-Learners <generated/catenets.models.jax.rnet.rst>\n    X-Learners <generated/catenets.models.jax.xnet.rst>\n    Pseudo-Outcome Nets <generated/catenets.models.jax.pseudo_outocome_nets.rst>\n    Representation Nets <generated/catenets.models.jax.representation_nets.rst>\n    Disentangled Nets <generated/catenets.models.jax.disentangled_nets.rst>\n    S-Nets <generated/catenets.models.jax.snet.rst>\n    FlexTENet <generated/catenets.models.jax.flextenet.rst>\n    OffsetNet <generated/catenets.models.jax.offsetnet.rst>\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=.\r\nset BUILDDIR=_build\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.http://sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "autodoc\nbandit\nblack\ncatboost\nflake8\ngdown\njax>=0.3.16\njaxlib>=0.3.14; sys_platform != 'win32'\njupyter\nloguru>=0.5.3\nm2r2\nmyst-parser\nnotebook\nnumpy>=1.20\npandas>=1.3\npre-commit\npytest>=6.2.4\npytest\npytest\npytest-cov\nrequests\nscikit_learn>=0.24.2\nscipy>=1.2\nsetuptools\nsklearn\nsphinx-autopackagesummary\nsphinx-rtd-theme\nsphinxcontrib-napoleon\ntorch>=1.9\nxgboost\n"
  },
  {
    "path": "docs/torch_models.rst",
    "content": "PyTorch models\n=========================\n\nPyTorch-based CATE estimators\n\n.. toctree::\n    :glob:\n    :maxdepth: 2\n\n    T-Learners <generated/catenets.models.torch.tlearner.rst>\n    S-Learners <generated/catenets.models.torch.slearner.rst>\n    Pseudo-Outcome Nets <generated/catenets.models.torch.pseudo_outocome_nets.rst>\n    Representation Nets <generated/catenets.models.torch.representation_nets.rst>\n    S-Nets <generated/catenets.models.torch.snet.rst>\n"
  },
  {
    "path": "experiments/__init__.py",
    "content": ""
  },
  {
    "path": "experiments/experiments_AISTATS21/ihdp_experiments.py",
    "content": "\"\"\"\nScript to run experiments on Johansson's IHDP dataset (retrieved via https://www.fredjo.com/)\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nfrom sklearn import clone\n\nimport catenets.logger as log\nfrom catenets.datasets.dataset_ihdp import get_one_data_set, load_raw, prepare_ihdp_data\nfrom catenets.experiment_utils.base import eval_root_mse, get_model_set\nfrom catenets.models.jax import PSEUDOOUT_NAME, PseudoOutcomeNet\nfrom catenets.models.jax.transformation_utils import RA_TRANSFORMATION\n\n# Some constants\nDATA_DIR = Path(\"catenets/datasets/data/\")\nRESULT_DIR = Path(\"results/experiments_AISTATS21/ihdp/\")\nSEP = \"_\"\n\n# Hyperparameters for experiments on IHDP\nLAYERS_OUT = 2\nLAYERS_R = 3\nPENALTY_L2 = 0.01 / 100\nPENALTY_ORTHOGONAL_IHDP = 0\n\nMODEL_PARAMS = {\n    \"n_layers_out\": LAYERS_OUT,\n    \"n_layers_r\": LAYERS_R,\n    \"penalty_l2\": PENALTY_L2,\n    \"penalty_orthogonal\": PENALTY_ORTHOGONAL_IHDP,\n    \"n_layers_out_t\": LAYERS_OUT,\n    \"n_layers_r_t\": LAYERS_R,\n    \"penalty_l2_t\": PENALTY_L2,\n}\n\n# get basic models\nALL_MODELS_IHDP = get_model_set(model_selection=\"all\", model_params=MODEL_PARAMS)\n\nCOMBINED_MODELS_IHDP = {\n    PSEUDOOUT_NAME\n    + SEP\n    + RA_TRANSFORMATION\n    + SEP\n    + \"S2\": PseudoOutcomeNet(\n        n_layers_r=LAYERS_R,\n        n_layers_out=LAYERS_OUT,\n        penalty_l2=PENALTY_L2,\n        n_layers_r_t=LAYERS_R,\n        n_layers_out_t=LAYERS_OUT,\n        penalty_l2_t=PENALTY_L2,\n        transformation=RA_TRANSFORMATION,\n        first_stage_strategy=\"S2\",\n    )\n}\n\nFULL_MODEL_SET_IHDP = dict(**ALL_MODELS_IHDP, **COMBINED_MODELS_IHDP)\n\n\ndef do_ihdp_experiments(\n    n_exp: Union[int, list] = 100,\n    file_name: str = \"ihdp_results_scaled\",\n    model_params: Optional[dict] = None,\n    scale_cate: bool = True,\n    models: Union[list, dict, str, None] = None,\n) -> None:\n    if models is None:\n        models = FULL_MODEL_SET_IHDP\n    elif isinstance(models, (list, str)):\n        models = get_model_set(models)\n\n    # make path\n    if not os.path.exists(RESULT_DIR):\n        os.makedirs(RESULT_DIR)\n\n    # get file to write in\n    out_file = open(RESULT_DIR / (file_name + \".csv\"), \"w\", buffering=1)\n    writer = csv.writer(out_file)\n    header = [name + \"_in\" for name in models.keys()] + [\n        name + \"_out\" for name in models.keys()\n    ]\n    writer.writerow(header)\n\n    # get data\n    data_train, data_test = load_raw(DATA_DIR)\n\n    if isinstance(n_exp, int):\n        experiment_loop = list(range(1, n_exp + 1))\n    elif isinstance(n_exp, list):\n        experiment_loop = n_exp\n    else:\n        raise ValueError(\"n_exp should be either an integer or a list of integers.\")\n\n    for i_exp in experiment_loop:\n        pehe_in = []\n        pehe_out = []\n\n        # get data\n        data_exp = get_one_data_set(data_train, i_exp=i_exp, get_po=True)\n        data_exp_test = get_one_data_set(data_test, i_exp=i_exp, get_po=True)\n\n        X, y, w, cate_true_in, X_t, cate_true_out = prepare_ihdp_data(\n            data_exp, data_exp_test, rescale=scale_cate\n        )\n\n        for model_name, estimator in models.items():\n            log.info(f\"Experiment {i_exp} with {model_name}\")\n            estimator_temp = clone(estimator)\n            if model_params is not None:\n                estimator_temp.set_params(**model_params)\n\n            # fit estimator\n            estimator_temp.fit(X=X, y=y, w=w)\n\n            cate_pred_in = estimator_temp.predict(X, return_po=False)\n            cate_pred_out = estimator_temp.predict(X_t, return_po=False)\n\n            pehe_in.append(eval_root_mse(cate_pred_in, cate_true_in))\n            pehe_out.append(eval_root_mse(cate_pred_out, cate_true_out))\n\n        writer.writerow(pehe_in + pehe_out)\n\n    out_file.close()\n"
  },
  {
    "path": "experiments/experiments_AISTATS21/simulations_AISTATS.py",
    "content": "\"\"\"\nScript to generate synthetic simulations in AISTATS paper\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom typing import Any, Optional, Union\n\nfrom sklearn import clone\n\nimport catenets.logger as log\nfrom catenets.experiment_utils.base import eval_root_mse, get_model_set\nfrom catenets.experiment_utils.simulation_utils import simulate_treatment_setup\nfrom catenets.models.jax import PSEUDOOUT_NAME, PseudoOutcomeNet\nfrom catenets.models.jax.pseudo_outcome_nets import S1_STRATEGY, S_STRATEGY\nfrom catenets.models.jax.snet import DEFAULT_UNITS_R_BIG_S, DEFAULT_UNITS_R_SMALL_S\nfrom catenets.models.jax.transformation_utils import (\n    DR_TRANSFORMATION,\n    RA_TRANSFORMATION,\n)\n\n# some constants\nRESULT_DIR = \"results/experiments_AISTATS21/simulations/\"\nCSV_STRING = \".csv\"\nSEP = \"_\"\n\n# hyperparameters for experiments\nLAYERS_OUT = 2\nLAYERS_R = 3\nPENALTY_L2 = 0.01 / 100\nPENALTY_ORTHOGONAL = 1 / 100\n\nMODEL_PARAMS_AISTATS = {\n    \"n_layers_out\": LAYERS_OUT,\n    \"n_layers_r\": LAYERS_R,\n    \"penalty_l2\": PENALTY_L2,\n    \"penalty_orthogonal\": PENALTY_ORTHOGONAL,\n    \"n_layers_out_t\": LAYERS_OUT,\n    \"n_layers_r_t\": LAYERS_R,\n    \"penalty_l2_t\": PENALTY_L2,\n}\n\n# get basic models\nALL_MODELS_AISTATS = get_model_set(\n    model_selection=\"all\", model_params=MODEL_PARAMS_AISTATS\n)\n\n# model-twostep combinations\nCOMBINED_MODELS = {\n    PSEUDOOUT_NAME\n    + SEP\n    + DR_TRANSFORMATION\n    + SEP\n    + S_STRATEGY: PseudoOutcomeNet(\n        transformation=DR_TRANSFORMATION,\n        first_stage_strategy=S_STRATEGY,\n        n_units_r=DEFAULT_UNITS_R_BIG_S,\n        n_layers_out=LAYERS_OUT,\n        n_layers_r=LAYERS_R,\n        penalty_l2_t=PENALTY_L2,\n        penalty_l2=PENALTY_L2,\n        n_layers_out_t=LAYERS_OUT,\n        first_stage_args={\n            \"n_units_r_small\": DEFAULT_UNITS_R_SMALL_S,\n            \"penalty_orthogonal\": PENALTY_ORTHOGONAL,\n        },\n    ),\n    PSEUDOOUT_NAME\n    + SEP\n    + RA_TRANSFORMATION\n    + SEP\n    + S_STRATEGY: PseudoOutcomeNet(\n        transformation=RA_TRANSFORMATION,\n        first_stage_strategy=S_STRATEGY,\n        n_units_r=DEFAULT_UNITS_R_BIG_S,\n        n_layers_out=LAYERS_OUT,\n        n_layers_r=LAYERS_R,\n        penalty_l2_t=PENALTY_L2,\n        penalty_l2=PENALTY_L2,\n        n_layers_out_t=LAYERS_OUT,\n        n_layers_r_t=LAYERS_R,\n        first_stage_args={\n            \"n_units_r_small\": DEFAULT_UNITS_R_SMALL_S,\n            \"penalty_orthogonal\": PENALTY_ORTHOGONAL,\n        },\n    ),\n    PSEUDOOUT_NAME\n    + SEP\n    + DR_TRANSFORMATION\n    + SEP\n    + S1_STRATEGY: PseudoOutcomeNet(\n        transformation=DR_TRANSFORMATION,\n        first_stage_strategy=S1_STRATEGY,\n        n_layers_out=LAYERS_OUT,\n        n_layers_r=LAYERS_R,\n        penalty_l2_t=PENALTY_L2,\n        penalty_l2=PENALTY_L2,\n        n_layers_out_t=LAYERS_OUT,\n        n_layers_r_t=LAYERS_R,\n    ),\n    PSEUDOOUT_NAME\n    + SEP\n    + RA_TRANSFORMATION\n    + SEP\n    + S1_STRATEGY: PseudoOutcomeNet(\n        transformation=RA_TRANSFORMATION,\n        first_stage_strategy=S1_STRATEGY,\n        n_layers_out=LAYERS_OUT,\n        n_layers_r=LAYERS_R,\n        penalty_l2_t=PENALTY_L2,\n        penalty_l2=PENALTY_L2,\n        n_layers_out_t=LAYERS_OUT,\n        n_layers_r_t=LAYERS_R,\n    ),\n}\n\nFULL_MODEL_SET_AISTATS = dict(**ALL_MODELS_AISTATS, **COMBINED_MODELS)\n\n# some more constants for experiments\nNTRAIN_BASE = 2000\nNTEST_BASE = 500\nD_BASE = 25\nBASE_XI = 3\nTARGET_PROP_BASE = None\n\nXI_STRING = \"xi\"\nN_STRING = \"n\"\nD_T_STRING = \"dim_t\"\nPROPENSITY_CONSTANT_STRING = \"p\"\nTARGET_STRING = \"target_p\"\n\n\ndef simulation_experiment_loop(\n    range_change: list,\n    change_dim: str = N_STRING,\n    n_train: int = NTRAIN_BASE,\n    n_test: int = NTEST_BASE,\n    n_repeats: int = 10,\n    d: int = D_BASE,\n    n_w: int = 0,\n    n_c: int = 5,\n    n_o: int = 5,\n    n_t: int = 0,\n    file_base: str = \"results\",\n    xi: float = BASE_XI,\n    mu_1_model: Optional[dict] = None,\n    correlated_x: bool = False,\n    mu_1_model_params: Optional[dict] = None,\n    mu_0_model_params: Optional[dict] = None,\n    models: Optional[dict] = None,\n    nonlinear_prop: bool = True,\n    prop_offset: Union[float, str] = \"center\",\n    target_prop: Optional[float] = TARGET_PROP_BASE,\n) -> None:\n    if change_dim is N_STRING:\n        for n in range_change:\n            log.debug(f\"Running experiments for {N_STRING} set to {n}\")\n            do_one_experiment_repeat(\n                n_train=n,\n                n_test=n_test,\n                n_repeats=n_repeats,\n                d=d,\n                n_w=n_w,\n                n_c=n_c,\n                n_o=n_o,\n                n_t=n_t,\n                file_base=file_base,\n                xi=xi,\n                mu_1_model=mu_1_model,\n                correlated_x=correlated_x,\n                models=models,\n                mu_1_model_params=mu_1_model_params,\n                mu_0_model_params=mu_0_model_params,\n                nonlinear_prop=nonlinear_prop,\n                prop_offset=prop_offset,\n                target_prop=target_prop,\n            )\n    elif change_dim is XI_STRING:\n        for xi_temp in range_change:\n            log.debug(f\"Running experiments for {XI_STRING} set to {xi_temp}\")\n            do_one_experiment_repeat(\n                n_train=n_train,\n                n_test=n_test,\n                n_repeats=n_repeats,\n                d=d,\n                n_w=n_w,\n                n_c=n_c,\n                n_o=n_o,\n                n_t=n_t,\n                file_base=file_base,\n                xi=xi_temp,\n                mu_1_model=mu_1_model,\n                correlated_x=correlated_x,\n                models=models,\n                mu_1_model_params=mu_1_model_params,\n                mu_0_model_params=mu_0_model_params,\n                nonlinear_prop=nonlinear_prop,\n                prop_offset=prop_offset,\n                target_prop=target_prop,\n            )\n\n    elif change_dim is D_T_STRING:\n        for d_t_temp in range_change:\n            log.debug(f\"Running experiments for {D_T_STRING} set to {d_t_temp}\")\n            do_one_experiment_repeat(\n                n_train=n_train,\n                n_test=n_test,\n                n_repeats=n_repeats,\n                d=d,\n                n_w=n_w,\n                n_c=n_c,\n                n_o=n_o,\n                n_t=d_t_temp,\n                file_base=file_base,\n                xi=xi,\n                mu_1_model=mu_1_model,\n                correlated_x=correlated_x,\n                models=models,\n                mu_1_model_params=mu_1_model_params,\n                mu_0_model_params=mu_0_model_params,\n                nonlinear_prop=nonlinear_prop,\n                prop_offset=prop_offset,\n                target_prop=target_prop,\n            )\n\n    elif change_dim is TARGET_STRING:\n        for target_prop_temp in range_change:\n            log.debug(\n                f\"Running experiments for {TARGET_STRING} set to {target_prop_temp}\"\n            )\n            do_one_experiment_repeat(\n                n_train=n_train,\n                n_test=n_test,\n                n_repeats=n_repeats,\n                d=d,\n                n_w=n_w,\n                n_c=n_c,\n                n_o=n_o,\n                n_t=n_t,\n                file_base=file_base,\n                xi=xi,\n                mu_1_model=mu_1_model,\n                correlated_x=correlated_x,\n                models=models,\n                mu_1_model_params=mu_1_model_params,\n                mu_0_model_params=mu_0_model_params,\n                nonlinear_prop=nonlinear_prop,\n                prop_offset=prop_offset,\n                target_prop=target_prop_temp,\n            )\n\n\ndef do_one_experiment_repeat(\n    n_train: int = NTRAIN_BASE,\n    n_test: int = NTEST_BASE,\n    n_repeats: int = 10,\n    d: int = D_BASE,\n    n_w: int = 0,\n    n_c: int = 0,\n    n_o: int = 0,\n    n_t: int = 0,\n    file_base: str = \"results\",\n    xi: float = BASE_XI,\n    mu_1_model: Optional[dict] = None,\n    correlated_x: bool = True,\n    mu_1_model_params: Optional[dict] = None,\n    mu_0_model_params: Optional[dict] = None,\n    models: Optional[dict] = None,\n    nonlinear_prop: bool = True,\n    range_exp: Optional[list] = None,\n    prop_offset: Union[float, str] = 0,\n    target_prop: Optional[float] = None,\n) -> None:\n    # make path\n    if not os.path.exists(RESULT_DIR):\n        os.makedirs(RESULT_DIR)\n\n    if range_exp is None:\n        range_exp = list(range(1, n_repeats + 1))\n\n    if models is None:\n        models = FULL_MODEL_SET_AISTATS\n\n    if target_prop is None:\n        prop_string = str(prop_offset)\n    else:\n        prop_string = str(target_prop)\n\n    # create file name and file\n    file_name = (\n        file_base\n        + SEP\n        + str(n_train)\n        + SEP\n        + str(d)\n        + SEP\n        + str(n_w)\n        + SEP\n        + str(n_c)\n        + SEP\n        + str(n_o)\n        + SEP\n        + str(n_t)\n        + SEP\n        + str(xi)\n        + SEP\n        + prop_string\n    )\n\n    out_file = open(RESULT_DIR + file_name + CSV_STRING, \"w\", buffering=1)\n    writer = csv.writer(out_file)\n    header = [name for name in models.keys()]\n    writer.writerow(header)\n\n    for i in range_exp:\n        log.debug(f\"Running experiment {i}.\")\n        mses = one_simulation_experiment(\n            n_train=n_train,\n            n_test=n_test,\n            d=d,\n            n_w=n_w,\n            n_c=n_c,\n            n_o=n_o,\n            n_t=n_t,\n            seed=i,\n            xi=xi,\n            mu_1_model=mu_1_model,\n            correlated_x=correlated_x,\n            models=models,\n            nonlinear_prop=nonlinear_prop,\n            mu_0_model_params=mu_0_model_params,\n            mu_1_model_params=mu_1_model_params,\n            prop_offset=prop_offset,\n            target_prop=target_prop,\n        )\n        writer.writerow(mses)\n\n    out_file.close()\n    return None\n\n\ndef one_simulation_experiment(\n    n_train: int,\n    n_test: int = NTEST_BASE,\n    d: int = D_BASE,\n    n_w: int = 0,\n    n_c: int = 0,\n    n_o: int = 0,\n    n_t: int = 0,\n    xi: float = BASE_XI,\n    seed: int = 42,\n    mu_1_model: Optional[dict] = None,\n    propensity_model: Optional[dict] = None,\n    correlated_x: bool = False,\n    mu_1_model_params: Optional[dict] = None,\n    mu_0_model_params: Optional[dict] = None,\n    models: Optional[dict] = None,\n    nonlinear_prop: bool = False,\n    prop_offset: Union[float, str] = 0,\n    target_prop: Optional[float] = None,\n) -> list:\n    if models is None:\n        models = FULL_MODEL_SET_AISTATS\n\n    # get data\n    X, y, w, p, t = simulate_treatment_setup(\n        n_train + n_test,\n        d=d,\n        n_w=n_w,\n        n_c=n_c,\n        n_o=n_o,\n        n_t=n_t,\n        propensity_model=propensity_model,\n        propensity_model_params={\n            \"xi\": xi,\n            \"nonlinear\": nonlinear_prop,\n            \"offset\": prop_offset,\n            \"target_prop\": target_prop,\n        },\n        seed=seed,\n        mu_1_model=mu_1_model,\n        mu_0_model_params=mu_0_model_params,\n        mu_1_model_params=mu_1_model_params,\n        covariate_model_params={\"correlated\": correlated_x},\n    )\n    # split data\n    X_train, y_train, w_train, _ = (\n        X[:n_train, :],\n        y[:n_train],\n        w[:n_train],\n        p[:n_train],\n    )\n    X_test, t_test = X[n_train:, :], t[n_train:]\n\n    rmses = []\n    for model_name, model in models.items():\n        log.debug(f\"Training model {model_name}\")\n\n        estimator = clone(model)\n        estimator.fit(X=X_train, y=y_train, w=w_train)\n\n        cate_test = estimator.predict(X_test, return_po=False)\n        rmses.append(eval_root_mse(cate_test, t_test))\n\n    return rmses\n\n\ndef main_AISTATS(\n    setting: int = 1,\n    models: Any = None,\n    file_name: str = \"res\",\n    n_repeats: int = 10,\n) -> None:\n    if models is None:\n        models = FULL_MODEL_SET_AISTATS\n    elif type(models) is list or type(models) is str:\n        models = get_model_set(models)\n\n    if setting == 1:\n        # no treatment effect, with confounding, by n\n        simulation_experiment_loop(\n            [1000, 2000, 5000, 10000],\n            change_dim=\"n\",\n            n_t=0,\n            n_w=0,\n            n_c=5,\n            n_o=5,\n            file_base=file_name,\n            models=models,\n            n_repeats=n_repeats,\n        )\n    elif setting == 2:\n        # with treatment effect, with confounding, by n\n        simulation_experiment_loop(\n            [1000, 2000, 5000, 10000],\n            change_dim=\"n\",\n            n_t=5,\n            n_w=0,\n            n_c=5,\n            n_o=5,\n            file_base=file_name,\n            models=models,\n            n_repeats=n_repeats,\n        )\n    elif setting == 3:\n        # Potential outcomes are supported on independent covariates, no confounding, by n\n        simulation_experiment_loop(\n            [1000, 2000, 5000, 10000],\n            change_dim=\"n\",\n            n_t=10,\n            n_w=0,\n            n_c=0,\n            n_o=10,\n            file_base=file_name,\n            models=models,\n            xi=0.5,\n            mu_1_model_params={\"withbase\": False},\n            n_repeats=n_repeats,\n        )\n    elif setting == 4:\n        # vary number of predictive features at n=2000\n        simulation_experiment_loop(\n            [0, 1, 3, 5, 7, 10],\n            change_dim=D_T_STRING,\n            n_train=2000,\n            n_c=5,\n            n_o=5,\n            file_base=file_name,\n            models=models,\n            n_repeats=n_repeats,\n        )\n    elif setting == 5:\n        # vary percentage treated at n=2000\n        simulation_experiment_loop(\n            [0.1, 0.2, 0.3, 0.4, 0.5],\n            change_dim=TARGET_STRING,\n            n_train=2000,\n            n_c=5,\n            n_o=5,\n            n_t=0,\n            n_repeats=n_repeats,\n            file_base=file_name,\n            models=models,\n        )\n"
  },
  {
    "path": "experiments/experiments_benchmarks_NeurIPS21/README.md",
    "content": "# Replication code for \"Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation\"\n\nThis folder contains the files to replicate the benchmarking studies of random forest (RF) and neural network (NN) based CATE estimators using the IHDP, ACIC2016 and Twins datasets.\n\nThe code for RFs is in R and relies on the R-package ‘grf’ which is available on CRAN. The code for NNs relies on the python package ‘catenets’ in this repo.\n\nThis folder provides both python and R code to replicate the results of all empirical studies. Always run the python code first, this code downloads and/or creates the datasets that are used in both python and R code.\nThe python code can also be run using the file 'run_experiments_benchmarks_NeurIPS.py' in the root of the repo.\n\nFor IHDP: Setting ‘original’ reproduces results as reported in Figure 3a and setting ‘modified’ reproduces results in Figure 3b.\n\nFor ACIC: we considered the simulation numbers (`simu_num’) 2, 26 and 7.\n"
  },
  {
    "path": "experiments/experiments_benchmarks_NeurIPS21/__init__.py",
    "content": ""
  },
  {
    "path": "experiments/experiments_benchmarks_NeurIPS21/acic_experiments_catenets.py",
    "content": "\"\"\"\nUtils to replicate ACIC2016 experiments with catenets\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib import Path\n\nimport numpy as np\nfrom sklearn import clone\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.base import eval_root_mse\nfrom catenets.models.jax import RNET_NAME, T_NAME, TARNET_NAME, RNet, TARNet, TNet\n\nRESULT_DIR = Path(\"results/experiments_benchmarking/acic2016/\")\nSEP = \"_\"\n\nPARAMS_DEPTH = {\"n_layers_r\": 3, \"n_layers_out\": 2}\nPARAMS_DEPTH_2 = {\n    \"n_layers_r\": 3,\n    \"n_layers_out\": 2,\n    \"n_layers_r_t\": 3,\n    \"n_layers_out_t\": 2,\n}\n\nALL_MODELS = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    TARNET_NAME: TARNet(**PARAMS_DEPTH),\n    RNET_NAME: RNet(**PARAMS_DEPTH_2),\n}\n\n\ndef do_acic_experiments(\n    n_exp: int = 10,\n    n_reps=5,\n    file_name: str = \"results_catenets\",\n    simu_num: int = 1,\n    models: dict = None,\n    train_size: int = 4000,\n    pre_trans: bool = True,\n):\n    if models is None:\n        models = ALL_MODELS\n\n    # get file to write in\n    if not os.path.isdir(RESULT_DIR):\n        os.makedirs(RESULT_DIR)\n\n    out_file = open(\n        RESULT_DIR\n        / (\n            file_name\n            + SEP\n            + str(pre_trans)\n            + SEP\n            + str(simu_num)\n            + SEP\n            + str(train_size)\n            + \".csv\"\n        ),\n        \"w\",\n        buffering=1,\n    )\n    writer = csv.writer(out_file)\n    header = (\n        [\"file_name\", \"run\", \"cate_var_in\", \"cate_var_out\", \"y_var_in\"]\n        + [name + \"_in\" for name in models.keys()]\n        + [name + \"_out\" for name in models.keys()]\n    )\n    writer.writerow(header)\n\n    for i_exp in range(n_exp):\n        # get data\n        X, w, y, po_train, X_test, w_test, y_test, po_test = load(\n            \"acic2016\",\n            preprocessed=pre_trans,\n            original_acic_outcomes=True,\n            i_exp=i_exp,\n            simu_num=simu_num,\n            train_size=train_size,\n        )\n\n        cate_in = po_train[:, 1] - po_train[:, 0]\n        cate_out = po_test[:, 1] - po_test[:, 0]\n\n        cate_var_in = np.var(cate_in)\n        cate_var_out = np.var(cate_out)\n        y_var_in = np.var(y)\n        for k in range(n_reps):\n            pehe_in = []\n            pehe_out = []\n\n            for model_name, estimator in models.items():\n                print(f\"Experiment {i_exp}, run {k}, with {model_name}\")\n                estimator_temp = clone(estimator)\n                estimator_temp.set_params(seed=k)\n\n                # fit estimator\n                estimator_temp.fit(X=X, y=y, w=w)\n\n                cate_pred_in = estimator_temp.predict(X, return_po=False)\n                cate_pred_out = estimator_temp.predict(X_test, return_po=False)\n\n                pehe_in.append(eval_root_mse(cate_pred_in, cate_in))\n                pehe_out.append(eval_root_mse(cate_pred_out, cate_out))\n\n            writer.writerow(\n                [i_exp, k, cate_var_in, cate_var_out, y_var_in] + pehe_in + pehe_out\n            )\n\n    out_file.close()\n"
  },
  {
    "path": "experiments/experiments_benchmarks_NeurIPS21/acic_experiments_grf.R",
    "content": "library(grf)\n\ndo_acic_exper_loop <-\n  function(simnums = c(2, 26, 7),\n           n_reps = 5,\n           n_exp = 10,\n           with_t = F) {\n    # function to loop over multiple simulation settings\n    for (k in simnums) {\n      do_acic_exper(k,\n                    n_reps = n_reps,\n                    n_exp = n_exp,\n                    with_t = with_t)\n    }\n  }\n\ndo_acic_exper <- function(simnum,\n                          n_reps = 5,\n                          n_exp = 10,\n                          with_t = F) {\n  # function to do acic experiments for one simulation setting (simnum)\n  # n_reps indicates the number of replications (random seeds used)\n  # n_exp indicates the number of simulations to use within this setting (1-100)\n  # with_t indicates whether to create additional results with pre-transformed data\n\n  X <- data.matrix(read.csv('catenets/datasets/data/data_cf_all/x.csv'))\n  X_trans <- data.matrix(read.csv('catenets/datasets/data/x_trans.csv'))\n  range_train = 1:4000\n  range_test = 4001:4802\n\n  # get files\n  sim_dir = paste0('catenets/datasets/data/data_cf_all/', simnum, '/')\n  file_list <- list.files(sim_dir)\n\n  for (i in 1:(n_exp)) {\n    # loop over simulations within this setting\n    print(paste0('Experiment number ', i))\n    for (k in 1:n_reps) {\n      # loop over seeds\n      print(paste0('Iteration number ', k))\n      set.seed(k * i)\n\n      X_train <- X[range_train,]\n      X_test <- X[range_test,]\n      X_t_train <- X_trans[range_train,]\n\n      outcomes = read.csv(paste0(sim_dir, file_list[i]))\n      z = outcomes$z\n      y = outcomes$z * outcomes$y1 + (1 - outcomes$z) * outcomes$y0\n      t = outcomes$mu1 - outcomes$mu0\n\n      z_train = z[range_train]\n      y_train = y[range_train]\n      t_train = t[range_train]\n      t_test = t[range_test]\n\n      # causal forest\n      print('causal forest')\n      cf <- causal_forest(X_train, y_train, z_train, seed = k * i)\n      pred_cf <- predict(cf, X)$predictions\n      rmse_cf_in <- sqrt(mean((t_train - pred_cf[range_train]) ^ 2))\n      rmse_cf_out <- sqrt(mean((t_test - pred_cf[range_test]) ^ 2))\n\n      if (with_t == T) {\n        # also fit estimators using pre-transformed data\n        cf.t <- causal_forest(X_t_train, y_train, z_train,  seed = k * i)\n        pred_cf.t <- predict(cf.t, X_trans)$predictions\n        rmse_cf_in.t <- sqrt(mean((t_train - pred_cf.t[range_train]) ^ 2))\n        rmse_cf_out.t <- sqrt(mean((t_test - pred_cf.t[range_test]) ^ 2))\n      }\n\n      # t-learner\n      print('t learner')\n      y0.forest <-\n        regression_forest(subset(X_train, z_train == 0), y_train[z_train == 0],  seed =\n                            k * i)\n      y1.forest <-\n        regression_forest(subset(X_train, z_train == 1), y_train[z_train == 1],  seed =\n                            k * i)\n      pred_t <-\n        predict(y1.forest, X)$predictions - predict(y0.forest, X)$predictions\n      rmse_t_in <- sqrt(mean((t_train - pred_t[range_train]) ^ 2))\n      rmse_t_out <- sqrt(mean((t_test - pred_t[range_test]) ^ 2))\n\n      if (with_t == T) {\n        # also fit estimators using pre-transformed data\n        y0.forest.t <-\n          regression_forest(subset(X_t_train, z_train == 0), y_train[z_train == 0],  seed =\n                              k * i)\n        y1.forest.t <-\n          regression_forest(subset(X_t_train, z_train == 1), y_train[z_train == 1],  seed =\n                              k * i)\n        pred_t.t <-\n          predict(y1.forest.t, X_trans)$predictions - predict(y0.forest.t, X_trans)$predictions\n        rmse_t_in.t <- sqrt(mean((t_train - pred_t.t[range_train]) ^ 2))\n        rmse_t_out.t <- sqrt(mean((t_test - pred_t.t[range_test]) ^ 2))\n      }\n\n      # s-learner\n      print('s learner')\n      s_forest <-\n        regression_forest(cbind(X_train, z_train), y_train,  seed = k * i)\n      n_total <- nrow(X)\n      test_treated <- rep(1, n_total)\n      test_control <- rep(0, n_total)\n      pred_s <-\n        predict(s_forest, cbind(X, test_treated))$predictions - predict(s_forest, cbind(X, test_control))$predictions\n      rmse_s_in <- sqrt(mean((t_train - pred_s[range_train]) ^ 2))\n      rmse_s_out <- sqrt(mean((t_test - pred_s[range_test]) ^ 2))\n\n      if (with_t == T) {\n        # also fit estimators using pre-transformed data\n        s_forest.t <-\n          regression_forest(data.matrix(cbind(X_t_train, z_train)), y_train,  seed =\n                              k * i)\n        pred_s.t <-\n          predict(s_forest.t, data.matrix(cbind(X_trans, test_treated)))$predictions - predict(s_forest.t, data.matrix(cbind(X_trans, test_control)))$predictions\n        rmse_s_in.t <- sqrt(mean((t_train - pred_s.t[range_train]) ^ 2))\n        rmse_s_out.t <- sqrt(mean((t_test - pred_s.t[range_test]) ^ 2))\n      }\n\n\n      if (with_t == T) {\n        df_res <-\n          data.frame(\n            file = file_list[i],\n            run = k,\n            cf_in = rmse_cf_in,\n            cf_t_in = rmse_cf_in.t,\n            t_in = rmse_t_in,\n            t_t_in = rmse_t_in.t,\n            s_in = rmse_s_in,\n            s_in_t = rmse_s_in.t,\n            cf_out = rmse_cf_out,\n            cf_t_out = rmse_cf_out.t,\n            t_out = rmse_t_out,\n            t_t_out = rmse_t_out,\n            s_out = rmse_s_out,\n            s_t_out = rmse_s_out.t\n          )\n      } else{\n        df_res <-\n          data.frame(\n            file = file_list[i],\n            run = k,\n            cf_in = rmse_cf_in,\n            t_in = rmse_t_in,\n            s_in = rmse_s_in,\n            cf_out = rmse_cf_out,\n            t_out = rmse_t_out,\n            s_out = rmse_s_out\n          )\n      }\n\n      if (i * k == 1) {\n        write.table(\n          df_res,\n          file = paste0(\n            'results/experiments_benchmarking/acic2016/grf_',\n            simnum,\n            '_',\n            with_t,\n            '_',\n            n_exp,\n            '_',\n            n_reps,\n            '.csv'\n          ),\n          col.names = T,\n          sep = ',',\n          row.names = F\n        )\n      }\n      else{\n        write.table(\n          df_res,\n          file = paste0(\n            'results/experiments_benchmarking/acic2016/grf_',\n            simnum,\n            '_',\n            with_t,\n            '_',\n            n_exp,\n            '_',\n            n_reps,\n            '.csv'\n          ),\n          col.names = F,\n          sep = ',',\n          row.names = F,\n          append = T\n        )\n      }\n\n    }\n  }\n}\n"
  },
  {
    "path": "experiments/experiments_benchmarks_NeurIPS21/ihdp_experiments_catenets.py",
    "content": "\"\"\"\nUtils to replicate IHDP experiments with catenets\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nimport numpy as np\nfrom sklearn import clone\n\nfrom catenets.datasets.dataset_ihdp import get_one_data_set, load_raw, prepare_ihdp_data\nfrom catenets.experiment_utils.base import eval_root_mse\nfrom catenets.models.jax import RNET_NAME, T_NAME, TARNET_NAME, RNet, TARNet, TNet\n\nDATA_DIR = Path(\"catenets/datasets/data/\")\nRESULT_DIR = Path(\"results/experiments_benchmarking/ihdp/\")\nSEP = \"_\"\n\nPARAMS_DEPTH = {\"n_layers_r\": 3, \"n_layers_out\": 2}\nPARAMS_DEPTH_2 = {\n    \"n_layers_r\": 3,\n    \"n_layers_out\": 2,\n    \"n_layers_r_t\": 3,\n    \"n_layers_out_t\": 2,\n}\n\nALL_MODELS = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    TARNET_NAME: TARNet(**PARAMS_DEPTH),\n    RNET_NAME: RNet(**PARAMS_DEPTH_2),\n}\n\n\ndef do_ihdp_experiments(\n    n_exp: Union[int, list] = 100,\n    n_reps: int = 5,\n    file_name: str = \"ihdp_all\",\n    model_params: Optional[dict] = None,\n    models: Optional[dict] = None,\n    setting: str = \"original\",\n) -> None:\n    if models is None:\n        models = ALL_MODELS\n\n    if (setting == \"original\") or (setting == \"C\"):\n        setting = \"C\"\n    elif (setting == \"modified\") or (setting == \"D\"):\n        setting = \"D\"\n    else:\n        raise ValueError(\n            f\"Setting should be one of original or modified. You passed {setting}.\"\n        )\n\n    # get file to write in\n    if not os.path.isdir(RESULT_DIR):\n        os.makedirs(RESULT_DIR)\n\n    out_file = open(RESULT_DIR / (file_name + SEP + setting + \".csv\"), \"w\", buffering=1)\n    writer = csv.writer(out_file)\n    header = (\n        [\"exp\", \"run\", \"cate_var_in\", \"cate_var_out\", \"y_var_in\"]\n        + [name + \"_in\" for name in models.keys()]\n        + [name + \"_out\" for name in models.keys()]\n    )\n    writer.writerow(header)\n\n    # get data\n    data_train, data_test = load_raw(DATA_DIR)\n\n    if isinstance(n_exp, int):\n        experiment_loop = list(range(1, n_exp + 1))\n    elif isinstance(n_exp, list):\n        experiment_loop = n_exp\n    else:\n        raise ValueError(\"n_exp should be either an integer or a list of integers.\")\n\n    for i_exp in experiment_loop:\n        # get data\n        data_exp = get_one_data_set(data_train, i_exp=i_exp, get_po=True)\n        data_exp_test = get_one_data_set(data_test, i_exp=i_exp, get_po=True)\n\n        X, y, w, cate_true_in, X_t, cate_true_out = prepare_ihdp_data(\n            data_exp, data_exp_test, setting=setting\n        )\n\n        # compute some stats\n        cate_var_in = np.var(cate_true_in)\n        cate_var_out = np.var(cate_true_out)\n        y_var_in = np.var(y)\n\n        for k in range(n_reps):\n            pehe_in = []\n            pehe_out = []\n\n            for model_name, estimator in models.items():\n                print(f\"Experiment {i_exp}, run {k}, with {model_name}\")\n                estimator_temp = clone(estimator)\n                estimator_temp.set_params(seed=k)\n                if model_params is not None:\n                    estimator_temp.set_params(**model_params)\n\n                # fit estimator\n                estimator_temp.fit(X=X, y=y, w=w)\n\n                cate_pred_in = estimator_temp.predict(X, return_po=False)\n                cate_pred_out = estimator_temp.predict(X_t, return_po=False)\n\n                pehe_in.append(eval_root_mse(cate_pred_in, cate_true_in))\n                pehe_out.append(eval_root_mse(cate_pred_out, cate_true_out))\n\n            writer.writerow(\n                [i_exp, k, cate_var_in, cate_var_out, y_var_in] + pehe_in + pehe_out\n            )\n\n    out_file.close()\n"
  },
  {
    "path": "experiments/experiments_benchmarks_NeurIPS21/ihdp_experiments_grf.R",
    "content": "library(grf)\nlibrary(reticulate)\n\ndo_ihdp_exper <- function(n_exp = 100,\n                          n_reps = 5,\n                          setup = 'original') {\n  # read IHDP data (originally saved in numpy format)\n  np <- import(\"numpy\")\n  npz_train <- np$load('catenets/datasets/data/ihdp_npci_1-100.train.npz')\n\n  x_train <- npz_train$f[['x']]\n  y_train <- npz_train$f[['yf']]\n  w_train <- npz_train$f[['t']]\n  mu0_train <- npz_train$f[['mu0']]\n  mu1_train <- npz_train$f[['mu1']]\n\n\n  npz_test <- np$load('catenets/datasets/data/ihdp_npci_1-100.test.npz')\n\n  x_test <- npz_test$f[['x']]\n  y_test <- npz_test$f[['yf']]\n  w_test <- npz_test$f[['t']]\n  mu0_test <- npz_test$f[['mu0']]\n  mu1_test <- npz_test$f[['mu1']]\n\n\n  if (setup == 'modified') {\n    # make TE additive instead\n    y_train[w_train == 1] = y_train[w_train == 1] + mu0_train[w_train == 1]\n    mu1_train = mu0_train + mu1_train\n    mu1_test = mu0_test + mu1_test\n  }\n\n  cate_train <- mu1_train - mu0_train\n  cate_test <- mu1_test - mu0_test\n\n  for (i in 1:n_exp) {\n    # loop over runs\n    print(paste0('Experiment number', i))\n    for (k in 1:n_reps) {\n      # loop over seeds\n\n      # Causal forest ------------------------------\n      print('causal forest')\n      cf <-\n        causal_forest(x_train[, , i], y_train[, i], w_train[, i], seed = k)\n\n      # predict CATE\n      pred_cf_in <- predict(cf, x_train[, , i])$predictions\n      pred_cf_out <- predict(cf, x_test[, , i])$predictions\n\n      # Evaluate\n      rmse_cf_in <- sqrt(mean((cate_train[, i] - pred_cf_in) ^ 2))\n      rmse_cf_out <- sqrt(mean((cate_test[, i] - pred_cf_out) ^ 2))\n\n\n      # T-learner -----------------------------------------------------\n      print('t learner')\n      y0.forest <- regression_forest(subset(x_train[, , i], w_train[, i] == 0),\n                                     y_train[w_train[, i] == 0, i], seed = k)\n      y1.forest <- regression_forest(subset(x_train[, , i], w_train[, i] == 1),\n                                     y_train[w_train[, i] == 1, i], seed = k)\n      # predict CATE\n      pred_t_in <-\n        predict(y1.forest, x_train[, , i])$predictions - predict(y0.forest, x_train[, , i])$predictions\n      pred_t_out <-\n        predict(y1.forest, x_test[, , i])$predictions - predict(y0.forest, x_test[, , i])$predictions\n      # Evaluate\n      rmse_t_in <- sqrt(mean((cate_train[, i] - pred_t_in) ^ 2))\n      rmse_t_out <- sqrt(mean((cate_test[, i] - pred_t_out) ^ 2))\n\n      # s-learner -------------------------------------------------------------\n      print('s learner')\n      s_forest <-\n        regression_forest(cbind(x_train[, , i], w_train[, i]), y_train[, i], seed =\n                            k)\n      # create extended feature matrices\n      n_train <- nrow(x_train[, , i])\n      n_test <- nrow(x_test[, , i])\n      train_treated <- rep(1, n_train)\n      train_control <- rep(0, n_train)\n      test_treated <- rep(1, n_test)\n      test_control <- rep(0, n_test)\n\n      # predict CATE\n      pred_s_in <-\n        predict(s_forest, cbind(x_train[, , i], train_treated))$predictions - predict(s_forest, cbind(x_train[, , i], train_control))$predictions\n      pred_s_out <-\n        predict(s_forest, cbind(x_test[, , i], test_treated))$predictions - predict(s_forest, cbind(x_test[, , i], test_control))$predictions\n      # evaluate\n      rmse_s_in <- sqrt(mean((cate_train[, i] - pred_s_in) ^ 2))\n      rmse_s_out <- sqrt(mean((cate_test[, i] - pred_s_out) ^ 2))\n\n\n      df_res <-\n        data.frame(\n          simu = i,\n          run = k,\n          cf_in = rmse_cf_in,\n          t_in = rmse_t_in,\n          s_in = rmse_s_in,\n          cf_out = rmse_cf_out,\n          t_out = rmse_t_out,\n          s_out = rmse_s_out\n        )\n\n\n      if (i * k == 1) {\n        write.table(\n          df_res,\n          file = paste0('results/experiments_benchmarking/ihdp/grf_', setup, '.csv'),\n          col.names = T,\n          sep = ',',\n          row.names = F\n        )\n      }\n      else{\n        write.table(\n          df_res,\n          file = paste0('results/experiments_benchmarking/ihdp/grf_', setup, '.csv'),\n          col.names = F,\n          append = T,\n          sep = ',',\n          row.names = F\n        )\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "experiments/experiments_benchmarks_NeurIPS21/twins_experiments_catenets.py",
    "content": "\"\"\"\nUtils to replicate Twins experiments with catenets\n\"\"\"\nimport csv\n\n# Author: Alicia Curth\nimport os\nfrom pathlib import Path\n\nimport numpy as onp\nimport pandas as pd\nfrom sklearn import clone\nfrom sklearn.model_selection import train_test_split\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.base import eval_root_mse\nfrom catenets.models.jax import RNET_NAME, T_NAME, TARNET_NAME, RNet, TARNet, TNet\n\nRESULT_DIR = Path(\"results/experiments_benchmarking/twins/\")\nEXP_DIR = Path(\"experiments/experiments_benchmarks_NeurIPS21/twins_datasets/\")\nSEP = \"_\"\n\nPARAMS_DEPTH = {\"n_layers_r\": 1, \"n_layers_out\": 1}\nPARAMS_DEPTH_2 = {\n    \"n_layers_r\": 1,\n    \"n_layers_out\": 1,\n    \"n_layers_r_t\": 1,\n    \"n_layers_out_t\": 1,\n}\n\nALL_MODELS = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    TARNET_NAME: TARNet(**PARAMS_DEPTH),\n    RNET_NAME: RNet(**PARAMS_DEPTH_2),\n}\n\n\ndef do_twins_experiment_loop(\n    n_train_loop=[500, 1000, 2000, 5000, None],\n    n_exp: int = 10,\n    file_name: str = \"twins\",\n    models: dict = None,\n    test_size=0.5,\n):\n    for n in n_train_loop:\n        print(f\"Running twins experiments for subset_train {n}\")\n        do_twins_experiments(\n            n_exp=n_exp,\n            file_name=file_name,\n            models=models,\n            subset_train=n,\n            test_size=test_size,\n        )\n\n\ndef do_twins_experiments(\n    n_exp: int = 10,\n    file_name: str = \"twins\",\n    models: dict = None,\n    subset_train: int = None,\n    prop_treated=0.5,\n    test_size=0.5,\n):\n    if models is None:\n        models = ALL_MODELS\n\n    # get file to write in\n    if not os.path.isdir(RESULT_DIR):\n        os.makedirs(RESULT_DIR)\n    out_file = open(\n        RESULT_DIR\n        / (file_name + SEP + str(prop_treated) + SEP + str(subset_train) + \".csv\"),\n        \"w\",\n        buffering=1,\n    )\n\n    writer = csv.writer(out_file)\n    header = [name + \"_pehe\" for name in models.keys()]\n\n    writer.writerow(header)\n\n    for i_exp in range(n_exp):\n        pehe_out = []\n\n        # get data\n        X, X_t, y, w, y0_out, y1_out = prepare_twins(\n            seed=i_exp,\n            treat_prop=prop_treated,\n            subset_train=subset_train,\n            test_size=test_size,\n        )\n\n        ite_out = y1_out - y0_out\n\n        # split data\n        for model_name, estimator in models.items():\n            print(f\"Experiment {i_exp} with {model_name}\")\n            estimator_temp = clone(estimator)\n            estimator_temp.set_params(**{\"binary_y\": True, \"seed\": i_exp})\n\n            # fit estimator\n            estimator_temp.fit(X=X, y=y, w=w)\n\n            cate_pred_out = estimator_temp.predict(X_t)\n\n            pehe_out.append(eval_root_mse(cate_pred_out, ite_out))\n\n        writer.writerow(pehe_out)\n\n    out_file.close()\n\n\n# utils ---------------------------------------------------------------------\ndef prepare_twins(treat_prop=0.5, seed=42, test_size=0.5, subset_train: int = None):\n    if not os.path.isdir(EXP_DIR):\n        os.makedirs(EXP_DIR)\n\n    out_base = (\n        \"preprocessed\"\n        + SEP\n        + str(treat_prop)\n        + SEP\n        + str(subset_train)\n        + SEP\n        + str(test_size)\n        + SEP\n        + str(seed)\n    )\n    outfile_train = EXP_DIR / (out_base + SEP + \"train.csv\")\n    outfile_test = EXP_DIR / (out_base + SEP + \"test.csv\")\n\n    feat_list = [\n        \"dmage\",\n        \"mpcb\",\n        \"cigar\",\n        \"drink\",\n        \"wtgain\",\n        \"gestat\",\n        \"dmeduc\",\n        \"nprevist\",\n        \"dmar\",\n        \"anemia\",\n        \"cardiac\",\n        \"lung\",\n        \"diabetes\",\n        \"herpes\",\n        \"hydra\",\n        \"hemo\",\n        \"chyper\",\n        \"phyper\",\n        \"eclamp\",\n        \"incervix\",\n        \"pre4000\",\n        \"dtotord\",\n        \"preterm\",\n        \"renal\",\n        \"rh\",\n        \"uterine\",\n        \"othermr\",\n        \"adequacy_1\",\n        \"adequacy_2\",\n        \"adequacy_3\",\n        \"pldel_1\",\n        \"pldel_2\",\n        \"pldel_3\",\n        \"pldel_4\",\n        \"pldel_5\",\n        \"resstatb_1\",\n        \"resstatb_2\",\n        \"resstatb_3\",\n        \"resstatb_4\",\n    ]\n\n    if os.path.exists(outfile_train):\n        print(f\"Reading existing preprocessed twins file {out_base}\")\n        # use existing file\n        df_train = pd.read_csv(outfile_train)\n        X = onp.asarray(df_train[feat_list])\n        y = onp.asarray(df_train[[\"y\"]]).reshape((-1,))\n        w = onp.asarray(df_train[[\"w\"]]).reshape((-1,))\n\n        df_test = pd.read_csv(outfile_test)\n        X_t = onp.asarray(df_test[feat_list])\n        y0_out = onp.asarray(df_test[[\"y0\"]]).reshape((-1,))\n        y1_out = onp.asarray(df_test[[\"y1\"]]).reshape((-1,))\n    else:\n        # create file\n        print(f\"Creating preprocessed twins file {out_base}\")\n        onp.random.seed(seed)\n\n        x, w, y, pos, _, _ = load(\n            \"twins\", seed=seed, treat_prop=treat_prop, train_ratio=1\n        )\n\n        X, X_t, y, y_t, w, w_t, y0_in, y0_out, y1_in, y1_out = train_test_split(\n            x, y, w, pos[:, 0], pos[:, 1], test_size=test_size, random_state=seed\n        )\n        if subset_train is not None:\n            X, y, w, y0_in, y1_in = (\n                X[:subset_train, :],\n                y[:subset_train],\n                w[:subset_train],\n                y0_in[:subset_train],\n                y1_in[:subset_train],\n            )\n\n        # save data\n        save_df_train = pd.DataFrame(X, columns=feat_list)\n        save_df_train[\"y0\"] = y0_in\n        save_df_train[\"y1\"] = y1_in\n        save_df_train[\"w\"] = w\n        save_df_train[\"y\"] = y\n        save_df_train.to_csv(outfile_train)\n\n        save_df_train = pd.DataFrame(X_t, columns=feat_list)\n        save_df_train[\"y0\"] = y0_out\n        save_df_train[\"y1\"] = y1_out\n        save_df_train[\"w\"] = w_t\n        save_df_train[\"y\"] = y_t\n        save_df_train.to_csv(outfile_test)\n\n    return X, X_t, y, w, y0_out, y1_out\n"
  },
  {
    "path": "experiments/experiments_benchmarks_NeurIPS21/twins_experiments_grf.R",
    "content": "library(grf)\n\ndo_twins_exper <- function(\n                          n_reps = 10,\n                          subset_train = 500,\n                          test_size = 0.5,\n                          treat_prop=0.5) {\n  i=1\n  for (k in 0:(n_reps-1)) {\n      # loop over seeds\n      print(paste0('Iteration number ', k))\n      set.seed(k)\n\n      # read data (need to run the catenets script first; that creates the preprocessed data)\n      if (subset_train == 5700){\n        df_train <- read.csv(paste0('experiments/experiments_benchmarks_NeurIPS21/twins_datasets/preprocessed_', treat_prop, '_None_', test_size, '_', k, '_train.csv'))\n        df_test <- read.csv(paste0('experiments/experiments_benchmarks_NeurIPS21/twins_datasets/preprocessed_', treat_prop, '_None_', test_size, '_', k, '_test.csv'))\n      }else{\n      df_train <- read.csv(paste0('experiments/experiments_benchmarks_NeurIPS21/twins_datasets/preprocessed_', treat_prop, '_', subset_train, '_', test_size, '_', k, '_train.csv'))\n      df_test <- read.csv(paste0('experiments/experiments_benchmarks_NeurIPS21/twins_datasets/preprocessed_', treat_prop, '_', subset_train, '_', test_size, '_', k, '_test.csv'))\n      }\n      X_train <- data.matrix(df_train[,2:40])\n      X_test <- data.matrix(df_test[,2:40])\n\n\n      z_train = df_train$w\n      y_train = df_train$y\n      t_train = df_train$y1 - df_train$y0\n\n\n      z_test = df_test$w\n      y_test = df_test$y\n      t_test = df_test$y1 - df_test$y0\n\n      # causal forest\n      print('causal forest')\n      cf <- causal_forest(X_train, y_train, z_train, seed = k)\n      pred_cf_in <- predict(cf, X_train)$predictions\n      pred_cf_out <- predict(cf, X_test)$predictions\n      rmse_cf_in <- sqrt(mean((t_train - pred_cf_in) ^ 2))\n      rmse_cf_out <- sqrt(mean((t_test - pred_cf_out) ^ 2))\n\n\n      # t-learner\n      print('t learner')\n      y0.forest <-\n        regression_forest(subset(X_train, z_train == 0), y_train[z_train == 0],  seed =\n                            k * i)\n      y1.forest <-\n        regression_forest(subset(X_train, z_train == 1), y_train[z_train == 1],  seed =\n                            k * i)\n      pred_t_in <-\n        predict(y1.forest, X_train)$predictions - predict(y0.forest, X_train)$predictions\n      pred_t_out <-\n        predict(y1.forest, X_test)$predictions - predict(y0.forest, X_test)$predictions\n      rmse_t_in <- sqrt(mean((t_train - pred_t_in) ^ 2))\n      rmse_t_out <- sqrt(mean((t_test - pred_t_out) ^ 2))\n\n\n      # s-learner\n      print('s learner')\n      s_forest <-\n        regression_forest(cbind(X_train, z_train), y_train,  seed = k * i)\n      n_train <- nrow(X_train)\n      n_test <- nrow(X_test)\n      train_treated <- rep(1, n_train)\n      train_control <- rep(0, n_train)\n      test_treated <- rep(1, n_test)\n      test_control <- rep(0, n_test)\n      pred_s_in <-\n        predict(s_forest, cbind(X_train, train_treated))$predictions - predict(s_forest, cbind(X_train, train_control))$predictions\n\n      pred_s_out <-\n        predict(s_forest, cbind(X_test, test_treated))$predictions - predict(s_forest, cbind(X_test, test_control))$predictions\n      rmse_s_in <- sqrt(mean((t_train - pred_s_in) ^ 2))\n      rmse_s_out <- sqrt(mean((t_test - pred_s_out) ^ 2))\n\n\n\n      df_res <-\n          data.frame(\n            run = k,\n            cf_in = rmse_cf_in,\n            t_in = rmse_t_in,\n            s_in = rmse_s_in,\n            cf_out = rmse_cf_out,\n            t_out = rmse_t_out,\n            s_out = rmse_s_out\n          )\n\n\n      if (k == 0) {\n        write.table(\n          df_res,\n          file = paste0(\n            'results/experiments_benchmarking/twins/twins_grf_',\n            subset_train,\n            '_',\n            n_reps,\n            '.csv'\n          ),\n          col.names = T,\n          sep = ',',\n          row.names = F\n        )\n      }\n      else{\n        write.table(\n          df_res,\n          file = paste0(\n            'results/experiments_benchmarking/twins/twins_grf_',\n            subset_train,\n            '_',\n            n_reps,\n            '.csv'\n          ),\n          col.names = F,\n          sep = ',',\n          row.names = F,\n          append = T\n        )\n      }\n\n  }\n}\n"
  },
  {
    "path": "experiments/experiments_inductivebias_NeurIPS21/__init__.py",
    "content": ""
  },
  {
    "path": "experiments/experiments_inductivebias_NeurIPS21/experiments_AB.py",
    "content": "\"\"\"\nUtils to replicate setups A & B\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom typing import Optional, Tuple, Union\n\nimport numpy as onp\nfrom sklearn import clone\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.base import eval_root_mse\nfrom catenets.models.jax import (\n    DRAGON_NAME,\n    DRNET_NAME,\n    FLEXTE_NAME,\n    OFFSET_NAME,\n    RANET_NAME,\n    RNET_NAME,\n    SNET_NAME,\n    T_NAME,\n    TARNET_NAME,\n    XNET_NAME,\n    DragonNet,\n    DRNet,\n    FlexTENet,\n    OffsetNet,\n    RANet,\n    RNet,\n    SNet,\n    TARNet,\n    TNet,\n    XNet,\n)\n\nRESULT_DIR_SIMU = \"results/experiments_inductive_bias/acic2016/simulations/\"\nSEP = \"_\"\n\n# Hyperparms for all models\nPARAMS_DEPTH: dict = {\"n_layers_r\": 1, \"n_layers_out\": 1}\nPARAMS_DEPTH_2: dict = {\n    \"n_layers_r\": 1,\n    \"n_layers_out\": 1,\n    \"n_layers_r_t\": 1,\n    \"n_layers_out_t\": 1,\n}\nPENALTY_DIFF = 0.01\nPENALTY_ORTHOGONAL = 0.1\n\n# For main results\nALL_MODELS = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    T_NAME\n    + \"_reg\": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH),\n    TARNET_NAME: TARNet(**PARAMS_DEPTH),\n    TARNET_NAME\n    + \"_reg\": TARNet(\n        reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH\n    ),\n    OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH),\n    FLEXTE_NAME: FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH\n    ),\n    FLEXTE_NAME + \"_noortho_reg_same\": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH),\n    DRNET_NAME: DRNet(**PARAMS_DEPTH_2),\n    DRNET_NAME + \"_TAR\": DRNet(first_stage_strategy=\"Tar\", **PARAMS_DEPTH_2),\n}\n\n# For figure 4 in main text\nABLATIONS = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    T_NAME\n    + \"_reg\": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH),\n    T_NAME + \"_reg_same\": TNet(train_separate=False, **PARAMS_DEPTH),\n    OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH),\n    OFFSET_NAME + \"_reg_same\": OffsetNet(**PARAMS_DEPTH),\n    FLEXTE_NAME: FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH\n    ),\n    FLEXTE_NAME\n    + \"_reg_same\": FlexTENet(penalty_orthogonal=PENALTY_ORTHOGONAL, **PARAMS_DEPTH),\n    FLEXTE_NAME\n    + \"_noortho\": FlexTENet(\n        penalty_orthogonal=0, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH\n    ),\n    FLEXTE_NAME + \"_noortho_reg_same\": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH),\n}\n\n# For results in Appendix B.3\nFLEX_LAMBDA = {\n    \"FlexTENet_001\": FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=1 / 100, **PARAMS_DEPTH\n    ),\n    \"FlexTENet_01\": FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=1 / 10, **PARAMS_DEPTH\n    ),\n    \"FlexTENet_0001\": FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=1 / 1000, **PARAMS_DEPTH\n    ),\n    \"FlexTENet_00001\": FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=1 / 10000, **PARAMS_DEPTH\n    ),\n}\n\nT_LAMBDA = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    T_NAME + \"_reg_01\": TNet(train_separate=False, penalty_diff=1 / 10, **PARAMS_DEPTH),\n    T_NAME\n    + \"_reg_001\": TNet(train_separate=False, penalty_diff=1 / 100, **PARAMS_DEPTH),\n    T_NAME\n    + \"_reg_0001\": TNet(train_separate=False, penalty_diff=1 / 1000, **PARAMS_DEPTH),\n    T_NAME\n    + \"_reg_00001\": TNet(train_separate=False, penalty_diff=1 / 10000, **PARAMS_DEPTH),\n}\n\nOFFSET_LAMBDA = {\n    OFFSET_NAME + \"_reg_01\": OffsetNet(penalty_l2_p=1 / 10, **PARAMS_DEPTH),\n    OFFSET_NAME + \"_reg_001\": OffsetNet(penalty_l2_p=1 / 100, **PARAMS_DEPTH),\n    OFFSET_NAME + \"_reg_0001\": OffsetNet(penalty_l2_p=1 / 1000, **PARAMS_DEPTH),\n    OFFSET_NAME + \"_reg_00001\": OffsetNet(penalty_l2_p=1 / 10000, **PARAMS_DEPTH),\n}\n\n# For results in appendix D.1\nTWOSTEP_LEARNERS = {\n    XNET_NAME: XNet(**PARAMS_DEPTH_2),\n    RANET_NAME: RANet(**PARAMS_DEPTH_2),\n    RNET_NAME: RNet(**PARAMS_DEPTH_2),\n    DRNET_NAME: DRNet(**PARAMS_DEPTH_2),\n    T_NAME: TNet(**PARAMS_DEPTH),\n}\n\n# For results in Appendix D.2\nDRAGON_VARIANTS = {\n    DRAGON_NAME: DragonNet(**PARAMS_DEPTH),\n    DRAGON_NAME\n    + \"_reg\": DragonNet(\n        reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH\n    ),\n}\n\nSNET_VARIANTS = {\n    SNET_NAME: SNet(\n        n_units_r=100,\n        n_units_r_small=100,\n        ortho_reg_type=\"fro\",\n        penalty_orthogonal=PENALTY_ORTHOGONAL,\n        with_prop=False,\n        **PARAMS_DEPTH,\n    ),\n    SNET_NAME\n    + \"_reg\": SNet(\n        n_units_r=100,\n        n_units_r_small=100,\n        ortho_reg_type=\"fro\",\n        penalty_orthogonal=PENALTY_ORTHOGONAL,\n        with_prop=False,\n        penalty_diff=PENALTY_DIFF,\n        same_init=True,\n        reg_diff=True,\n        **PARAMS_DEPTH,\n    ),\n}\n\n# For results in appendix D.6\nDR_VARIANTS = {\n    DRNET_NAME\n    + \"_t_reg\": DRNet(\n        first_stage_args={\"train_separate\": False, \"penalty_diff\": PENALTY_DIFF},\n        **PARAMS_DEPTH_2,\n    ),\n    DRNET_NAME\n    + \"_Flex\": DRNet(\n        first_stage_strategy=\"Flex\",\n        first_stage_args={\n            \"private_out\": False,\n            \"penalty_orthogonal\": PENALTY_ORTHOGONAL,\n            \"penalty_l2_p\": PENALTY_DIFF,\n            \"normalize_ortho\": False,\n        },\n        **PARAMS_DEPTH_2,\n    ),\n}\n\n# results in appendix D.6\nX_VARIANTS = {\n    XNET_NAME\n    + \"_t_reg\": XNet(\n        first_stage_args={\"train_separate\": False, \"penalty_diff\": PENALTY_DIFF},\n        **PARAMS_DEPTH_2,\n    ),\n    XNET_NAME\n    + \"_Flex\": XNet(\n        first_stage_strategy=\"Flex\",\n        first_stage_args={\n            \"private_out\": False,\n            \"penalty_orthogonal\": PENALTY_ORTHOGONAL,\n            \"penalty_l2_p\": PENALTY_DIFF,\n            \"normalize_ortho\": False,\n        },\n        **PARAMS_DEPTH_2,\n    ),\n}\n\n\ndef do_acic_simu_loops(\n    rho_loop: list = [0, 0.05, 0.1, 0.2, 0.5, 0.8],\n    n1_loop: list = [200, 2000, 500],\n    n_exp: int = 10,\n    file_name: str = \"acic_simu\",\n    models: Optional[dict] = None,\n    n_0: int = 2000,\n    n_test: int = 500,\n    setting: str = \"A\",\n    factual_eval: bool = False,\n) -> None:\n    if models is None:\n        models = ALL_MODELS\n\n    for n_1 in n1_loop:\n        if setting == \"A\":\n            for rho in rho_loop:\n                do_acic_simu(\n                    n_1=n_1,\n                    n_exp=n_exp,\n                    file_name=file_name,\n                    models=models,\n                    n_0=n_0,\n                    n_test=n_test,\n                    prop_omega=0,\n                    prop_gamma=rho,\n                    factual_eval=factual_eval,\n                )\n        else:\n            for rho in rho_loop:\n                do_acic_simu(\n                    n_1=n_1,\n                    n_exp=n_exp,\n                    file_name=file_name,\n                    models=models,\n                    n_0=n_0,\n                    n_test=n_test,\n                    prop_gamma=0,\n                    prop_omega=rho,\n                    factual_eval=factual_eval,\n                )\n\n\ndef do_acic_simu(\n    n_exp: Union[int, list] = 10,\n    file_name: str = \"acic_simu\",\n    models: Union[dict, str, None] = None,\n    n_0: int = 2000,\n    n_1: int = 200,\n    n_test: int = 500,\n    error_sd: float = 1,\n    sp_lin: float = 0.6,\n    sp_nonlin: float = 0.3,\n    prop_gamma: float = 0,\n    ate_goal: float = 0,\n    inter: bool = True,\n    prop_omega: float = 0,\n    factual_eval: bool = False,\n) -> None:\n    if models is None:\n        models = ALL_MODELS\n    elif isinstance(models, str):\n        if models == \"all\":\n            models = ALL_MODELS\n        elif models == \"ablations\":\n            models = ABLATIONS\n        elif models == \"flex_lambda\":\n            models = FLEX_LAMBDA\n        elif models == \"t_lambda\":\n            models = T_LAMBDA\n        elif models == \"offset_lambda\":\n            models = OFFSET_LAMBDA\n        elif models == \"snet\":\n            models = SNET_VARIANTS\n        elif models == \"dragon\":\n            models = DRAGON_VARIANTS\n        elif models == \"twostep\":\n            models = TWOSTEP_LEARNERS\n        elif models == \"dr\":\n            models = DR_VARIANTS\n        elif models == \"x\":\n            models = X_VARIANTS\n        else:\n            raise ValueError(f\"{models} is not a valid model selection string.\")\n\n    # get file to write in\n    if not os.path.isdir(RESULT_DIR_SIMU):\n        os.makedirs(RESULT_DIR_SIMU)\n\n    out_file = open(\n        RESULT_DIR_SIMU\n        + file_name\n        + SEP\n        + str(n_0)\n        + SEP\n        + str(n_1)\n        + SEP\n        + str(prop_gamma)\n        + SEP\n        + str(prop_omega)\n        + \".csv\",\n        \"w\",\n        buffering=1,\n    )\n    writer = csv.writer(out_file)\n    header = (\n        [\"y_var\", \"cate_var\"]\n        + [name + \"_cate\" for name in models.keys()]\n        + [\n            name + \"_mu0\"\n            for name in models.keys()\n            if \"R\" not in name and \"X\" not in name\n        ]\n        + [\n            name + \"_mu1\"\n            for name in models.keys()\n            if \"R\" not in name and \"X\" not in name\n        ]\n    )\n\n    if factual_eval:\n        header = header + [\n            name + \"_factual\"\n            for name in models.keys()\n            if \"R\" not in name and \"X\" not in name\n        ]\n\n    writer.writerow(header)\n\n    if isinstance(n_exp, int):\n        experiment_loop = list(range(1, n_exp + 1))\n    elif isinstance(n_exp, list):\n        experiment_loop = n_exp\n    else:\n        raise ValueError(\"n_exp should be either an integer or a list of integers.\")\n\n    for i_exp in experiment_loop:\n        rmse_cate = []\n        rmse_mu0 = []\n        rmse_mu1 = []\n\n        # get data\n        if not factual_eval:\n            X, y, w, X_t, mu_0_t, mu_1_t, cate_t = acic_simu(\n                i_exp,\n                n_0=n_0,\n                n_1=n_1,\n                n_test=n_test,\n                error_sd=error_sd,\n                sp_lin=sp_lin,\n                sp_nonlin=sp_nonlin,\n                prop_gamma=prop_gamma,\n                ate_goal=ate_goal,\n                inter=inter,\n                prop_omega=prop_omega,\n            )\n        else:\n            rmse_factual = []\n            X, y, w, X_t, y_t, w_t, mu_0_t, mu_1_t, cate_t = acic_simu(\n                i_exp,\n                n_0=n_0,\n                n_1=n_1,\n                n_test=n_test,\n                error_sd=error_sd,\n                sp_lin=sp_lin,\n                sp_nonlin=sp_nonlin,\n                prop_gamma=prop_gamma,\n                ate_goal=ate_goal,\n                inter=inter,\n                prop_omega=prop_omega,\n                return_ytest=True,\n            )\n\n        y_var = onp.var(y)\n        cate_var = onp.var(cate_t)\n\n        # split data\n        for model_name, estimator in models.items():\n            print(f\"Experiment {i_exp} with {model_name}\")\n            estimator_temp = clone(estimator)\n\n            # fit estimator\n            estimator_temp.fit(X=X, y=y, w=w)\n\n            if \"R\" not in model_name and \"X\" not in model_name:\n                cate_pred_out, mu0_pred, mu1_pred = estimator_temp.predict(\n                    X_t, return_po=True\n                )\n                rmse_mu0.append(eval_root_mse(mu0_pred, mu_0_t))\n                rmse_mu1.append(eval_root_mse(mu1_pred, mu_1_t))\n                if factual_eval:\n                    pred_factual = w_t * mu1_pred + (1 - w_t) * mu0_pred\n                    rmse_factual.append(eval_root_mse(pred_factual, y_t))\n            else:\n                cate_pred_out = estimator_temp.predict(X_t)\n\n            rmse_cate.append(eval_root_mse(cate_pred_out, cate_t))\n\n        if not factual_eval:\n            writer.writerow([y_var, cate_var] + rmse_cate + rmse_mu0 + rmse_mu1)\n        else:\n            writer.writerow(\n                [y_var, cate_var] + rmse_cate + rmse_mu0 + rmse_mu1 + rmse_factual\n            )\n\n    out_file.close()\n\n\ndef acic_simu(\n    i_exp: onp.ndarray,\n    n_0: int = 2000,\n    n_1: int = 200,\n    n_test: int = 500,\n    error_sd: float = 1,\n    sp_lin: float = 0.6,\n    sp_nonlin: float = 0.3,\n    prop_gamma: float = 0,\n    prop_omega: float = 0,\n    ate_goal: float = 0,\n    inter: bool = True,\n    return_ytest: bool = False,\n) -> Tuple:\n    X_train, w_train, y_train, _, X_test, w_test, y_test, po_test = load(\n        \"acic2016\",\n        i_exp=i_exp,\n        n_0=n_0,\n        n_1=n_1,\n        n_test=n_test,\n        error_sd=error_sd,\n        sp_lin=sp_lin,\n        sp_nonlin=sp_nonlin,\n        prop_gamma=prop_gamma,\n        prop_omega=prop_omega,\n        ate_goal=ate_goal,\n        inter=inter,\n    )\n    mu_0_t = po_test[:, 0]\n    mu_1_t = po_test[:, 1]\n    cate_t = mu_1_t - mu_0_t\n\n    if return_ytest:\n        return X_train, y_train, w_train, X_test, y_test, w_test, mu_0_t, mu_1_t, cate_t\n\n    return X_train, y_train, w_train, X_test, mu_0_t, mu_1_t, cate_t\n"
  },
  {
    "path": "experiments/experiments_inductivebias_NeurIPS21/experiments_CD.py",
    "content": "\"\"\"\nUtils to replicate experiments C and D\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nfrom sklearn import clone\n\nfrom catenets.datasets.dataset_ihdp import get_one_data_set, load_raw, prepare_ihdp_data\nfrom catenets.experiment_utils.base import eval_root_mse\nfrom catenets.models.jax import (\n    DRNET_NAME,\n    FLEXTE_NAME,\n    OFFSET_NAME,\n    T_NAME,\n    TARNET_NAME,\n    DRNet,\n    FlexTENet,\n    OffsetNet,\n    TARNet,\n    TNet,\n)\n\nDATA_DIR = Path(\"catenets/datasets/data/\")\nRESULT_DIR = Path(\"results/experiments_inductive_bias/ihdp/\")\nSEP = \"_\"\n\nPARAMS_DEPTH: dict = {\"n_layers_r\": 2, \"n_layers_out\": 2}\nPARAMS_DEPTH_2: dict = {\n    \"n_layers_r\": 2,\n    \"n_layers_out\": 2,\n    \"n_layers_r_t\": 2,\n    \"n_layers_out_t\": 2,\n}\nPENALTY_DIFF = 0.01\nPENALTY_ORTHOGONAL = 0.1\n\nALL_MODELS = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    T_NAME\n    + \"_reg\": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH),\n    TARNET_NAME: TARNet(**PARAMS_DEPTH),\n    TARNET_NAME\n    + \"_reg\": TARNet(\n        reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH\n    ),\n    OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH),\n    FLEXTE_NAME: FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH\n    ),\n    FLEXTE_NAME + \"_noortho_reg_same\": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH),\n    DRNET_NAME: DRNet(**PARAMS_DEPTH_2),\n    DRNET_NAME + \"_TAR\": DRNet(first_stage_strategy=\"Tar\", **PARAMS_DEPTH_2),\n}\n\n\ndef do_ihdp_experiments(\n    n_exp: Union[int, list] = 100,\n    file_name: str = \"ihdp_all\",\n    model_params: Optional[dict] = None,\n    models: Optional[dict] = None,\n    setting: str = \"C\",\n) -> None:\n    if models is None:\n        models = ALL_MODELS\n\n    # get file to write in\n    if not os.path.isdir(RESULT_DIR):\n        os.makedirs(RESULT_DIR)\n\n    out_file = open(RESULT_DIR / (file_name + SEP + setting + \".csv\"), \"w\", buffering=1)\n    writer = csv.writer(out_file)\n    header = [name + \"_in\" for name in models.keys()] + [\n        name + \"_out\" for name in models.keys()\n    ]\n    writer.writerow(header)\n\n    # get data\n    data_train, data_test = load_raw(DATA_DIR)\n\n    if isinstance(n_exp, int):\n        experiment_loop = list(range(1, n_exp + 1))\n    elif isinstance(n_exp, list):\n        experiment_loop = n_exp\n    else:\n        raise ValueError(\"n_exp should be either an integer or a list of integers.\")\n\n    for i_exp in experiment_loop:\n        pehe_in = []\n        pehe_out = []\n\n        # get data\n        data_exp = get_one_data_set(data_train, i_exp=i_exp, get_po=True)\n        data_exp_test = get_one_data_set(data_test, i_exp=i_exp, get_po=True)\n\n        X, y, w, cate_true_in, X_t, cate_true_out = prepare_ihdp_data(\n            data_exp, data_exp_test, setting=setting\n        )\n\n        for model_name, estimator in models.items():\n            print(f\"Experiment {i_exp} with {model_name}\")\n            estimator_temp = clone(estimator)\n            if model_params is not None:\n                estimator_temp.set_params(**model_params)\n\n            # fit estimator\n            estimator_temp.fit(X=X, y=y, w=w)\n\n            cate_pred_in = estimator_temp.predict(X, return_po=False)\n            cate_pred_out = estimator_temp.predict(X_t, return_po=False)\n\n            pehe_in.append(eval_root_mse(cate_pred_in, cate_true_in))\n            pehe_out.append(eval_root_mse(cate_pred_out, cate_true_out))\n\n        writer.writerow(pehe_in + pehe_out)\n\n    out_file.close()\n"
  },
  {
    "path": "experiments/experiments_inductivebias_NeurIPS21/experiments_acic.py",
    "content": "\"\"\"\nUtils to replicate ACIC2016 experiments (Appendix E.1)\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib import Path\n\nimport numpy as np\nfrom sklearn import clone\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.base import eval_root_mse\nfrom catenets.models.jax import (\n    DRNET_NAME,\n    FLEXTE_NAME,\n    OFFSET_NAME,\n    T_NAME,\n    TARNET_NAME,\n    DRNet,\n    FlexTENet,\n    OffsetNet,\n    TARNet,\n    TNet,\n)\n\nRESULT_DIR = Path(\"results/experiments_inductive_bias/acic2016/original\")\nSEP = \"_\"\n\nPARAMS_DEPTH = {\"n_layers_r\": 1, \"n_layers_out\": 1}\nPARAMS_DEPTH_2 = {\n    \"n_layers_r\": 1,\n    \"n_layers_out\": 1,\n    \"n_layers_r_t\": 1,\n    \"n_layers_out_t\": 1,\n}\nPENALTY_DIFF = 0.01\nPENALTY_ORTHOGONAL = 0.1\n\n\nALL_MODELS = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    T_NAME\n    + \"_reg\": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH),\n    TARNET_NAME: TARNet(**PARAMS_DEPTH),\n    TARNET_NAME\n    + \"_reg\": TARNet(\n        reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH\n    ),\n    OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH),\n    FLEXTE_NAME: FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH\n    ),\n    FLEXTE_NAME + \"_noortho_reg_same\": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH),\n    DRNET_NAME: DRNet(**PARAMS_DEPTH_2),\n    DRNET_NAME + \"_TAR\": DRNet(first_stage_strategy=\"Tar\", **PARAMS_DEPTH_2),\n}\n\n\ndef do_acic_orig_loop(\n    simu_nums,\n    n_exp: int = 10,\n    file_name: str = \"results\",\n    models: dict = None,\n    train_size: float = 0.8,\n):\n    if models is None:\n        models = ALL_MODELS\n    for simu_num in simu_nums:\n        print(f\"Running simulation setting {simu_num}\")\n        do_acic_experiments(\n            n_exp=n_exp,\n            file_name=file_name,\n            simu_num=simu_num,\n            models=models,\n            train_size=train_size,\n        )\n\n\ndef do_acic_experiments(\n    n_exp: int = 10,\n    file_name: str = \"results_catenets\",\n    simu_num: int = 1,\n    models: dict = None,\n    train_size: float = 0.8,\n    pre_trans: bool = False,\n):\n    if models is None:\n        models = ALL_MODELS\n\n    # get file to write in\n    if not os.path.isdir(RESULT_DIR):\n        os.makedirs(RESULT_DIR)\n\n    out_file = open(\n        RESULT_DIR\n        / (\n            file_name\n            + SEP\n            + str(pre_trans)\n            + SEP\n            + str(simu_num)\n            + SEP\n            + str(train_size)\n            + \".csv\"\n        ),\n        \"w\",\n        buffering=1,\n    )\n    writer = csv.writer(out_file)\n    header = (\n        [\"file_name\", \"cate_var_in\", \"cate_var_out\", \"y_var_in\"]\n        + [name + \"_in\" for name in models.keys()]\n        + [name + \"_out\" for name in models.keys()]\n    )\n    writer.writerow(header)\n\n    for i_exp in range(n_exp):\n        # get data\n        X, w, y, po_train, X_test, w_test, y_test, po_test = load(\n            \"acic2016\",\n            preprocessed=pre_trans,\n            original_acic_outcomes=True,\n            keep_categorical=False,\n            random_split=True,\n            i_exp=i_exp,\n            simu_num=simu_num,\n            train_size=train_size,\n        )\n\n        cate_in = po_train[:, 1] - po_train[:, 0]\n        cate_out = po_test[:, 1] - po_test[:, 0]\n\n        cate_var_in = np.var(cate_in)\n        cate_var_out = np.var(cate_out)\n        y_var_in = np.var(y)\n\n        pehe_in = []\n        pehe_out = []\n\n        for model_name, estimator in models.items():\n            print(f\"Experiment {i_exp} with {model_name}\")\n            estimator_temp = clone(estimator)\n\n            # fit estimator\n            estimator_temp.fit(X=X, y=y, w=w)\n\n            cate_pred_in = estimator_temp.predict(X, return_po=False)\n            cate_pred_out = estimator_temp.predict(X_test, return_po=False)\n\n            pehe_in.append(eval_root_mse(cate_pred_in, cate_in))\n            pehe_out.append(eval_root_mse(cate_pred_out, cate_out))\n\n        writer.writerow(\n            [i_exp, cate_var_in, cate_var_out, y_var_in] + pehe_in + pehe_out\n        )\n\n    out_file.close()\n"
  },
  {
    "path": "experiments/experiments_inductivebias_NeurIPS21/experiments_twins.py",
    "content": "\"\"\"\nUtils to replicate Twins experiments (Appendix E.2)\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib import Path\n\nimport numpy as np\nfrom sklearn import clone\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.preprocessing import label_binarize\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.base import eval_root_mse\nfrom catenets.models.jax import (\n    DRNET_NAME,\n    FLEXTE_NAME,\n    OFFSET_NAME,\n    T_NAME,\n    TARNET_NAME,\n    DRNet,\n    FlexTENet,\n    OffsetNet,\n    TARNet,\n    TNet,\n)\nfrom catenets.models.jax.base import check_shape_1d_data\n\nRESULT_DIR = Path(\"results/experiments_inductive_bias/twins\")\nSEP = \"_\"\n\nPARAMS_DEPTH = {\"n_layers_r\": 1, \"n_layers_out\": 1}\nPARAMS_DEPTH_2 = {\n    \"n_layers_r\": 1,\n    \"n_layers_out\": 1,\n    \"n_layers_r_t\": 1,\n    \"n_layers_out_t\": 1,\n}\nPENALTY_DIFF = 0.01\nPENALTY_ORTHOGONAL = 0.1\n\nALL_MODELS = {\n    T_NAME: TNet(**PARAMS_DEPTH),\n    T_NAME\n    + \"_reg\": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH),\n    TARNET_NAME: TARNet(**PARAMS_DEPTH),\n    TARNET_NAME\n    + \"_reg\": TARNet(\n        reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH\n    ),\n    OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH),\n    FLEXTE_NAME: FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH\n    ),\n    FLEXTE_NAME + \"_noortho_reg_same\": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH),\n    DRNET_NAME: DRNet(**PARAMS_DEPTH_2),\n    DRNET_NAME + \"_TAR\": DRNet(first_stage_strategy=\"Tar\", **PARAMS_DEPTH_2),\n}\n\n\ndef do_twins_experiment_loop(\n    n_train_loop=[500, 1000, 2000, 5000, None],\n    prop_loop=[0.1, 0.25, 0.5, 0.75, 0.9],\n    n_exp: int = 10,\n    file_name: str = \"twins\",\n    models: dict = None,\n    test_size=0.5,\n):\n    for n in n_train_loop:\n        for prop in prop_loop:\n            print(\n                \"Running twins experiment for {} training samples with {} treated.\".format(\n                    n, prop\n                )\n            )\n            do_twins_experiments(\n                n_exp=n_exp,\n                file_name=file_name,\n                models=models,\n                subset_train=n,\n                prop_treated=prop,\n                test_size=test_size,\n            )\n\n\ndef do_twins_experiments(\n    n_exp: int = 10,\n    file_name: str = \"twins\",\n    models: dict = None,\n    subset_train: int = None,\n    prop_treated=0.5,\n    test_size=0.5,\n):\n    if models is None:\n        models = ALL_MODELS\n\n    # get file to write in\n    if not os.path.isdir(RESULT_DIR):\n        os.makedirs(RESULT_DIR)\n\n    out_file = open(\n        RESULT_DIR\n        / (file_name + SEP + str(prop_treated) + SEP + str(subset_train) + \".csv\"),\n        \"w\",\n        buffering=1,\n    )\n    writer = csv.writer(out_file)\n    header = (\n        [name + \"_cate\" for name in models.keys()]\n        + [\n            name + \"_auc_ite\"\n            for name in models.keys()\n            if \"R\" not in name and \"X\" not in name\n        ]\n        + [\n            name + \"_auc_mu0\"\n            for name in models.keys()\n            if \"R\" not in name and \"X\" not in name\n        ]\n        + [\n            name + \"_auc_mu1\"\n            for name in models.keys()\n            if \"R\" not in name and \"X\" not in name\n        ]\n        + [\n            name + \"_ap_mu0\"\n            for name in models.keys()\n            if \"R\" not in name and \"X\" not in name\n        ]\n        + [\n            name + \"_ap_mu1\"\n            for name in models.keys()\n            if \"R\" not in name and \"X\" not in name\n        ]\n    )\n\n    writer.writerow(header)\n\n    for i_exp in range(n_exp):\n        pehe_out = []\n        auc_ite = []\n        auc_mu0 = []\n        auc_mu1 = []\n        ap_mu0 = []\n        ap_mu1 = []\n\n        # get data\n        x, w, y, pos, _, _ = load(\n            \"twins\", seed=i_exp, treat_prop=prop_treated, train_ratio=1\n        )\n\n        # split data\n        X, X_t, y, y_t, w, w_t, y0_in, y0_out, y1_in, y1_out = split_data(\n            x,\n            y,\n            w,\n            pos,\n            random_state=i_exp,\n            subset_train=subset_train,\n            test_size=test_size,\n        )\n\n        ite_out = y1_out - y0_out\n\n        ite_out_encoded = label_binarize(ite_out, [-1, 0, 1])\n\n        n_test = X_t.shape[0]\n\n        # split data\n        for model_name, estimator in models.items():\n            print(f\"Experiment {i_exp} with {model_name}\")\n            estimator_temp = clone(estimator)\n            estimator_temp.set_params(**{\"binary_y\": True})\n\n            # fit estimator\n            estimator_temp.fit(X=X, y=y, w=w)\n\n            if (\n                \"DR\" not in model_name\n                and \"R\" not in model_name\n                and \"X\" not in model_name\n            ):\n                cate_pred_out, mu0_pred, mu1_pred = estimator_temp.predict(\n                    X_t, return_po=True\n                )\n\n                # create probabilities for each possible level of ITE\n                probs = np.zeros((n_test, 3))\n                probs[:, 0] = (mu0_pred * (1 - mu1_pred)).reshape((-1,))  # P(Y1-Y0=-1)\n                probs[:, 1] = (\n                    (mu0_pred * mu1_pred) + ((1 - mu0_pred) * (1 - mu1_pred))\n                ).reshape(\n                    (-1,)\n                )  # P(Y1-Y0=0)\n                probs[:, 2] = (mu1_pred * (1 - mu0_pred)).reshape((-1,))  # P(Y1-Y0=1)\n                auc_ite.append(roc_auc_score(ite_out_encoded, probs))\n\n                # evaluate performance on potential outcomes\n                auc_mu0.append(eval_roc_auc(y0_out, mu0_pred))\n                auc_mu1.append(eval_roc_auc(y1_out, mu1_pred))\n                ap_mu0.append(eval_ap(y0_out, mu0_pred))\n                ap_mu1.append(eval_ap(y1_out, mu1_pred))\n            else:\n                cate_pred_out = estimator_temp.predict(X_t)\n\n            pehe_out.append(eval_root_mse(cate_pred_out, ite_out))\n\n        writer.writerow(pehe_out + auc_ite + auc_mu0 + auc_mu1 + ap_mu0 + ap_mu1)\n\n    out_file.close()\n\n\n# utils -------\ndef split_data(X, y, w, pos, test_size=0.5, random_state=42, subset_train: int = None):\n    X, X_t, y, y_t, w, w_t, y0_in, y0_out, y1_in, y1_out = train_test_split(\n        X, y, w, pos[:, 0], pos[:, 1], test_size=test_size, random_state=random_state\n    )\n    if subset_train is not None:\n        X, y, w, y0_in, y1_in = (\n            X[:subset_train, :],\n            y[:subset_train],\n            w[:subset_train],\n            y0_in[:subset_train],\n            y1_in[:subset_train],\n        )\n\n    return X, X_t, y, y_t, w, w_t, y0_in, y0_out, y1_in, y1_out\n\n\ndef eval_roc_auc(targets, preds):\n    preds = check_shape_1d_data(preds)\n    targets = check_shape_1d_data(targets)\n    return roc_auc_score(targets, preds)\n\n\ndef eval_ap(targets, preds):\n    preds = check_shape_1d_data(preds)\n    targets = check_shape_1d_data(targets)\n    return average_precision_score(targets, preds)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\n# AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD!\nrequires = [\"setuptools>=46.1.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "pytest.ini",
    "content": "[pytest]\nmarkers =\n    slow: mark a test as slow.\n"
  },
  {
    "path": "run_experiments_AISTATS.py",
    "content": "\"\"\"\nFile to run AISTATS experiments from shell\n\"\"\"\n# Author: Alicia Curth\nimport argparse\nimport sys\nfrom typing import Any\n\nimport catenets.logger as log\nfrom experiments.experiments_AISTATS21.ihdp_experiments import do_ihdp_experiments\nfrom experiments.experiments_AISTATS21.simulations_AISTATS import main_AISTATS\n\nlog.add(sink=sys.stderr, level=\"DEBUG\")\n\n\ndef init_arg() -> Any:\n    # arg parser if script is run from shell\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--experiment\", default=\"simulation\", type=str)\n    parser.add_argument(\"--setting\", default=1, type=int)\n    parser.add_argument(\"--models\", default=None, type=str)\n    parser.add_argument(\"--file_name\", default=\"results\", type=str)\n    parser.add_argument(\"--n_repeats\", default=10, type=int)\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = init_arg()\n    if args.experiment == \"simulation\":\n        main_AISTATS(\n            setting=args.setting,\n            models=args.models,\n            file_name=args.file_name,\n            n_repeats=args.n_repeats,\n        )\n    elif args.experiment == \"ihdp\":\n        do_ihdp_experiments(\n            models=args.models, file_name=args.file_name, n_exp=args.n_repeats\n        )\n"
  },
  {
    "path": "run_experiments_benchmarks_NeurIPS.py",
    "content": "\"\"\"\nFile to run the catenets experiments for\n\"Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in\nTreatment Effect Estimation\" (Curth & vdS, NeurIPS21)\nfrom shell\n\"\"\"\n# Author: Alicia Curth\nimport argparse\nimport sys\nfrom typing import Any\n\nimport catenets.logger as log\nfrom experiments.experiments_benchmarks_NeurIPS21.acic_experiments_catenets import (\n    do_acic_experiments,\n)\nfrom experiments.experiments_benchmarks_NeurIPS21.ihdp_experiments_catenets import (\n    do_ihdp_experiments,\n)\nfrom experiments.experiments_benchmarks_NeurIPS21.twins_experiments_catenets import (\n    do_twins_experiment_loop,\n)\n\nlog.add(sink=sys.stderr, level=\"DEBUG\")\n\n\ndef init_arg() -> Any:\n    # arg parser\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--setting\", default=\"C\", type=str)\n    parser.add_argument(\"--experiment\", default=\"ihdp\", type=str)\n    parser.add_argument(\"--file_name\", default=\"results\", type=str)\n    parser.add_argument(\"--n_exp\", default=10, type=int)\n    parser.add_argument(\"--n_reps\", default=5, type=int)\n    parser.add_argument(\"--pre_trans\", type=bool, default=False)\n    parser.add_argument(\"--simu_num\", type=int, default=2)\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = init_arg()\n    if (args.experiment == \"ihdp\") or (args.experiment == \"IHDP\"):\n        do_ihdp_experiments(\n            file_name=args.file_name,\n            n_exp=args.n_exp,\n            setting=args.setting,\n            n_reps=args.n_reps,\n        )\n    elif (args.experiment == \"acic\") or (args.experiment == \"ACIC\"):\n        do_acic_experiments(\n            file_name=args.file_name,\n            n_reps=args.n_reps,\n            simu_num=args.simu_num,\n            n_exp=args.n_exp,\n            pre_trans=args.pre_trans,\n        )\n    elif (args.experiment == \"twins\") or (args.experiment == \"Twins\"):\n        do_twins_experiment_loop(file_name=args.file_name, n_exp=args.n_reps)\n\n    else:\n        raise ValueError(\n            f\"Experiment should be one of ihdp/IHDP, acic/ACIC and twins/Twins. You \"\n            f\"passed {args.experiment}\"\n        )\n"
  },
  {
    "path": "run_experiments_inductive_bias_NeurIPS.py",
    "content": "\"\"\"\nFile to run experiments for\n\"On Inductive Biases for Heterogeneous Treatment Effect Estimation\" (Curth & vdS, NeurIPS21)\nfrom shell\n\"\"\"\n# Author: Alicia Curth\nimport argparse\nimport sys\nfrom typing import Any\n\nimport catenets.logger as log\nfrom experiments.experiments_inductivebias_NeurIPS21.experiments_AB import (\n    do_acic_simu_loops,\n)\nfrom experiments.experiments_inductivebias_NeurIPS21.experiments_acic import (\n    do_acic_orig_loop,\n)\nfrom experiments.experiments_inductivebias_NeurIPS21.experiments_CD import (\n    do_ihdp_experiments,\n)\nfrom experiments.experiments_inductivebias_NeurIPS21.experiments_twins import (\n    do_twins_experiment_loop,\n)\n\nlog.add(sink=sys.stderr, level=\"DEBUG\")\n\n\ndef init_arg() -> Any:\n    # arg parser\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--setup\", default=\"A\", type=str)\n    parser.add_argument(\"--file_name\", default=\"results\", type=str)\n    parser.add_argument(\"--n_exp\", default=10, type=int)\n    parser.add_argument(\"--n_0\", default=2000, type=int)\n    parser.add_argument(\"--models\", default=None, type=str)\n    parser.add_argument(\"--n1_loop\", nargs=\"+\", default=[200, 2000, 500], type=int)\n    parser.add_argument(\n        \"--rho_loop\", nargs=\"+\", default=[0, 0.05, 0.1, 0.2, 0.5, 0.8], type=float\n    )\n    parser.add_argument(\"--factual_eval\", default=False, type=bool)\n    parser.add_argument(\n        \"--simu_nums\", nargs=\"+\", default=[x for x in range(1, 78)], type=int\n    )\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = init_arg()\n    if (args.setup == \"A\") or (args.setup == \"B\"):\n        do_acic_simu_loops(\n            n_exp=args.n_exp,\n            file_name=args.file_name,\n            setting=args.setup,\n            n_0=args.n_0,\n            models=args.models,\n            n1_loop=args.n1_loop,\n            rho_loop=args.rho_loop,\n            factual_eval=args.factual_eval,\n        )\n    elif (args.setup == \"C\") or (args.setup == \"D\"):\n        do_ihdp_experiments(\n            file_name=args.file_name, n_exp=args.n_exp, setting=args.setup\n        )\n    elif (args.setup == \"acic\") or (args.setup == \"ACIC\"):\n        # Appendix E.1\n        do_acic_orig_loop(\n            simu_nums=args.simu_nums, n_exp=args.n_exp, file_name=args.file_name\n        )\n    elif (args.setup == \"twins\") or (args.setup == \"Twins\"):\n        # Appendix E.2\n        do_twins_experiment_loop(file_name=args.file_name, n_exp=args.n_exp)\n    else:\n        raise ValueError(\n            f\"Setup should be one of A, B, C, D, acic/ACIC or twins/Twins You passed\"\n            f\" {args.setup}\"\n        )\n"
  },
  {
    "path": "setup.py",
    "content": "# stdlib\nimport os\nimport re\n\n# third party\nfrom setuptools import setup\n\nPKG_DIR = os.path.dirname(os.path.abspath(__file__))\n\n\ndef read(fname: str) -> str:\n    return open(os.path.join(os.path.dirname(__file__), fname)).read()\n\n\ndef find_version() -> str:\n    version_file = read(\"catenets/version.py\").split(\"\\n\")[0]\n    version_re = r\"__version__ = \\\"(?P<version>.+)\\\"\"\n    version_raw = re.match(version_re, version_file)\n\n    if version_raw is None:\n        return \"0.0.1\"\n\n    version = version_raw.group(\"version\")\n    return version\n\n\nif __name__ == \"__main__\":\n    try:\n        setup(\n            version=find_version(),\n        )\n    except:  # noqa\n        print(\n            \"\\n\\nAn error occurred while building the project, \"\n            \"please ensure you have the most updated version of setuptools, \"\n            \"setuptools_scm and wheel with:\\n\"\n            \"   pip install -U setuptools setuptools_scm wheel\\n\\n\"\n        )\n        raise\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "import sys\n\nimport catenets.logger as log\n\nlog.add(sink=sys.stderr, level=\"CRITICAL\")\n"
  },
  {
    "path": "tests/datasets/test_datasets.py",
    "content": "import pytest\n\nfrom catenets.datasets import load\n\n\n@pytest.mark.parametrize(\"train_ratio\", [0.5, 0.8])\n@pytest.mark.parametrize(\"treatment_type\", [\"rand\", \"logistic\"])\n@pytest.mark.parametrize(\"treat_prop\", [0.1, 0.9])\ndef test_dataset_sanity_twins(\n    train_ratio: float, treatment_type: str, treat_prop: float\n) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\n        \"twins\",\n        train_ratio=train_ratio,\n        treatment_type=treatment_type,\n        treat_prop=treat_prop,\n    )\n\n    total = X_train.shape[0] + X_test.shape[0]\n\n    assert int(total * train_ratio) == X_train.shape[0]\n    assert X_train.shape[1] == X_test.shape[1]\n    assert X_train.shape[0] == Y_train.shape[0]\n    assert X_train.shape[0] == Y_train_full.shape[0]\n    assert X_train.shape[0] == W_train.shape[0]\n    assert X_test.shape[0] == Y_test.shape[0]\n\n\ndef test_dataset_sanity_ihdp() -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\"ihdp\")\n\n    assert X_train.shape[1] == X_test.shape[1]\n    assert X_train.shape[0] == Y_train.shape[0]\n    assert X_train.shape[0] == Y_train_full.shape[0]\n    assert X_train.shape[0] == W_train.shape[0]\n    assert X_test.shape[0] == Y_test.shape[0]\n\n\n@pytest.mark.slow\n@pytest.mark.parametrize(\"preprocessed\", [False, True])\ndef test_dataset_sanity_acic2016(preprocessed: bool) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\n        \"acic2016\", preprocessed=preprocessed\n    )\n\n    assert X_train.shape[1] == X_test.shape[1]\n    assert X_train.shape[0] == Y_train.shape[0]\n    assert X_train.shape[0] == Y_train_full.shape[0]\n    assert X_train.shape[0] == W_train.shape[0]\n    assert X_test.shape[0] == Y_test.shape[0]\n"
  },
  {
    "path": "tests/models/jax/test_jax_ite.py",
    "content": "from copy import deepcopy\n\nimport pytest\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester import evaluate_treatments_model\nfrom catenets.models.jax import FLEXTE_NAME, OFFSET_NAME, FlexTENet, OffsetNet\n\nLAYERS_OUT = 2\nLAYERS_R = 3\nPENALTY_L2 = 0.01 / 100\nPENALTY_ORTHOGONAL_IHDP = 0\n\nPARAMS_DEPTH: dict = {\"n_layers_r\": 2, \"n_layers_out\": 2, \"n_iter\": 10}\nPENALTY_DIFF = 0.01\nPENALTY_ORTHOGONAL = 0.1\n\nALL_MODELS = {\n    OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH),\n    FLEXTE_NAME: FlexTENet(\n        penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH\n    ),\n}\n\nmodels = list(ALL_MODELS.keys())\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4), (\"ihdp\", 3)])\n@pytest.mark.parametrize(\"model_name\", models)\ndef test_model_sanity(dataset: str, pehe_threshold: float, model_name: str) -> None:\n    model = deepcopy(ALL_MODELS[model_name])\n\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n\n    score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train)\n    print(f\"Evaluation for model jax.{model_name} on {dataset} = {score['str']}\")\n\n\ndef test_model_score() -> None:\n    model = OffsetNet(n_iter=10)\n\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\"ihdp\")\n\n    model.fit(X_train[:10], Y_train[:10], W_train[:10])\n\n    result = model.score(X_test, Y_test)\n\n    assert result > 0\n\n    with pytest.raises(ValueError):\n        model.score(X_train, Y_train)  # Y_train has just one outcome\n"
  },
  {
    "path": "tests/models/jax/test_jax_model_utils.py",
    "content": "from typing import Any\n\nimport jax.numpy as jnp\nimport numpy as np\nimport pandas as pd\nimport pytest\n\nfrom catenets.models.jax.model_utils import (\n    check_shape_1d_data,\n    check_X_is_np,\n    make_val_split,\n)\n\n\n@pytest.mark.parametrize(\"data\", [np.array([1, 2, 3]), np.array([[1, 2], [3, 4]])])\ndef test_check_shape_1d_data_sanity(data: np.ndarray) -> None:\n    out = check_shape_1d_data(data)\n\n    assert len(out.shape) == 2\n\n\n@pytest.mark.parametrize(\"data\", [np.array([1, 2, 3]), pd.DataFrame([1, 2])])\ndef test_check_X_is_np_sanity(data: Any) -> None:\n    out = check_X_is_np(data)\n\n    assert isinstance(out, jnp.ndarray)\n\n\ndef test_make_val_split_sanity() -> None:\n    X = np.random.rand(1000, 5)\n    y = np.random.randint(0, 1, size=1000)\n    w = np.random.randint(0, 1, size=1000)\n\n    X_t, y_t, w_t, X_val, y_val, w_val, VALIDATION_STRING = make_val_split(X, y, w)\n\n    assert X_t.shape[0] == 700\n    assert y_t.shape[0] == 700\n    assert w_t.shape[0] == 700\n    assert X_val.shape[0] == 300\n    assert y_val.shape[0] == 300\n    assert w_val.shape[0] == 300\n    assert VALIDATION_STRING == \"validation\"\n"
  },
  {
    "path": "tests/models/jax/test_jax_transformation_utils.py",
    "content": "from typing import Callable\n\nimport numpy as np\nimport pytest\n\nfrom catenets.models.jax.transformation_utils import (\n    ALL_TRANSFORMATIONS,\n    DR_TRANSFORMATION,\n    PW_TRANSFORMATION,\n    RA_TRANSFORMATION,\n    _get_transformation_function,\n    aipw_te_transformation,\n    ht_te_transformation,\n    ra_te_transformation,\n)\n\n\ndef test_get_transformation_function_sanity() -> None:\n    expected_fns = [ht_te_transformation, aipw_te_transformation, ra_te_transformation]\n\n    for tr, expected in zip(ALL_TRANSFORMATIONS, expected_fns):\n        assert _get_transformation_function(tr) is expected\n\n    with pytest.raises(ValueError):\n        _get_transformation_function(\"invalid\")\n\n\n@pytest.mark.parametrize(\n    \"fn\", [aipw_te_transformation, _get_transformation_function(DR_TRANSFORMATION)]\n)\ndef test_aipw_te_transformation_sanity(fn: Callable) -> None:\n    res = fn(\n        y=np.array([0, 1]),\n        w=np.array([1, 0]),\n        p=None,\n        mu_0=np.array([0.4, 0.6]),\n        mu_1=np.array([0.6, 0.4]),\n    )\n    assert res.shape[0] == 2\n\n\n@pytest.mark.parametrize(\n    \"fn\", [ht_te_transformation, _get_transformation_function(PW_TRANSFORMATION)]\n)\ndef test_ht_te_transformation_sanity(fn: Callable) -> None:\n    res = fn(\n        y=np.array([0, 1]),\n        w=np.array([1, 0]),\n    )\n    assert res.shape[0] == 2\n\n\n@pytest.mark.parametrize(\n    \"fn\", [ra_te_transformation, _get_transformation_function(RA_TRANSFORMATION)]\n)\ndef test_ra_te_transformation_sanity(fn: Callable) -> None:\n    res = fn(\n        y=np.array([0, 1]),\n        w=np.array([1, 0]),\n        p=None,\n        mu_0=np.array([0.4, 0.6]),\n        mu_1=np.array([0.6, 0.4]),\n    )\n    assert res.shape[0] == 2\n"
  },
  {
    "path": "tests/models/torch/test_torch_flextenet.py",
    "content": "import numpy as np\nimport pytest\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester import evaluate_treatments_model\nfrom catenets.models.torch import FlexTENet\n\n\ndef test_flextenet_model_params() -> None:\n    model = FlexTENet(\n        2,\n        binary_y=True,\n        n_layers_out=1,\n        n_layers_r=2,\n        n_units_s_out=20,\n        n_units_p_out=30,\n        n_units_s_r=40,\n        n_units_p_r=50,\n        private_out=True,\n        weight_decay=1e-5,\n        penalty_orthogonal=1e-7,\n        lr=1e-2,\n        n_iter=123,\n        batch_size=234,\n        early_stopping=True,\n        patience=5,\n        n_iter_min=13,\n        n_iter_print=7,\n        seed=42,\n        shared_repr=False,\n        normalize_ortho=False,\n        mode=1,\n    )\n\n    assert model.binary_y is True\n    assert model.n_layers_out == 1\n    assert model.n_layers_r == 2\n    assert model.n_units_s_out == 20\n    assert model.n_units_p_out == 30\n    assert model.n_units_s_r == 40\n    assert model.n_units_p_r == 50\n    assert model.private_out is True\n    assert model.weight_decay == 1e-5\n    assert model.penalty_orthogonal == 1e-7\n    assert model.lr == 1e-2\n    assert model.n_iter == 123\n    assert model.batch_size == 234\n    assert model.early_stopping is True\n    assert model.patience == 5\n    assert model.n_iter_min == 13\n    assert model.n_iter_print == 7\n    assert model.seed == 42\n    assert model.shared_repr is False\n    assert model.normalize_ortho is False\n    assert model.mode == 1\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4), (\"ihdp\", 1.5)])\ndef test_flextenet_model_sanity(dataset: str, pehe_threshold: float) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = FlexTENet(\n        X_train.shape[1],\n        binary_y=(len(np.unique(Y_train)) == 2),\n        batch_size=1024,\n        lr=1e-3,\n        n_iter=10,\n    )\n\n    score = evaluate_treatments_model(\n        model, X_train, Y_train, Y_train_full, W_train, n_folds=2\n    )\n\n    print(f\"Evaluation for model FlexTENet on {dataset} = {score['str']}\")\n\n\n@pytest.mark.parametrize(\"shared_repr\", [False, True])\n@pytest.mark.parametrize(\"private_out\", [False, True])\n@pytest.mark.parametrize(\"n_units_p_r\", [50, 150])\ndef test_flextenet_model_predict_api(\n    shared_repr: bool, private_out: bool, n_units_p_r: int\n) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\"ihdp\")\n    W_train = W_train.ravel()\n\n    model = FlexTENet(\n        X_train.shape[1],\n        binary_y=(len(np.unique(Y_train)) == 2),\n        batch_size=1024,\n        lr=1e-3,\n        shared_repr=shared_repr,\n        private_out=private_out,\n        n_units_p_r=n_units_p_r,\n        n_iter=10,\n    )\n    model.fit(X_train, Y_train, W_train)\n\n    out = model.predict(X_test)\n\n    assert len(out) == len(X_test)\n\n    out, p0, p1 = model.predict(X_test, return_po=True)\n    assert len(out) == len(X_test)\n    assert len(p0) == len(X_test)\n    assert len(p1) == len(X_test)\n\n    score = model.score(X_test, Y_test)\n\n    assert score > 0\n"
  },
  {
    "path": "tests/models/torch/test_torch_pseudo_outcome_nets.py",
    "content": "from typing import Any\n\nimport numpy as np\nimport pytest\nfrom sklearn.ensemble import RandomForestRegressor\nfrom torch import nn\nfrom xgboost import XGBClassifier\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester import evaluate_treatments_model\nfrom catenets.models.torch import (\n    DRLearner,\n    PWLearner,\n    RALearner,\n    RLearner,\n    ULearner,\n    XLearner,\n)\n\n\n@pytest.mark.parametrize(\n    \"model_t\", [DRLearner, PWLearner, RALearner, RLearner, ULearner, XLearner]\n)\ndef test_nn_model_params(model_t: Any) -> None:\n    model = model_t(\n        2,\n        binary_y=True,\n    )\n\n    assert model._te_estimator is not None\n    assert model._po_estimator is not None\n    assert model._propensity_estimator is not None\n\n\n@pytest.mark.parametrize(\"nonlin\", [\"elu\", \"relu\", \"sigmoid\"])\n@pytest.mark.parametrize(\n    \"model_t\", [DRLearner, PWLearner, RALearner, RLearner, ULearner, XLearner]\n)\ndef test_nn_model_params_nonlin(nonlin: str, model_t: Any) -> None:\n    model = model_t(2, binary_y=True, nonlin=nonlin)\n\n    nonlins = {\n        \"elu\": nn.ELU,\n        \"relu\": nn.ReLU,\n        \"sigmoid\": nn.Sigmoid,\n    }\n\n    for mod in [model._te_estimator, model._po_estimator, model._propensity_estimator]:\n        assert isinstance(mod.model[2], nonlins[nonlin])\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4), (\"ihdp\", 4)])\n@pytest.mark.parametrize(\"model_t\", [DRLearner, RALearner, XLearner])\ndef test_nn_model_sanity(dataset: str, pehe_threshold: float, model_t: Any) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = model_t(\n        X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), n_iter=10\n    )\n\n    score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train)\n\n    print(\n        f\"Evaluation for model torch.{model_t} with NNs on {dataset} = {score['str']}\"\n    )\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4)])\n@pytest.mark.parametrize(\n    \"po_estimator\",\n    [\n        XGBClassifier(\n            n_estimators=100,\n            reg_lambda=1e-3,\n            reg_alpha=1e-3,\n            colsample_bytree=0.1,\n            colsample_bynode=0.1,\n            colsample_bylevel=0.1,\n            max_depth=6,\n            tree_method=\"hist\",\n            learning_rate=1e-2,\n            min_child_weight=0,\n            max_bin=256,\n            random_state=0,\n            eval_metric=\"logloss\",\n            use_label_encoder=False,\n        ),\n    ],\n)\n@pytest.mark.parametrize(\n    \"te_estimator\",\n    [\n        RandomForestRegressor(\n            n_estimators=100,\n            max_depth=6,\n        ),\n    ],\n)\n@pytest.mark.parametrize(\"model_t\", [DRLearner, RALearner])\ndef test_sklearn_model_pseudo_outcome_binary(\n    dataset: str,\n    pehe_threshold: float,\n    po_estimator: Any,\n    te_estimator: Any,\n    model_t: Any,\n) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = model_t(\n        X_train.shape[1],\n        binary_y=True,\n        po_estimator=po_estimator,\n        te_estimator=te_estimator,\n        batch_size=1024,\n        n_iter=10,\n    )\n\n    score = evaluate_treatments_model(\n        model, X_train, Y_train, Y_train_full, W_train, n_folds=3\n    )\n\n    print(\n        f\"Evaluation for model {model_t} with po_estimator = {type(po_estimator)},\"\n        f\"te_estimator = {type(te_estimator)} on {dataset} = {score['str']}\"\n    )\n\n\ndef test_model_predict_api() -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\"ihdp\")\n    W_train = W_train.ravel()\n\n    model = XLearner(X_train.shape[1], binary_y=False, batch_size=1024, n_iter=10)\n    model.fit(X_train, Y_train, W_train)\n\n    out = model.predict(X_test)\n\n    assert len(out) == len(X_test)\n\n    score = model.score(X_test, Y_test)\n\n    assert score > 0\n"
  },
  {
    "path": "tests/models/torch/test_torch_representation_net.py",
    "content": "from typing import Type\n\nimport pytest\nfrom torch import nn\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester import evaluate_treatments_model\nfrom catenets.models.torch import DragonNet, TARNet\n\n\n@pytest.mark.parametrize(\"snet\", [TARNet, DragonNet])\ndef test_model_params(snet: Type) -> None:\n    model = snet(\n        2,\n        binary_y=True,\n        n_layers_out=1,\n        n_units_out=2,\n        n_layers_r=3,\n        n_units_r=4,\n        weight_decay=0.5,\n        lr=0.6,\n        n_iter=700,\n        batch_size=80,\n        val_split_prop=0.9,\n        n_iter_print=10,\n        seed=11,\n    )\n\n    assert model._repr_estimator is not None\n    assert model._propensity_estimator is not None\n    assert len(model._po_estimators) == 2\n\n    for mod in model._po_estimators:\n        assert len(mod.model) == 5  # 1 in + NL + 4 * (n_layers_out - 1) + 1 out + NL\n\n    assert len(model._repr_estimator.model) == 9\n\n\n@pytest.mark.parametrize(\"nonlin\", [\"elu\", \"relu\", \"sigmoid\"])\n@pytest.mark.parametrize(\"snet\", [TARNet, DragonNet])\ndef test_model_params_nonlin(nonlin: str, snet: Type) -> None:\n    model = snet(2, nonlin=nonlin)\n\n    nonlins = {\n        \"elu\": nn.ELU,\n        \"relu\": nn.ReLU,\n        \"sigmoid\": nn.Sigmoid,\n    }\n\n    for mod in [\n        model._repr_estimator,\n        model._po_estimators[0],\n        model._po_estimators[1],\n        model._propensity_estimator,\n    ]:\n        assert isinstance(mod.model[2], nonlins[nonlin])\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4)])\n@pytest.mark.parametrize(\"snet\", [TARNet, DragonNet])\ndef test_model_sanity(dataset: str, pehe_threshold: float, snet: Type) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = snet(\n        X_train.shape[1],\n        batch_size=256,\n        n_iter=10,\n    )\n\n    score = evaluate_treatments_model(\n        model, X_train, Y_train, Y_train_full, W_train, n_folds=3\n    )\n\n    print(f\"Evaluation for model {snet} on {dataset} = {score['str']}\")\n    assert score[\"raw\"][\"pehe\"][0] < pehe_threshold\n\n\ndef test_model_predict_api() -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\"ihdp\")\n    W_train = W_train.ravel()\n\n    model = TARNet(X_train.shape[1], batch_size=1024, n_iter=10)\n    model.fit(X_train, Y_train, W_train)\n\n    out = model.predict(X_test)\n\n    assert len(out) == len(X_test)\n\n    out, p0, p1 = model.predict(X_test, return_po=True)\n    assert len(out) == len(X_test)\n    assert len(p0) == len(X_test)\n    assert len(p1) == len(X_test)\n\n    score = model.score(X_test, Y_test)\n\n    assert score > 0\n"
  },
  {
    "path": "tests/models/torch/test_torch_slearner.py",
    "content": "from typing import Any, Optional\n\nimport numpy as np\nimport pytest\nfrom sklearn.ensemble import RandomForestClassifier, RandomForestRegressor\nfrom sklearn.linear_model import LogisticRegression\nfrom torch import nn\nfrom xgboost import XGBClassifier\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester import evaluate_treatments_model\nfrom catenets.models.torch import SLearner\n\n\ndef test_nn_model_params() -> None:\n    model = SLearner(\n        2,\n        binary_y=True,\n        n_layers_out=1,\n        n_units_out=2,\n        n_units_out_prop=33,\n        n_layers_out_prop=12,\n        weight_decay=0.5,\n        lr=0.6,\n        n_iter=700,\n        batch_size=80,\n        val_split_prop=0.9,\n        n_iter_print=10,\n        seed=11,\n        weighting_strategy=\"ipw\",\n    )\n\n    assert model._weighting_strategy == \"ipw\"\n    assert model._propensity_estimator is not None\n    assert model._po_estimator is not None\n\n    assert model._po_estimator.n_iter == 700\n    assert model._po_estimator.batch_size == 80\n    assert model._po_estimator.n_iter_print == 10\n    assert model._po_estimator.seed == 11\n    assert model._po_estimator.val_split_prop == 0.9\n    assert (\n        len(model._po_estimator.model) == 5\n    )  # 1 in + NL + 3 * (n_layers_hidden -1) + out + Sigmoid\n\n    assert model._propensity_estimator.n_iter == 700\n    assert model._propensity_estimator.batch_size == 80\n    assert model._propensity_estimator.n_iter_print == 10\n    assert model._propensity_estimator.seed == 11\n    assert model._propensity_estimator.val_split_prop == 0.9\n    assert (\n        len(model._propensity_estimator.model) == 38\n    )  # 1 in + NL + 3 * (n_layers_hidden - 1) + out + Softmax\n\n\n@pytest.mark.parametrize(\"nonlin\", [\"elu\", \"relu\", \"sigmoid\"])\ndef test_nn_model_params_nonlin(nonlin: str) -> None:\n    model = SLearner(2, True, nonlin=nonlin, weighting_strategy=\"ipw\")\n\n    nonlins = {\n        \"elu\": nn.ELU,\n        \"relu\": nn.ReLU,\n        \"sigmoid\": nn.Sigmoid,\n    }\n\n    for mod in [model._propensity_estimator, model._po_estimator]:\n        assert isinstance(mod.model[2], nonlins[nonlin])\n\n\n@pytest.mark.parametrize(\"weighting_strategy\", [\"ipw\", None])\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4), (\"ihdp\", 1.5)])\ndef test_nn_model_sanity(\n    dataset: str, pehe_threshold: float, weighting_strategy: Optional[str]\n) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = SLearner(\n        X_train.shape[1],\n        binary_y=(len(np.unique(Y_train)) == 2),\n        weighting_strategy=weighting_strategy,\n        n_iter=10,\n    )\n\n    score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train)\n\n    print(\n        f\"Evaluation for model torch.SLearner(NN)(weighting_strategy={weighting_strategy}) on {dataset} = {score['str']}\"\n    )\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4)])\n@pytest.mark.parametrize(\n    \"po_estimator\",\n    [\n        XGBClassifier(\n            n_estimators=100,\n            reg_lambda=1e-3,\n            reg_alpha=1e-3,\n            colsample_bytree=0.1,\n            colsample_bynode=0.1,\n            colsample_bylevel=0.1,\n            max_depth=6,\n            tree_method=\"hist\",\n            learning_rate=1e-2,\n            min_child_weight=0,\n            max_bin=256,\n            random_state=0,\n            eval_metric=\"logloss\",\n            use_label_encoder=False,\n        ),\n        RandomForestClassifier(\n            n_estimators=100,\n            max_depth=6,\n        ),\n        LogisticRegression(\n            C=1.0,\n            solver=\"sag\",\n            max_iter=10000,\n            penalty=\"l2\",\n        ),\n    ],\n)\ndef test_sklearn_model_sanity_binary_output(\n    dataset: str, pehe_threshold: float, po_estimator: Any\n) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = SLearner(\n        X_train.shape[1],\n        binary_y=True,\n        po_estimator=po_estimator,\n        n_iter=10,\n    )\n\n    score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train)\n\n    print(\n        f\"Evaluation for model torch.SLearner with {po_estimator.__class__} on {dataset} = {score['str']}\"\n    )\n    assert score[\"raw\"][\"pehe\"][0] < pehe_threshold\n\n\n@pytest.mark.parametrize(\"exp\", [1, 10, 40, 50, 99])\n@pytest.mark.parametrize(\n    \"po_estimator\",\n    [\n        RandomForestRegressor(\n            n_estimators=100,\n            max_depth=6,\n        ),\n    ],\n)\ndef test_slearner_sklearn_model_ihdp(po_estimator: Any, exp: int) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\n        \"ihdp\", exp=exp, rescale=True\n    )\n    W_train = W_train.ravel()\n\n    model = SLearner(\n        X_train.shape[1],\n        binary_y=False,\n        po_estimator=po_estimator,\n        n_iter=10,\n    )\n    score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train)\n\n    print(\n        f\"Evaluation for model torch.SLearner with {po_estimator.__class__} on ihdp[{exp}] = {score['str']}\"\n    )\n    assert score[\"raw\"][\"pehe\"][0] < 1.5\n\n\ndef test_model_predict_api() -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\"ihdp\")\n    W_train = W_train.ravel()\n\n    model = SLearner(X_train.shape[1], binary_y=False, batch_size=1024, n_iter=10)\n    model.fit(X_train, Y_train, W_train)\n\n    out = model.predict(X_test)\n\n    assert len(out) == len(X_test)\n\n    out, p0, p1 = model.predict(X_test, return_po=True)\n    assert len(out) == len(X_test)\n    assert len(p0) == len(X_test)\n    assert len(p1) == len(X_test)\n\n    score = model.score(X_test, Y_test)\n\n    assert score > 0\n"
  },
  {
    "path": "tests/models/torch/test_torch_snet.py",
    "content": "import numpy as np\nimport pytest\nfrom torch import nn\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester import evaluate_treatments_model\nfrom catenets.models.torch import SNet\n\n\ndef test_model_params() -> None:\n    # with propensity estimator\n    model = SNet(\n        2,\n        binary_y=True,\n        n_layers_out=1,\n        n_units_out=2,\n        n_layers_r=3,\n        n_units_r=4,\n        weight_decay=0.5,\n        lr=0.6,\n        n_iter=700,\n        batch_size=80,\n        val_split_prop=0.9,\n        n_iter_print=10,\n        seed=11,\n    )\n\n    assert model._reps_c is not None\n    assert model._reps_o is not None\n    assert model._reps_mu0 is not None\n    assert model._reps_mu1 is not None\n    assert model._reps_prop is not None\n    assert model._propensity_estimator is not None\n    assert len(model._po_estimators) == 2\n\n    for mod in model._po_estimators:\n        assert len(mod.model) == 5  # 1 in + NL + 4 * (n_layers_out - 1) + 1 out + NL\n\n    assert len(model._reps_c.model) == 9\n    assert len(model._reps_o.model) == 9\n    assert len(model._reps_mu0.model) == 9\n    assert len(model._reps_mu1.model) == 9\n    assert len(model._propensity_estimator.model) == 8\n\n    # remove propensity estimator\n    model = SNet(\n        2,\n        binary_y=True,\n        n_layers_out=1,\n        n_units_out=2,\n        n_layers_r=3,\n        n_units_r=4,\n        weight_decay=0.5,\n        lr=0.6,\n        n_iter=700,\n        batch_size=80,\n        val_split_prop=0.9,\n        n_iter_print=10,\n        seed=11,\n        with_prop=False,\n    )\n\n    with np.testing.assert_raises(AttributeError):\n        model._reps_c\n    with np.testing.assert_raises(AttributeError):\n        model._reps_prop\n    with np.testing.assert_raises(AttributeError):\n        model._propensity_estimator\n    assert model._reps_o is not None\n    assert model._reps_mu0 is not None\n    assert model._reps_mu1 is not None\n    assert len(model._po_estimators) == 2\n\n    for mod in model._po_estimators:\n        assert len(mod.model) == 5  # 1 in + NL + 4 * (n_layers_out - 1) + 1 out + NL\n\n    assert len(model._reps_o.model) == 9\n    assert len(model._reps_mu0.model) == 9\n    assert len(model._reps_mu1.model) == 9\n\n\n@pytest.mark.parametrize(\"nonlin\", [\"elu\", \"relu\", \"sigmoid\", \"selu\", \"leaky_relu\"])\ndef test_model_params_nonlin(nonlin: str) -> None:\n    model = SNet(2, nonlin=nonlin)\n\n    nonlins = {\n        \"elu\": nn.ELU,\n        \"relu\": nn.ReLU,\n        \"sigmoid\": nn.Sigmoid,\n        \"selu\": nn.SELU,\n        \"leaky_relu\": nn.LeakyReLU,\n    }\n\n    for mod in [\n        model._reps_c,\n        model._reps_o,\n        model._reps_mu0,\n        model._reps_mu1,\n        model._reps_prop,\n        model._po_estimators[0],\n        model._po_estimators[1],\n        model._propensity_estimator,\n    ]:\n        assert isinstance(mod.model[2], nonlins[nonlin])\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4)])\ndef test_model_sanity(dataset: str, pehe_threshold: float) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    # with propensity estimator\n    model = SNet(\n        X_train.shape[1],\n        binary_y=(len(np.unique(Y_train)) == 2),\n        batch_size=1024,\n        n_iter=10,\n    )\n\n    score = evaluate_treatments_model(\n        model, X_train, Y_train, Y_train_full, W_train, n_folds=3\n    )\n\n    print(f\"Evaluation for model SNet on {dataset} = {score['str']}\")\n\n    model = SNet(\n        X_train.shape[1],\n        binary_y=(len(np.unique(Y_train)) == 2),\n        batch_size=1024,\n        n_iter=10,\n        with_prop=False,\n    )\n\n    score = evaluate_treatments_model(\n        model, X_train, Y_train, Y_train_full, W_train, n_folds=3\n    )\n\n    print(f\"Evaluation for model SNet (with_prop=False) on {dataset} = {score['str']}\")\n\n\ndef test_model_predict_api() -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\"ihdp\")\n    W_train = W_train.ravel()\n\n    model = SNet(X_train.shape[1], batch_size=1024, n_iter=10)\n    model.fit(X_train, Y_train, W_train)\n\n    out = model.predict(X_test)\n\n    assert len(out) == len(X_test)\n\n    out, p0, p1 = model.predict(X_test, return_po=True)\n    assert len(out) == len(X_test)\n    assert len(p0) == len(X_test)\n    assert len(p1) == len(X_test)\n\n    score = model.score(X_test, Y_test)\n\n    assert score > 0\n"
  },
  {
    "path": "tests/models/torch/test_torch_tlearner.py",
    "content": "from typing import Any\n\nimport numpy as np\nimport pytest\nfrom sklearn.ensemble import RandomForestClassifier, RandomForestRegressor\nfrom sklearn.linear_model import LogisticRegression\nfrom torch import nn\nfrom xgboost import XGBClassifier, XGBRegressor\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester import evaluate_treatments_model\nfrom catenets.models.torch import TLearner\n\n\ndef test_nn_model_params() -> None:\n    model = TLearner(\n        2,\n        True,\n        n_layers_out=1,\n        n_units_out=2,\n        weight_decay=0.5,\n        lr=0.6,\n        n_iter=700,\n        batch_size=80,\n        val_split_prop=0.9,\n        n_iter_print=10,\n        seed=11,\n    )\n\n    assert len(model._plug_in) == 2\n\n    for mod in model._plug_in:\n        assert mod.n_iter == 700\n        assert mod.batch_size == 80\n        assert mod.n_iter_print == 10\n        assert mod.seed == 11\n        assert mod.val_split_prop == 0.9\n        assert len(mod.model) == 5  # 2 in + NL + 3 * (n_layers_hidden - 1) + 2 out\n\n\n@pytest.mark.parametrize(\"nonlin\", [\"elu\", \"relu\", \"sigmoid\"])\ndef test_nn_model_params_nonlin(nonlin: str) -> None:\n    model = TLearner(2, True, nonlin=nonlin)\n\n    assert len(model._plug_in) == 2\n\n    nonlins = {\n        \"elu\": nn.ELU,\n        \"relu\": nn.ReLU,\n        \"sigmoid\": nn.Sigmoid,\n    }\n\n    for mod in model._plug_in:\n        assert isinstance(mod.model[2], nonlins[nonlin])\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4), (\"ihdp\", 1.5)])\ndef test_nn_model_sanity(dataset: str, pehe_threshold: float) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = TLearner(\n        X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), n_iter=10\n    )\n\n    score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train)\n\n    print(f\"Evaluation for model torch.TLearner(NN) on {dataset} = {score['str']}\")\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"twins\", 0.4)])\n@pytest.mark.parametrize(\n    \"po_estimator\",\n    [\n        XGBClassifier(\n            n_estimators=100,\n            reg_lambda=1e-3,\n            reg_alpha=1e-3,\n            colsample_bytree=0.1,\n            colsample_bynode=0.1,\n            colsample_bylevel=0.1,\n            max_depth=6,\n            tree_method=\"hist\",\n            learning_rate=1e-2,\n            min_child_weight=0,\n            max_bin=256,\n            random_state=0,\n            eval_metric=\"logloss\",\n            use_label_encoder=False,\n        ),\n        RandomForestClassifier(\n            n_estimators=100,\n            max_depth=6,\n        ),\n        LogisticRegression(\n            C=1.0,\n            solver=\"sag\",\n            max_iter=10000,\n            penalty=\"l2\",\n        ),\n    ],\n)\ndef test_sklearn_model_sanity_binary_output(\n    dataset: str, pehe_threshold: float, po_estimator: Any\n) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = TLearner(\n        X_train.shape[1],\n        binary_y=True,\n        po_estimator=po_estimator,\n        n_iter=10,\n    )\n\n    score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train)\n\n    print(\n        f\"Evaluation for model torch.TLearner with {po_estimator.__class__} on {dataset} = {score['str']}\"\n    )\n    assert score[\"raw\"][\"pehe\"][0] < pehe_threshold\n\n\n@pytest.mark.parametrize(\"dataset, pehe_threshold\", [(\"ihdp\", 1.5)])\n@pytest.mark.parametrize(\n    \"po_estimator\",\n    [\n        XGBRegressor(\n            n_estimators=1000,\n            reg_lambda=1e-3,\n            reg_alpha=1e-3,\n            colsample_bytree=0.1,\n            colsample_bynode=0.1,\n            colsample_bylevel=0.1,\n            max_depth=7,\n            tree_method=\"hist\",\n            learning_rate=1e-2,\n            min_child_weight=0,\n            max_bin=256,\n            random_state=0,\n            eval_metric=\"logloss\",\n        ),\n        RandomForestRegressor(\n            n_estimators=100,\n            max_depth=6,\n        ),\n    ],\n)\ndef test_sklearn_model_sanity_regression(\n    dataset: str, pehe_threshold: float, po_estimator: Any\n) -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset)\n    W_train = W_train.ravel()\n\n    model = TLearner(\n        X_train.shape[1],\n        binary_y=False,\n        po_estimator=po_estimator,\n        n_iter=10,\n    )\n    score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train)\n\n    print(\n        f\"Evaluation for model torch.TLearner with {po_estimator.__class__ } on {dataset} = {score['str']}\"\n    )\n\n\ndef test_model_predict_api() -> None:\n    X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(\"ihdp\")\n    W_train = W_train.ravel()\n\n    model = TLearner(\n        X_train.shape[1],\n        binary_y=False,\n        n_iter=10,\n    )\n    model.fit(X_train, Y_train, W_train)\n\n    out = model.predict(X_test)\n\n    assert len(out) == len(X_test)\n\n    out, p0, p1 = model.predict(X_test, return_po=True)\n    assert len(out) == len(X_test)\n    assert len(p0) == len(X_test)\n    assert len(p1) == len(X_test)\n\n    score = model.score(X_test, Y_test)\n\n    assert score > 0\n"
  }
]