Repository: AliciaCurth/CATENets Branch: main Commit: f8c961c30776 Files: 91 Total size: 523.3 KB Directory structure: gitextract_9k1q0bj_/ ├── .github/ │ └── workflows/ │ ├── release.yml │ ├── scripts/ │ │ ├── release_linux.sh │ │ ├── release_osx.sh │ │ └── release_windows.bat │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── catenets/ │ ├── __init__.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── dataset_acic2016.py │ │ ├── dataset_ihdp.py │ │ ├── dataset_twins.py │ │ └── network.py │ ├── experiment_utils/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── simulation_utils.py │ │ ├── tester.py │ │ └── torch_metrics.py │ ├── logger.py │ ├── models/ │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── jax/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── disentangled_nets.py │ │ │ ├── flextenet.py │ │ │ ├── model_utils.py │ │ │ ├── offsetnet.py │ │ │ ├── pseudo_outcome_nets.py │ │ │ ├── representation_nets.py │ │ │ ├── rnet.py │ │ │ ├── snet.py │ │ │ ├── tnet.py │ │ │ ├── transformation_utils.py │ │ │ └── xnet.py │ │ └── torch/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── flextenet.py │ │ ├── pseudo_outcome_nets.py │ │ ├── representation_nets.py │ │ ├── slearner.py │ │ ├── snet.py │ │ ├── tlearner.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── decorators.py │ │ ├── model_utils.py │ │ ├── transformations.py │ │ └── weight_utils.py │ └── version.py ├── docs/ │ ├── Makefile │ ├── conf.py │ ├── datasets.rst │ ├── index.rst │ ├── jax_models.rst │ ├── make.bat │ ├── requirements.txt │ └── torch_models.rst ├── experiments/ │ ├── __init__.py │ ├── experiments_AISTATS21/ │ │ ├── ihdp_experiments.py │ │ └── simulations_AISTATS.py │ ├── experiments_benchmarks_NeurIPS21/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── acic_experiments_catenets.py │ │ ├── acic_experiments_grf.R │ │ ├── ihdp_experiments_catenets.py │ │ ├── ihdp_experiments_grf.R │ │ ├── twins_experiments_catenets.py │ │ └── twins_experiments_grf.R │ └── experiments_inductivebias_NeurIPS21/ │ ├── __init__.py │ ├── experiments_AB.py │ ├── experiments_CD.py │ ├── experiments_acic.py │ └── experiments_twins.py ├── pyproject.toml ├── pytest.ini ├── run_experiments_AISTATS.py ├── run_experiments_benchmarks_NeurIPS.py ├── run_experiments_inductive_bias_NeurIPS.py ├── setup.py └── tests/ ├── conftest.py ├── datasets/ │ └── test_datasets.py └── models/ ├── jax/ │ ├── test_jax_ite.py │ ├── test_jax_model_utils.py │ └── test_jax_transformation_utils.py └── torch/ ├── test_torch_flextenet.py ├── test_torch_pseudo_outcome_nets.py ├── test_torch_representation_net.py ├── test_torch_slearner.py ├── test_torch_snet.py └── test_torch_tlearner.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/release.yml ================================================ name: Package release on: release: types: [created] jobs: deploy_osx: runs-on: ${{ matrix.os }} strategy: matrix: python-version: ["3.7", "3.8", "3.9", "3.10"] os: [macos-latest] steps: - uses: actions/checkout@v2 with: submodules: true - name: Set up Python uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_osx.sh deploy_linux: strategy: matrix: python-version: - cp37-cp37m - cp38-cp38 - cp39-cp39 - cp310-cp310 runs-on: ubuntu-latest container: quay.io/pypa/manylinux2014_x86_64 steps: - uses: actions/checkout@v1 with: submodules: true - name: Set target Python version PATH run: | echo "/opt/python/${{ matrix.python-version }}/bin" >> $GITHUB_PATH - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_linux.sh deploy_windows: runs-on: windows-latest strategy: matrix: python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 with: submodules: true - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: | ../../.github/workflows/scripts/release_windows.bat ================================================ FILE: .github/workflows/scripts/release_linux.sh ================================================ #!/bin/bash set -e yum makecache -y yum install centos-release-scl -y yum-config-manager --enable rhel-server-rhscl-7-rpms yum install llvm-toolset-7.0 python3 python3-devel -y # Python python3 -m pip install --upgrade pip python3 -m pip install setuptools wheel twine auditwheel # Publish python3 -m pip wheel . -w dist/ --no-deps twine upload --verbose --skip-existing dist/* ================================================ FILE: .github/workflows/scripts/release_osx.sh ================================================ #!/bin/sh export MACOSX_DEPLOYMENT_TARGET=10.14 python -m pip install --upgrade pip pip install setuptools wheel twine auditwheel python3 setup.py build bdist_wheel --plat-name macosx_10_14_x86_64 --dist-dir wheel twine upload --skip-existing wheel/* ================================================ FILE: .github/workflows/scripts/release_windows.bat ================================================ echo on python -m pip install --upgrade pip pip install setuptools wheel twine auditwheel pip wheel . -w wheel/ --no-deps twine upload --skip-existing wheel/* ================================================ FILE: .github/workflows/test.yml ================================================ name: CATENets Tests on: push: branches: [main, release] pull_request: types: [opened, synchronize, reopened] schedule: - cron: '0 0 * * 0' workflow_dispatch: jobs: Linter: runs-on: ${{ matrix.os }} strategy: matrix: python-version: [3.8] os: [ubuntu-latest] steps: - uses: actions/checkout@v2 with: submodules: true - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: pip install .[testing] - name: pre-commit validation run: pre-commit run --files catenets/* - name: Security checks run: | bandit -r catenets/* Library: needs: [Linter] runs-on: ${{ matrix.os }} strategy: matrix: python-version: ['3.8', '3.9', "3.10"] os: [macos-latest, ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v2 with: submodules: true - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - name: Install MacOS dependencies run: | brew install libomp if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install .[testing] - name: Test with pytest Unix run: pytest -vvvsx -m "not slow" if: ${{ matrix.os != 'windows-latest' }} - name: Test with pytest Windows run: | cd tests\datasets pytest -vvvsx -m "not slow" cd ..\.. cd tests\models\torch pytest -vvvsx -m "not slow" if: ${{ matrix.os == 'windows-latest' }} ================================================ FILE: .gitignore ================================================ *.pyc *.xml *.iml *.csv *.xlsx *.Rhistory .idea/ .coverage .ipynb_checkpoints .ipynb_checkpoints/ */.ipynb_checkpoints/ */bin/ */include/ */lib/ */lib64/ */share/ *.cfg .pytest_cache data/ build/ catenets.egg-info/ dist/ generated/ _build ================================================ FILE: .pre-commit-config.yaml ================================================ exclude: 'setup.py|^docs' repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.4.0 hooks: - id: trailing-whitespace - id: check-added-large-files - id: check-ast - id: check-json - id: check-merge-conflict - id: check-xml - id: check-yaml - id: debug-statements - id: check-executables-have-shebangs - id: end-of-file-fixer - id: requirements-txt-fixer - id: mixed-line-ending args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black language_version: python3 - repo: https://github.com/pycqa/flake8 rev: 3.9.1 hooks: - id: flake8 args: [ "--max-line-length=140", "--extend-ignore=E203,W503" ] - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.812 hooks: - id: mypy args: [ "--ignore-missing-imports", "--scripts-are-modules", "--disallow-incomplete-defs", "--no-implicit-optional", "--warn-unused-ignores", "--warn-redundant-casts", "--strict-equality", "--warn-unreachable", "--disallow-untyped-defs", "--disallow-untyped-calls", ] - repo: local hooks: - id: flynt name: flynt entry: flynt args: [--fail-on-change] types: [python] language: python additional_dependencies: - flynt ================================================ FILE: LICENSE ================================================ BSD 3-Clause License Copyright (c) 2021, Alicia Curth All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ # CATENets - Conditional Average Treatment Effect Estimation Using Neural Networks [![CATENets Tests](https://github.com/AliciaCurth/CATENets/actions/workflows/test.yml/badge.svg)](https://github.com/AliciaCurth/CATENets/actions/workflows/test.yml) [![Documentation Status](https://readthedocs.org/projects/catenets/badge/?version=latest)](https://catenets.readthedocs.io/en/latest/?badge=latest) [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://github.com/AliciaCurth/CATENets/blob/main/LICENSE) Code Author: Alicia Curth (amc253@cam.ac.uk) This repo contains Jax-based, sklearn-style implementations of Neural Network-based Conditional Average Treatment Effect (CATE) Estimators, which were used in the AISTATS21 paper ['Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning Algorithms']( https://arxiv.org/abs/2101.10943) (Curth & vd Schaar, 2021a) as well as the follow up NeurIPS21 paper ["On Inductive Biases for Heterogeneous Treatment Effect Estimation"](https://arxiv.org/abs/2106.03765) (Curth & vd Schaar, 2021b) and the NeurIPS21 Datasets & Benchmarks track paper ["Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation"](https://openreview.net/forum?id=FQLzQqGEAH) (Curth et al, 2021). We implement the SNet-class we introduce in Curth & vd Schaar (2021a), as well as FlexTENet and OffsetNet as discussed in Curth & vd Schaar (2021b), and re-implement a number of NN-based algorithms from existing literature (Shalit et al (2017), Shi et al (2019), Hassanpour & Greiner (2020)). We also provide Neural Network (NN)-based instantiations of a number of so-called meta-learners for CATE estimation, including two-step pseudo-outcome regression estimators (the DR-learner (Kennedy, 2020) and single-robust propensity-weighted (PW) and regression-adjusted (RA) learners), Nie & Wager (2017)'s R-learner and Kuenzel et al (2019)'s X-learner. The jax implementations in ``catenets.models.jax`` were used in all papers listed; additionally, pytorch versions of some models (``catenets.models.torch``) were contributed by [Bogdan Cebere](https://github.com/bcebere). ### Interface The repo contains a package ``catenets``, which contains all general code used for modeling and evaluation, and a folder ``experiments``, in which the code for replicating experimental results is contained. All implemented learning algorithms in ``catenets`` (``SNet, FlexTENet, OffsetNet, TNet, SNet1 (TARNet), SNet2 (DragonNet), SNet3, DRNet, RANet, PWNet, RNet, XNet``) come with a sklearn-style wrapper, implementing a ``.fit(X, y, w)`` and a ``.predict(X)`` method, where predict returns CATE by default. All hyperparameters are documented in detail in the respective files in ``catenets.models`` folder. Example usage: ```python from catenets.models.jax import TNet, SNet from catenets.experiment_utils.simulation_utils import simulate_treatment_setup # simulate some data (here: unconfounded, 10 prognostic variables and 5 predictive variables) X, y, w, p, cate = simulate_treatment_setup(n=2000, n_o=10, n_t=5, n_c=0) # estimate CATE using TNet t = TNet() t.fit(X, y, w) cate_pred_t = t.predict(X) # without potential outcomes cate_pred_t, po0_pred_t, po1_pred_t = t.predict(X, return_po=True) # predict potential outcomes too # estimate CATE using SNet s = SNet(penalty_orthogonal=0.01) s.fit(X, y, w) cate_pred_s = s.predict(X) ``` All experiments in Curth & vd Schaar (2021a) can be replicated using this repository; the necessary code is in ``experiments.experiments_AISTATS21``. To do so from shell, clone the repo, create a new virtual environment and run ```shell pip install catenets # install the library from PyPI # OR pip install . # install the library from the local repository # Run the experiments python run_experiments_AISTATS.py ``` ```shell Options: --experiment # defaults to 'simulation', 'ihdp' will run ihdp experiments --setting # different simulation settings in synthetic experiments (can be 1-5) --models # defaults to None which will train all models considered in paper, # can be string of model name (e.g 'TNet'), 'plug' for all plugin models, # 'pseudo' for all pseudo-outcome regression models --file_name # base file name to write to, defaults to 'results' --n_repeats # number of experiments to run for each configuration, defaults to 10 (should be set to 100 for IHDP) ``` Similarly, the experiments in Curth & vd Schaar (2021b) can be replicated using the code in ``experiments.experiments_inductivebias_NeurIPS21`` (or from shell using ```python run_experiments_inductive_bias_NeurIPS.py```) and the experiments in Curth et al (2021) can be replicated using the code in ``experiments.experiments_benchmarks_NeurIPS21`` (the catenets experiments can also be run from shell using ``python run_experiments_benchmarks_NeurIPS``). The code can also be installed as a python package (``catenets``). From a local copy of the repo, run ``python setup.py install``. Note: jax is currently only supported on macOS and linux, but can be run from windows using WSL (the windows subsystem for linux). ### Citing If you use this software please cite the corresponding paper(s): ``` @inproceedings{curth2021nonparametric, title={Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning Algorithms}, author={Curth, Alicia and van der Schaar, Mihaela}, year={2021}, booktitle={Proceedings of the 24th International Conference on Artificial Intelligence and Statistics (AISTATS)}, organization={PMLR} } @article{curth2021inductive, title={On Inductive Biases for Heterogeneous Treatment Effect Estimation}, author={Curth, Alicia and van der Schaar, Mihaela}, booktitle={Proceedings of the Thirty-Fifth Conference on Neural Information Processing Systems}, year={2021} } @article{curth2021really, title={Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation}, author={Curth, Alicia and Svensson, David and Weatherall, James and van der Schaar, Mihaela}, booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks}, year={2021} } ``` ================================================ FILE: catenets/__init__.py ================================================ import sys from . import logger # noqa: F401 from . import datasets, models # noqa: F401 logger.add(sink=sys.stderr, level="CRITICAL") ================================================ FILE: catenets/datasets/__init__.py ================================================ # stdlib import os from pathlib import Path from typing import Any, Tuple from . import dataset_acic2016, dataset_ihdp, dataset_twins DATA_PATH = Path(os.path.dirname(__file__)) / Path("data") try: os.mkdir(DATA_PATH) except BaseException: pass def load(dataset: str, *args: Any, **kwargs: Any) -> Tuple: """ Input: dataset: the name of the dataset to load Outputs: - Train_X, Test_X: Train and Test features - Train_Y: Observable outcomes - Train_T: Assigned treatment - Test_Y: Potential outcomes. """ if dataset == "twins": return dataset_twins.load(DATA_PATH, *args, **kwargs) if dataset == "ihdp": return dataset_ihdp.load(DATA_PATH, *args, **kwargs) if dataset == "acic2016": return dataset_acic2016.load(DATA_PATH, *args, **kwargs) else: raise Exception("Unsupported dataset") __all__ = ["dataset_ihdp", "dataset_twins", "dataset_acic2016", "load"] ================================================ FILE: catenets/datasets/dataset_acic2016.py ================================================ """ ACIC2016 dataset """ import glob # stdlib import random from pathlib import Path from typing import Any, Tuple # third party import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from sklearn.preprocessing import OneHotEncoder, StandardScaler import catenets.logger as log from .network import download_if_needed np.random.seed(0) random.seed(0) FILE_ID = "0B7pG5PPgj6A3N09ibmFwNWE1djA" PREPROCESSED_FILE_ID = "1iOfEAk402o3jYBs2Prfiz6oaailwWcR5" NUMERIC_COLS = [ 0, 3, 4, 16, 17, 18, 20, 21, 22, 24, 24, 25, 30, 31, 32, 33, 39, 40, 41, 53, 54, ] N_NUM_COLS = len(NUMERIC_COLS) def get_acic_covariates( fn_csv: Path, keep_categorical: bool = False, preprocessed: bool = True ) -> np.ndarray: X = pd.read_csv(fn_csv) if not keep_categorical: X = X.drop(columns=["x_2", "x_21", "x_24"]) else: # encode categorical features feature_list = [] for cols_ in X.columns: if type(X.loc[X.index[0], cols_]) not in [np.int64, np.float64]: enc = OneHotEncoder(drop="first") enc.fit(np.array(X[[cols_]]).reshape((-1, 1))) for k in range(len(list(enc.get_feature_names()))): X[cols_ + list(enc.get_feature_names())[k]] = enc.transform( np.array(X[[cols_]]).reshape((-1, 1)) ).toarray()[:, k] feature_list.append(cols_) X.drop(feature_list, axis=1, inplace=True) if preprocessed: X_t = X.values else: scaler = StandardScaler() X_t = scaler.fit_transform(X) return X_t def preprocess_simu( fn_csv: Path, n_0: int = 2000, n_1: int = 200, n_test: int = 500, error_sd: float = 1, sp_lin: float = 0.6, sp_nonlin: float = 0.3, prop_gamma: float = 0, prop_omega: float = 0, ate_goal: float = 0, inter: bool = True, i_exp: int = 0, keep_categorical: bool = False, preprocessed: bool = True, ) -> Tuple: X = get_acic_covariates( fn_csv, keep_categorical=keep_categorical, preprocessed=preprocessed ) np.random.seed(i_exp) # shuffle indices n_total, n_cov = X.shape ind = np.arange(n_total) np.random.shuffle(ind) ind_test = ind[-n_test:] ind_1 = ind[n_0 : (n_0 + n_1)] # create treatment indicator (treatment assignment does not matter in test set) w = np.zeros(n_total).reshape((-1, 1)) w[ind_1] = 1 # create dgp coeffs_ = [0, 1] # sample baseline coefficients beta_0 = np.random.choice(coeffs_, size=n_cov, replace=True, p=[1 - sp_lin, sp_lin]) intercept = np.random.choice([x for x in np.arange(-1, 1.25, 0.25)]) # sample treatment effect coefficients gamma = np.random.choice( coeffs_, size=n_cov, replace=True, p=[1 - prop_gamma, prop_gamma] ) omega = np.random.choice( [0, 1], replace=True, size=n_cov, p=[prop_omega, 1 - prop_omega] ) # simulate mu_0 and mu_1 mu_0 = (intercept + np.dot(X, beta_0)).reshape((-1, 1)) mu_1 = (intercept + np.dot(X, gamma + beta_0 * omega)).reshape((-1, 1)) if sp_nonlin > 0: coefs_sq = [0, 0.1] beta_sq = np.random.choice( coefs_sq, size=N_NUM_COLS, replace=True, p=[1 - sp_nonlin, sp_nonlin] ) omega = np.random.choice( [0, 1], replace=True, size=N_NUM_COLS, p=[prop_omega, 1 - prop_omega] ) X_sq = X[:, NUMERIC_COLS] ** 2 mu_0 = mu_0 + np.dot(X_sq, beta_sq).reshape((-1, 1)) mu_1 = mu_1 + np.dot(X_sq, beta_sq * omega).reshape((-1, 1)) if inter: # randomly add some interactions ind_c = np.arange(n_cov) np.random.shuffle(ind_c) inter_list = list() for i in range(0, n_cov - 2, 2): inter_list.append(X[:, ind_c[i]] * X[:, ind_c[i + 1]]) X_inter = np.array(inter_list).T n_inter = X_inter.shape[1] beta_inter = np.random.choice( coefs_sq, size=n_inter, replace=True, p=[1 - sp_nonlin, sp_nonlin] ) omega = np.random.choice( [0, 1], replace=True, size=n_inter, p=[prop_omega, 1 - prop_omega] ) mu_0 = mu_0 + np.dot(X_inter, beta_inter).reshape((-1, 1)) mu_1 = mu_1 + np.dot(X_inter, beta_inter * omega).reshape((-1, 1)) ate = np.mean(mu_1 - mu_0) mu_1 = mu_1 - ate + ate_goal y = ( w * mu_1 + (1 - w) * mu_0 + np.random.normal(0, error_sd, n_total).reshape((-1, 1)) ) X_train, y_train, w_train, mu_0_train, mu_1_train = ( X[ind[: (n_0 + n_1)], :], y[ind[: (n_0 + n_1)]], w[ind[: (n_0 + n_1)]], mu_0[ind[: (n_0 + n_1)]], mu_1[ind[: (n_0 + n_1)]], ) X_test, y_test, w_test, mu_0_t, mu_1_t = ( X[ind_test, :], y[ind_test], w[ind_test], mu_0[ind_test], mu_1[ind_test], ) return ( X_train, w_train, y_train, np.asarray([mu_0_train, mu_1_train]).squeeze().T, X_test, w_test, y_test, np.asarray([mu_0_t, mu_1_t]).squeeze().T, ) def get_acic_orig_filenames(data_path: Path, simu_num: int) -> list: return sorted( glob.glob( (data_path / ("data_cf_all/" + str(simu_num) + "/zymu_*.csv")).__str__() ) ) def get_acic_orig_outcomes(data_path: Path, simu_num: int, i_exp: int) -> Tuple: file_list = get_acic_orig_filenames(data_path=data_path, simu_num=simu_num) out = pd.read_csv(file_list[i_exp]) w = out["z"] y = w * out["y1"] + (1 - w) * out["y0"] mu_0, mu_1 = out["mu0"], out["mu1"] return y.values, w.values, mu_0.values, mu_1.values def preprocess_acic_orig( fn_csv: Path, data_path: Path, preprocessed: bool = False, keep_categorical: bool = True, simu_num: int = 1, i_exp: int = 0, train_size: int = 4000, random_split: bool = False, ) -> Tuple: X = get_acic_covariates( fn_csv, keep_categorical=keep_categorical, preprocessed=preprocessed ) y, w, mu_0, mu_1 = get_acic_orig_outcomes( data_path=data_path, simu_num=simu_num, i_exp=i_exp ) if not random_split: X_train, y_train, w_train, mu_0_train, mu_1_train = ( X[:train_size, :], y[:train_size], w[:train_size], mu_0[:train_size], mu_1[:train_size], ) X_test, y_test, w_test, mu_0_test, mu_1_test = ( X[train_size:, :], y[train_size:], w[train_size:], mu_0[train_size:], mu_1[train_size:], ) else: ( X_train, X_test, y_train, y_test, w_train, w_test, mu_0_train, mu_0_test, mu_1_train, mu_1_test, ) = train_test_split( X, y, w, mu_0, mu_1, test_size=1 - train_size, random_state=i_exp ) return ( X_train, w_train, y_train, np.asarray([mu_0_train, mu_1_train]).squeeze().T, X_test, w_test, y_test, np.asarray([mu_0_test, mu_1_test]).squeeze().T, ) def preprocess( fn_csv: Path, data_path: Path, preprocessed: bool = True, original_acic_outcomes: bool = False, **kwargs: Any, ) -> Tuple: if not original_acic_outcomes: return preprocess_simu(fn_csv=fn_csv, preprocessed=preprocessed, **kwargs) else: return preprocess_acic_orig( fn_csv=fn_csv, preprocessed=preprocessed, data_path=data_path, **kwargs ) def load( data_path: Path, preprocessed: bool = True, original_acic_outcomes: bool = False, **kwargs: Any, ) -> Tuple: """ ACIC2016 dataset dataloader. - Download the dataset if needed. - Load the dataset. - Preprocess the data. - Return train/test split. Parameters ---------- data_path: Path Path to the CSV. If it is missing, it will be downloaded. preprocessed: bool Switch between the raw and preprocessed versions of the dataset. original_acic_outcomes: bool Switch between new simulations (Inductive bias paper) and original acic outcomes Returns ------- train_x: array or pd.DataFrame Features in training data. train_t: array or pd.DataFrame Treatments in training data. train_y: array or pd.DataFrame Observed outcomes in training data. train_potential_y: array or pd.DataFrame Potential outcomes in training data. test_x: array or pd.DataFrame Features in testing data. test_potential_y: array or pd.DataFrame Potential outcomes in testing data. """ if preprocessed: csv = data_path / "x_trans.csv" download_if_needed(csv, file_id=PREPROCESSED_FILE_ID) else: arch = data_path / "data_cf_all.tar.gz" download_if_needed( arch, file_id=FILE_ID, unarchive=True, unarchive_folder=data_path ) csv = data_path / "data_cf_all/x.csv" log.debug(f"load dataset {csv}") return preprocess( csv, data_path=data_path, preprocessed=preprocessed, original_acic_outcomes=original_acic_outcomes, **kwargs, ) ================================================ FILE: catenets/datasets/dataset_ihdp.py ================================================ """ IHDP (Infant Health and Development Program) dataset """ # stdlib import os import random from pathlib import Path from typing import Any, Tuple # third party import numpy as np import catenets.logger as log from .network import download_if_needed np.random.seed(0) random.seed(0) TRAIN_DATASET = "ihdp_npci_1-100.train.npz" TEST_DATASET = "ihdp_npci_1-100.test.npz" TRAIN_URL = "https://www.fredjo.com/files/ihdp_npci_1-100.train.npz" TEST_URL = "https://www.fredjo.com/files/ihdp_npci_1-100.test.npz" # helper functions def load_data_npz(fname: Path, get_po: bool = True) -> dict: """ Helper function for loading the IHDP data set (adapted from https://github.com/clinicalml/cfrnet) Parameters ---------- fname: Path Dataset path Returns ------- data: dict Raw IHDP dict, with X, w, y and yf keys. """ data_in = np.load(fname) data = {"X": data_in["x"], "w": data_in["t"], "y": data_in["yf"]} try: data["ycf"] = data_in["ycf"] except BaseException: data["ycf"] = None if get_po: data["mu0"] = data_in["mu0"] data["mu1"] = data_in["mu1"] data["HAVE_TRUTH"] = not data["ycf"] is None data["dim"] = data["X"].shape[1] data["n"] = data["X"].shape[0] return data def prepare_ihdp_data( data_train: dict, data_test: dict, rescale: bool = False, setting: str = "C", return_pos: bool = False, ) -> Tuple: """ Helper for preprocessing the IHDP dataset. Parameters ---------- data_train: pd.DataFrame or dict Train dataset data_test: pd.DataFrame or dict Test dataset rescale: bool, default False Rescale the outcomes to have similar scale setting: str, default C Experiment setting return_pos: bool Return potential outcomes Returns ------- X: dict or pd.DataFrame Training Feature set y: pd.DataFrame or list Outcome list t: pd.DataFrame or list Treatment list cate_true_in: pd.DataFrame or list Average treatment effects for the training set X_t: pd.Dataframe or list Test feature set cate_true_out: pd.DataFrame of list Average treatment effects for the testing set """ X, y, w, mu0, mu1 = ( data_train["X"], data_train["y"], data_train["w"], data_train["mu0"], data_train["mu1"], ) X_t, _, _, mu0_t, mu1_t = ( data_test["X"], data_test["y"], data_test["w"], data_test["mu0"], data_test["mu1"], ) if setting == "D": y[w == 1] = y[w == 1] + mu0[w == 1] mu1 = mu0 + mu1 mu1_t = mu0_t + mu1_t if rescale: # rescale all outcomes to have similar scale of CATEs if sd_cate > 1 cate_in = mu0 - mu1 sd_cate = np.sqrt(cate_in.var()) if sd_cate > 1: # training data error = y - w * mu1 - (1 - w) * mu0 mu0 = mu0 / sd_cate mu1 = mu1 / sd_cate y = w * mu1 + (1 - w) * mu0 + error # test data mu0_t = mu0_t / sd_cate mu1_t = mu1_t / sd_cate cate_true_in = mu1 - mu0 cate_true_out = mu1_t - mu0_t if return_pos: return X, y, w, cate_true_in, X_t, cate_true_out, mu0, mu1, mu0_t, mu1_t return X, y, w, cate_true_in, X_t, cate_true_out def get_one_data_set(D: dict, i_exp: int, get_po: bool = True) -> dict: """ Helper for getting the IHDP data for one experiment. Adapted from https://github.com/clinicalml/cfrnet Parameters ---------- D: dict or pd.DataFrame All the experiment i_exp: int Experiment number Returns ------- data: dict or pd.Dataframe dict with the experiment """ D_exp = {} D_exp["X"] = D["X"][:, :, i_exp - 1] D_exp["w"] = D["w"][:, i_exp - 1 : i_exp] D_exp["y"] = D["y"][:, i_exp - 1 : i_exp] if D["HAVE_TRUTH"]: D_exp["ycf"] = D["ycf"][:, i_exp - 1 : i_exp] else: D_exp["ycf"] = None if get_po: D_exp["mu0"] = D["mu0"][:, i_exp - 1 : i_exp] D_exp["mu1"] = D["mu1"][:, i_exp - 1 : i_exp] return D_exp def load(data_path: Path, exp: int = 1, rescale: bool = False, **kwargs: Any) -> Tuple: """ Get IHDP train/test datasets with treatments and labels. Parameters ---------- data_path: Path Path to the dataset csv. If the data is missing, it will be downloaded. Returns ------- X: pd.Dataframe or array The training feature set w: pd.DataFrame or array Training treatment assignments. y: pd.Dataframe or array The training labels training potential outcomes: pd.DataFrame or array. Potential outcomes for the training set. X_t: pd.DataFrame or array The testing feature set testing potential outcomes: pd.DataFrame of array Potential outcomes for the testing set. """ data_train, data_test = load_raw(data_path) data_exp = get_one_data_set(data_train, i_exp=exp, get_po=True) data_exp_test = get_one_data_set(data_test, i_exp=exp, get_po=True) ( X, y, w, cate_true_in, X_t, cate_true_out, mu0, mu1, mu0_t, mu1_t, ) = prepare_ihdp_data( data_exp, data_exp_test, rescale=rescale, return_pos=True, ) return ( X, w, y, np.asarray([mu0, mu1]).squeeze().T, X_t, np.asarray([mu0_t, mu1_t]).squeeze().T, ) def load_raw(data_path: Path) -> Tuple: """ Get IHDP raw train/test sets. Parameters ---------- data_path: Path Path to the dataset csv. If the data is missing, it will be downloaded. Returns ------- data_train: dict or pd.DataFrame Training data data_test: dict or pd.DataFrame Testing data """ try: os.mkdir(data_path) except BaseException: pass train_csv = data_path / TRAIN_DATASET test_csv = data_path / TEST_DATASET log.debug(f"load raw dataset {train_csv}") download_if_needed(train_csv, http_url=TRAIN_URL) download_if_needed(test_csv, http_url=TEST_URL) data_train = load_data_npz(train_csv, get_po=True) data_test = load_data_npz(test_csv, get_po=True) return data_train, data_test ================================================ FILE: catenets/datasets/dataset_twins.py ================================================ """ Twins dataset Load real-world individualized treatment effects estimation datasets - Reference: http://data.nber.org/data/linked-birth-infant-death-data-vital-statistics-data.html """ # stdlib import random from pathlib import Path from typing import Tuple # third party import numpy as np import pandas as pd from sklearn.preprocessing import MinMaxScaler import catenets.logger as log from .network import download_if_needed DATASET = "Twin_Data.csv.gz" URL = "https://bitbucket.org/mvdschaar/mlforhealthlabpub/raw/0b0190bcd38a76c405c805f1ca774971fcd85233/data/twins/Twin_Data.csv.gz" # noqa: E501 def preprocess( fn_csv: Path, train_ratio: float = 0.8, treatment_type: str = "rand", seed: int = 42, treat_prop: float = 0.5, ) -> Tuple: """Helper for preprocessing the Twins dataset. Parameters ---------- fn_csv: Path Dataset CSV file path. train_ratio: float The ratio of training data. treatment_type: string The treatment selection strategy. seed: float Random seed. Returns ------- train_x: array or pd.DataFrame Features in training data. train_t: array or pd.DataFrame Treatments in training data. train_y: array or pd.DataFrame Observed outcomes in training data. train_potential_y: array or pd.DataFrame Potential outcomes in training data. test_x: array or pd.DataFrame Features in testing data. test_potential_y: array or pd.DataFrame Potential outcomes in testing data. """ np.random.seed(seed) random.seed(seed) # Load original data (11400 patients, 30 features, 2 dimensional potential outcomes) df = pd.read_csv(fn_csv) cleaned_columns = [] for col in df.columns: cleaned_columns.append(col.replace("'", "").replace("’", "")) df.columns = cleaned_columns feat_list = list(df) # 8: factor not on certificate, 9: factor not classifiable --> np.nan --> mode imputation medrisk_list = [ "anemia", "cardiac", "lung", "diabetes", "herpes", "hydra", "hemo", "chyper", "phyper", "eclamp", "incervix", "pre4000", "dtotord", "preterm", "renal", "rh", "uterine", "othermr", ] # 99: missing other_list = ["cigar", "drink", "wtgain", "gestat", "dmeduc", "nprevist"] other_list2 = ["pldel", "resstatb"] # but no samples are missing.. bin_list = ["dmar"] + medrisk_list con_list = ["dmage", "mpcb"] + other_list cat_list = ["adequacy"] + other_list2 for feat in medrisk_list: df[feat] = df[feat].apply(lambda x: df[feat].mode()[0] if x in [8, 9] else x) for feat in other_list: df.loc[df[feat] == 99, feat] = df.loc[df[feat] != 99, feat].mean() df_features = df[con_list + bin_list] for feat in cat_list: df_features = pd.concat( [df_features, pd.get_dummies(df[feat], prefix=feat)], axis=1 ) # Define features feat_list = [ "dmage", "mpcb", "cigar", "drink", "wtgain", "gestat", "dmeduc", "nprevist", "dmar", "anemia", "cardiac", "lung", "diabetes", "herpes", "hydra", "hemo", "chyper", "phyper", "eclamp", "incervix", "pre4000", "dtotord", "preterm", "renal", "rh", "uterine", "othermr", "adequacy_1", "adequacy_2", "adequacy_3", "pldel_1", "pldel_2", "pldel_3", "pldel_4", "pldel_5", "resstatb_1", "resstatb_2", "resstatb_3", "resstatb_4", ] x = np.asarray(df_features[feat_list]) y0 = np.asarray(df[["outcome(t=0)"]]).reshape((-1,)) y0 = np.array(y0 < 9999, dtype=int) y1 = np.asarray(df[["outcome(t=1)"]]).reshape((-1,)) y1 = np.array(y1 < 9999, dtype=int) # Preprocessing scaler = MinMaxScaler() scaler.fit(x) x = scaler.transform(x) no, dim = x.shape if treatment_type == "rand": # assign with p=0.5 prob = np.ones(x.shape[0]) * treat_prop elif treatment_type == "logistic": # assign with logistic prob coef = np.random.uniform(-0.1, 0.1, size=[np.shape(x)[1], 1]) prob = 1 / (1 + np.exp(-np.matmul(x, coef))) w = np.random.binomial(1, prob) y = y1 * w + y0 * (1 - w) potential_y = np.vstack((y0, y1)).T # Train/test division if train_ratio < 1: idx = np.random.permutation(no) train_idx = idx[: int(train_ratio * no)] test_idx = idx[int(train_ratio * no) :] train_x = x[train_idx, :] train_w = w[train_idx] train_y = y[train_idx] train_potential_y = potential_y[train_idx, :] test_x = x[test_idx, :] test_potential_y = potential_y[test_idx, :] else: train_x = x train_w = w train_y = y train_potential_y = potential_y test_x = None test_potential_y = None return train_x, train_w, train_y, train_potential_y, test_x, test_potential_y def load( data_path: Path, train_ratio: float = 0.8, treatment_type: str = "rand", seed: int = 42, treat_prop: float = 0.5, ) -> Tuple: """ Twins dataset dataloader. - Download the dataset if needed. - Load the dataset. - Preprocess the data. - Return train/test split. Parameters ---------- data_path: Path Path to the CSV. If it is missing, it will be downloaded. train_ratio: float Train/test ratio treatment_type: str Treatment generation strategy seed: float Random seed treat_prop: float Treatment proportion Returns ------- train_x: array or pd.DataFrame Features in training data. train_t: array or pd.DataFrame Treatments in training data. train_y: array or pd.DataFrame Observed outcomes in training data. train_potential_y: array or pd.DataFrame Potential outcomes in training data. test_x: array or pd.DataFrame Features in testing data. test_potential_y: array or pd.DataFrame Potential outcomes in testing data. """ csv = data_path / DATASET download_if_needed(csv, http_url=URL) log.debug(f"load dataset {csv}") return preprocess( csv, train_ratio=train_ratio, treatment_type=treatment_type, seed=seed, treat_prop=treat_prop, ) ================================================ FILE: catenets/datasets/network.py ================================================ """ Utilities and helpers for retrieving the datasets """ # stdlib import tarfile import urllib.request from pathlib import Path from typing import Optional import gdown def download_gdrive_if_needed(path: Path, file_id: str) -> None: """ Helper for downloading a file from Google Drive, if it is now already on the disk. Parameters ---------- path: Path Where to download the file file_id: str Google Drive File ID. Details: https://developers.google.com/drive/api/v3/about-files """ path = Path(path) if path.exists(): return gdown.download(id=file_id, output=str(path), quiet=False) def download_http_if_needed(path: Path, url: str) -> None: """ Helper for downloading a file, if it is now already on the disk. Parameters ---------- path: Path Where to download the file. url: URL string HTTP URL for the dataset. """ path = Path(path) if path.exists(): return if url.lower().startswith("http"): urllib.request.urlretrieve(url, path) # nosec return raise ValueError(f"Invalid url provided {url}") def unarchive_if_needed(path: Path, output_folder: Path) -> None: """ Helper for uncompressing archives. Supports .tar.gz and .tar. Parameters ---------- path: Path Source archive. output_folder: Path Where to unarchive. """ if str(path).endswith(".tar.gz"): tar = tarfile.open(path, "r:gz") tar.extractall(path=output_folder) # nosec tar.close() elif str(path).endswith(".tar"): tar = tarfile.open(path, "r:") tar.extractall(path=output_folder) # nosec tar.close() else: raise NotImplementedError(f"archive not supported {path}") def download_if_needed( download_path: Path, file_id: Optional[str] = None, # used for downloading from Google Drive http_url: Optional[str] = None, # used for downloading from a HTTP URL unarchive: bool = False, # unzip a downloaded archive unarchive_folder: Optional[Path] = None, # unzip folder ) -> None: """ Helper for retrieving online datasets. Parameters ---------- download_path: str Where to download the archive file_id: str, optional Set this if you want to download from a public Google drive share http_url: str, optional Set this if you want to download from a HTTP URL unarchive: bool Set this if you want to try to unarchive the downloaded file unarchive_folder: str Mandatory if you set unarchive to True. """ download_path = Path(download_path) if file_id is not None: download_gdrive_if_needed(download_path, file_id) elif http_url is not None: download_http_if_needed(download_path, http_url) else: raise ValueError("Please provide a download URL") if unarchive and unarchive_folder is None: raise ValueError("Please provide a folder for the archive") if unarchive and unarchive_folder is not None: try: unarchive_if_needed(download_path, unarchive_folder) except BaseException as e: print(f"Failed to unpack {download_path}. Error {e}") download_path.unlink() ================================================ FILE: catenets/experiment_utils/__init__.py ================================================ ================================================ FILE: catenets/experiment_utils/base.py ================================================ """ Some utils for experiments """ # Author: Alicia Curth from typing import Callable, Dict, Optional, Union import jax.numpy as jnp from catenets.models.jax import ( DRNET_NAME, PSEUDOOUT_NAME, RANET_NAME, RNET_NAME, SNET1_NAME, SNET2_NAME, SNET3_NAME, SNET_NAME, T_NAME, XNET_NAME, PseudoOutcomeNet, get_catenet, ) from catenets.models.jax.base import check_shape_1d_data from catenets.models.jax.transformation_utils import ( DR_TRANSFORMATION, PW_TRANSFORMATION, RA_TRANSFORMATION, ) SEP = "_" def eval_mse_model( inputs: jnp.ndarray, targets: jnp.ndarray, predict_fun: Callable, params: jnp.ndarray, ) -> jnp.ndarray: # evaluate the mse of a model given its function and params preds = predict_fun(params, inputs) return jnp.mean((preds - targets) ** 2) def eval_mse(preds: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray: preds = check_shape_1d_data(preds) targets = check_shape_1d_data(targets) return jnp.mean((preds - targets) ** 2) def eval_root_mse(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray: cate_true = check_shape_1d_data(cate_true) cate_pred = check_shape_1d_data(cate_pred) return jnp.sqrt(eval_mse(cate_pred, cate_true)) def eval_abs_error_ate(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray: cate_true = check_shape_1d_data(cate_true) cate_pred = check_shape_1d_data(cate_pred) return jnp.abs(jnp.mean(cate_pred) - jnp.mean(cate_true)) def get_model_set( model_selection: Union[str, list] = "all", model_params: Optional[dict] = None ) -> Dict: """Helper function to retrieve a set of models""" # get model selection if type(model_selection) is str: if model_selection == "snet": models = get_all_snets() elif model_selection == "pseudo": models = get_all_pseudoout_models() elif model_selection == "twostep": models = get_all_twostep_models() elif model_selection == "all": models = dict(**get_all_snets(), **get_all_pseudoout_models()) else: models = {model_selection: get_catenet(model_selection)()} # type: ignore elif type(model_selection) is list: models = {} for model in model_selection: models.update({model: get_catenet(model)()}) else: raise ValueError("model_selection should be string or list.") # set hyperparameters if model_params is not None: for model in models.values(): existing_params = model.get_params() new_params = { key: val for key, val in model_params.items() if key in existing_params.keys() } model.set_params(**new_params) return models ALL_SNETS = [T_NAME, SNET1_NAME, SNET2_NAME, SNET3_NAME, SNET_NAME] ALL_PSEUDOOUT_MODELS = [DR_TRANSFORMATION, PW_TRANSFORMATION, RA_TRANSFORMATION] ALL_TWOSTEP_MODELS = [DRNET_NAME, RANET_NAME, XNET_NAME, RNET_NAME] def get_all_snets() -> Dict: model_dict = {} for name in ALL_SNETS: model_dict.update({name: get_catenet(name)()}) return model_dict def get_all_pseudoout_models() -> Dict: # DR, RA, PW learner model_dict = {} for trans in ALL_PSEUDOOUT_MODELS: model_dict.update( {PSEUDOOUT_NAME + SEP + trans: PseudoOutcomeNet(transformation=trans)} ) return model_dict def get_all_twostep_models() -> Dict: # DR, RA, R, X learner model_dict = {} for name in ALL_TWOSTEP_MODELS: model_dict.update({name: get_catenet(name)()}) return model_dict ================================================ FILE: catenets/experiment_utils/simulation_utils.py ================================================ """ Simulation utils, allowing to flexibly consider different DGPs """ # Author: Alicia Curth from typing import Any, Optional, Tuple import numpy as np from scipy.special import expit def simulate_treatment_setup( n: int, d: int = 25, n_w: int = 0, n_c: int = 0, n_o: int = 0, n_t: int = 0, covariate_model: Any = None, covariate_model_params: Optional[dict] = None, propensity_model: Any = None, propensity_model_params: Optional[dict] = None, mu_0_model: Any = None, mu_0_model_params: Optional[dict] = None, mu_1_model: Any = None, mu_1_model_params: Optional[dict] = None, error_sd: float = 1, seed: int = 42, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Generic function to flexibly simulate a treatment setup. Parameters ---------- n: int Number of observations to generate d: int dimension of X to generate n_o: int Dimension of outcome-factor n_c: int Dimension of confounding factor n_t: int Dimension of purely predictive variables (support of tau(x) n_w: int Dimension of treatment assignment factor covariate_model: Model to generate covariates. Default: multivariate normal covariate_model_params: dict Additional parameters to pass to covariate model propensity_model: Model to generate propensity scores propensity_model_params: Additional parameters to pass to propensity model mu_0_model: Model to generate untreated outcomes mu_0_model_params: Additional parameters to pass to untreated outcome model mu_1_model: Model to generate treated outcomes. mu_1_model_params: Additional parameters to pass to treated outcome model error_sd: float, default 1 Standard deviation of normal errors seed: int Seed Returns ------- X, y, w, p, t - Covariates, observed outcomes, treatment indicators, propensities, CATE """ # input checks n_nuisance = d - (n_c + n_o + n_w + n_t) if n_nuisance < 0: raise ValueError("Dimensions should add up to maximally d.") # set defaults if covariate_model is None: covariate_model = normal_covariate_model if covariate_model_params is None: covariate_model_params = {} if propensity_model is None: propensity_model = propensity_AISTATS if propensity_model_params is None: propensity_model_params = {} if mu_0_model is None: mu_0_model = mu0_AISTATS if mu_0_model_params is None: mu_0_model_params = {} if mu_1_model is None: mu_1_model = mu1_AISTATS if mu_1_model_params is None: mu_1_model_params = {} np.random.seed(seed) # generate data and outcomes X = covariate_model( n=n, n_nuisance=n_nuisance, n_c=n_c, n_o=n_o, n_w=n_w, n_t=n_t, **covariate_model_params ) mu_0 = mu_0_model(X, n_c=n_c, n_o=n_o, n_w=n_w, **mu_0_model_params) mu_1 = mu_1_model( X, n_c=n_c, n_o=n_o, n_w=n_w, n_t=n_t, mu_0=mu_0, **mu_1_model_params ) t = mu_1 - mu_0 # generate treatments p = propensity_model(X, n_c=n_c, n_w=n_w, **propensity_model_params) w = np.random.binomial(1, p=p) # generate observables y = w * mu_1 + (1 - w) * mu_0 + np.random.normal(0, error_sd, n) return X, y, w, p, t # normal covariate model (Adapted from Hassanpour & Greiner, 2020) ------------- def get_multivariate_normal_params( m: int, correlated: bool = False ) -> Tuple[np.ndarray, np.ndarray]: # Adapted from Hassanpour & Greiner (2020) if correlated: mu = np.zeros(m) # np.random.normal(size=m)/10 temp = np.random.uniform(size=(m, m)) temp = 0.5 * (np.transpose(temp) + temp) sig = (np.ones((m, m)) - np.eye(m)) * temp / 10 + 0.5 * np.eye( m ) # (temp + m * np.eye(m)) / 10 else: mu = np.zeros(m) sig = np.eye(m) return mu, sig def get_set_normal_covariates(m: int, n: int, correlated: bool = False) -> np.ndarray: if m == 0: return mu, sig = get_multivariate_normal_params(m, correlated=correlated) return np.random.multivariate_normal(mean=mu, cov=sig, size=n) def normal_covariate_model( n: int, n_nuisance: int = 25, n_c: int = 0, n_o: int = 0, n_w: int = 0, n_t: int = 0, correlated: bool = False, ) -> np.ndarray: X_stack: Tuple = () for n_x in [n_w, n_c, n_o, n_t, n_nuisance]: if n_x > 0: X_stack = (*X_stack, get_set_normal_covariates(n_x, n, correlated)) return np.hstack(X_stack) def propensity_AISTATS( X: np.ndarray, n_c: int = 0, n_w: int = 0, xi: float = 0.5, nonlinear: bool = True, offset: Any = 0, target_prop: Optional[np.ndarray] = None, ) -> np.ndarray: if n_c + n_w == 0: # constant propensity return xi * np.ones(X.shape[0]) else: coefs = np.ones(n_c + n_w) if nonlinear: z = np.dot(X[:, : (n_c + n_w)] ** 2, coefs) / (n_c + n_w) else: z = np.dot(X[:, : (n_c + n_w)], coefs) / (n_c + n_w) if type(offset) is float or type(offset) is int: prop = expit(xi * z + offset) if target_prop is not None: avg_prop = np.average(prop) prop = target_prop / avg_prop * prop return prop elif offset == "center": # center the propensity scores to median 0.5 prop = expit(xi * (z - np.median(z))) if target_prop is not None: avg_prop = np.average(prop) prop = target_prop / avg_prop * prop return prop else: raise ValueError("Not a valid value for offset") def propensity_constant( X: np.ndarray, n_c: int = 0, n_w: int = 0, xi: float = 0.5 ) -> np.ndarray: return xi * np.ones(X.shape[0]) def mu0_AISTATS( X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0, scale: bool = False ) -> np.ndarray: if n_c + n_o == 0: return np.zeros((X.shape[0])) else: if not scale: coefs = np.ones(n_c + n_o) else: coefs = 10 * np.ones(n_c + n_o) / (n_c + n_o) return np.dot(X[:, n_w : (n_w + n_c + n_o)] ** 2, coefs) def mu1_AISTATS( X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0, n_t: int = 0, mu_0: Optional[np.ndarray] = None, nonlinear: int = 2, withbase: bool = True, scale: bool = False, ) -> np.ndarray: if n_t == 0: return mu_0 # use additive effect else: if scale: coefs = 10 * np.ones(n_t) / n_t else: coefs = np.ones(n_t) X_sel = X[:, (n_w + n_c + n_o) : (n_w + n_c + n_o + n_t)] if withbase: return mu_0 + np.dot(X_sel**nonlinear, coefs) else: return np.dot(X_sel**nonlinear, coefs) # Other simulation settings not used in AISTATS paper # uniform covariate model def uniform_covariate_model( n: int, n_nuisance: int = 0, n_c: int = 0, n_o: int = 0, n_w: int = 0, n_t: int = 0, low: int = -1, high: int = 1, ) -> np.ndarray: d = n_nuisance + n_c + n_o + n_w + n_t return np.random.uniform(low=low, high=high, size=(n, d)) def mu1_additive( X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0, n_t: int = 0, mu_0: Optional[np.ndarray] = None, ) -> np.ndarray: if n_t == 0: return mu_0 else: coefs = np.random.normal(size=n_t) return np.dot(X[:, (n_w + n_c + n_o) : (n_w + n_c + n_o + n_t)], coefs) / n_t # regression surfaces from Hassanpour & Greiner def mu0_hg(X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0) -> np.ndarray: if n_c + n_o == 0: return np.zeros((X.shape[0])) else: coefs = np.random.normal(size=n_c + n_o) return np.dot(X[:, n_w : (n_w + n_c + n_o)], coefs) / (n_c + n_o) def mu1_hg( X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0, n_t: int = 0, mu_0: Optional[np.ndarray] = None, ) -> np.ndarray: if n_c + n_o == 0: return np.zeros((X.shape[0])) else: coefs = np.random.normal(size=n_c + n_o) return np.dot(X[:, n_w : (n_w + n_c + n_o)] ** 2, coefs) / (n_c + n_o) def propensity_hg( X: np.ndarray, n_c: int = 0, n_w: int = 0, xi: Optional[float] = None ) -> np.ndarray: # propensity set-up used in Hassanpour & Greiner (2020) if n_c + n_w == 0: return 0.5 * np.ones(X.shape[0]) else: if xi is None: xi = 1 coefs = np.random.normal(size=n_c + n_w) z = np.dot(X[:, : (n_c + n_w)], coefs) return expit(xi * z) ================================================ FILE: catenets/experiment_utils/tester.py ================================================ # stdlib import copy from typing import Any, Tuple # third party import numpy as np import torch from sklearn.model_selection import KFold, StratifiedKFold from catenets.experiment_utils.torch_metrics import abs_error_ATE, sqrt_PEHE def generate_score(metric: np.ndarray) -> Tuple[float, float]: percentile_val = 1.96 return (np.mean(metric), percentile_val * np.std(metric) / np.sqrt(len(metric))) def print_score(score: Tuple[float, float]) -> str: return str(round(score[0], 4)) + " +/- " + str(round(score[1], 4)) def evaluate_treatments_model( estimator: Any, X: torch.Tensor, Y: torch.Tensor, Y_full: torch.Tensor, W: torch.Tensor, n_folds: int = 3, seed: int = 0, ) -> dict: metric_pehe = np.zeros(n_folds) metric_ate = np.zeros(n_folds) indx = 0 if len(np.unique(Y)) == 2: skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed) else: skf = KFold(n_splits=n_folds, shuffle=True, random_state=seed) for train_index, test_index in skf.split(X, Y): X_train = X[train_index] Y_train = Y[train_index] W_train = W[train_index] X_test = X[test_index] Y_full_test = Y_full[test_index] model = copy.deepcopy(estimator) model.fit(X_train, Y_train, W_train) try: te_pred = model.predict(X_test).detach().cpu().numpy() except BaseException: te_pred = np.asarray(model.predict(X_test)) metric_ate[indx] = abs_error_ATE(Y_full_test, te_pred) metric_pehe[indx] = sqrt_PEHE(Y_full_test, te_pred) indx += 1 output_pehe = generate_score(metric_pehe) output_ate = generate_score(metric_ate) return { "raw": { "pehe": output_pehe, "ate": output_ate, }, "str": { "pehe": print_score(output_pehe), "ate": print_score(output_ate), }, } ================================================ FILE: catenets/experiment_utils/torch_metrics.py ================================================ # third party import torch def sqrt_PEHE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor: """ Precision in Estimation of Heterogeneous Effect(PyTorch version). PEHE reflects the ability to capture individual variation in treatment effects. Args: po: expected outcome. hat_te: estimated outcome. """ po = torch.Tensor(po) hat_te = torch.Tensor(hat_te) return torch.sqrt(torch.mean(((po[:, 1] - po[:, 0]) - hat_te) ** 2)) def abs_error_ATE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor: """ Average Treatment Effect. ATE measures what is the expected causal effect of the treatment across all individuals in the population. Args: po: expected outcome. hat_te: estimated outcome. """ po = torch.Tensor(po) hat_te = torch.Tensor(hat_te) return torch.abs(torch.mean(po[:, 1] - po[:, 0]) - torch.mean(hat_te)) ================================================ FILE: catenets/logger.py ================================================ # stdlib import logging import os from typing import Any, Callable, NoReturn, TextIO, Union # third party from loguru import logger LOG_FORMAT = "[{time}][{process.id}][{level}] {message}" logger.remove() DEFAULT_SINK = "catenets_{time}.log" def remove() -> None: logger.remove() def add( sink: Union[None, str, os.PathLike, TextIO, logging.Handler] = None, level: str = "ERROR", ) -> None: sink = DEFAULT_SINK if sink is None else sink try: logger.add( sink=sink, format=LOG_FORMAT, enqueue=True, colorize=False, diagnose=True, backtrace=True, rotation="10 MB", retention="1 day", level=level, ) except BaseException: logger.add( sink=sink, format=LOG_FORMAT, colorize=False, diagnose=True, backtrace=True, level=level, ) def traceback_and_raise(e: Any, verbose: bool = False) -> NoReturn: try: if verbose: logger.opt(lazy=True).exception(e) else: logger.opt(lazy=True).critical(e) except BaseException as ex: logger.debug("failed to print exception", ex) if not issubclass(type(e), Exception): e = Exception(e) raise e def create_log_and_print_function(level: str) -> Callable: def log_and_print(*args: Any, **kwargs: Any) -> None: try: method = getattr(logger.opt(lazy=True), level, None) if method is not None: method(*args, **kwargs) else: logger.debug(*args, **kwargs) except BaseException as e: msg = f"failed to log exception. {e}" try: logger.debug(msg) except Exception as e: print(f"{msg}. {e}") return log_and_print def traceback(*args: Any, **kwargs: Any) -> None: return create_log_and_print_function(level="exception")(*args, **kwargs) def critical(*args: Any, **kwargs: Any) -> None: return create_log_and_print_function(level="critical")(*args, **kwargs) def error(*args: Any, **kwargs: Any) -> None: return create_log_and_print_function(level="error")(*args, **kwargs) def warning(*args: Any, **kwargs: Any) -> None: return create_log_and_print_function(level="warning")(*args, **kwargs) def info(*args: Any, **kwargs: Any) -> None: return create_log_and_print_function(level="info")(*args, **kwargs) def debug(*args: Any, **kwargs: Any) -> None: return create_log_and_print_function(level="debug")(*args, **kwargs) def trace(*args: Any, **kwargs: Any) -> None: return create_log_and_print_function(level="trace")(*args, **kwargs) ================================================ FILE: catenets/models/__init__.py ================================================ import catenets.logger as log try: from . import jax except ImportError: log.error("JAX models disabled") try: from . import torch except ImportError: log.error("PyTorch models disabled") __all__ = ["jax", "torch"] ================================================ FILE: catenets/models/constants.py ================================================ """ Define some constants for initialisation of hyperparamters etc """ import numpy as np # default model architectures DEFAULT_LAYERS_OUT = 2 DEFAULT_LAYERS_OUT_T = 2 DEFAULT_LAYERS_R = 3 DEFAULT_LAYERS_R_T = 3 DEFAULT_UNITS_OUT = 100 DEFAULT_UNITS_R = 200 DEFAULT_UNITS_OUT_T = 100 DEFAULT_UNITS_R_T = 200 DEFAULT_NONLIN = "elu" # other default hyperparameters DEFAULT_STEP_SIZE = 0.0001 DEFAULT_STEP_SIZE_T = 0.0001 DEFAULT_N_ITER = 10000 DEFAULT_BATCH_SIZE = 100 DEFAULT_PENALTY_L2 = 1e-4 DEFAULT_PENALTY_DISC = 0 DEFAULT_PENALTY_ORTHOGONAL = 1 / 100 DEFAULT_AVG_OBJECTIVE = True # defaults for early stopping DEFAULT_VAL_SPLIT = 0.3 DEFAULT_N_ITER_MIN = 200 DEFAULT_PATIENCE = 10 # Defaults for crossfitting DEFAULT_CF_FOLDS = 2 # other defaults DEFAULT_SEED = 42 DEFAULT_N_ITER_PRINT = 50 LARGE_VAL = np.iinfo(np.int32).max DEFAULT_UNITS_R_BIG_S = 100 DEFAULT_UNITS_R_SMALL_S = 50 DEFAULT_UNITS_R_BIG_S3 = 150 DEFAULT_UNITS_R_SMALL_S3 = 50 N_SUBSPACES = 3 DEFAULT_DIM_S_OUT = 50 DEFAULT_DIM_S_R = 100 DEFAULT_DIM_P_OUT = 50 DEFAULT_DIM_P_R = 100 ================================================ FILE: catenets/models/jax/__init__.py ================================================ """ JAX-based implementations for the CATE estimators. """ from typing import Any from catenets.models.jax.disentangled_nets import SNet3 from catenets.models.jax.flextenet import FlexTENet from catenets.models.jax.offsetnet import OffsetNet from catenets.models.jax.pseudo_outcome_nets import ( DRNet, PseudoOutcomeNet, PWNet, RANet, ) from catenets.models.jax.representation_nets import DragonNet, SNet1, SNet2, TARNet from catenets.models.jax.rnet import RNet from catenets.models.jax.snet import SNet from catenets.models.jax.tnet import TNet from catenets.models.jax.xnet import XNet SNET1_NAME = "SNet1" T_NAME = "TNet" SNET2_NAME = "SNet2" PSEUDOOUT_NAME = "PseudoOutcomeNet" SNET3_NAME = "SNet3" SNET_NAME = "SNet" XNET_NAME = "XNet" RNET_NAME = "RNet" DRNET_NAME = "DRNet" PWNET_NAME = "PWNet" RANET_NAME = "RANet" TARNET_NAME = "TARNet" FLEXTE_NAME = "FlexTENet" OFFSET_NAME = "OffsetNet" DRAGON_NAME = "DragonNet" ALL_MODELS = [ T_NAME, SNET1_NAME, SNET2_NAME, SNET3_NAME, SNET_NAME, PSEUDOOUT_NAME, RNET_NAME, XNET_NAME, DRNET_NAME, PWNET_NAME, RANET_NAME, TARNET_NAME, FLEXTE_NAME, OFFSET_NAME, ] MODEL_DICT = { T_NAME: TNet, SNET1_NAME: SNet1, SNET2_NAME: SNet2, SNET3_NAME: SNet3, SNET_NAME: SNet, PSEUDOOUT_NAME: PseudoOutcomeNet, RNET_NAME: RNet, XNET_NAME: XNet, DRNET_NAME: DRNet, PWNET_NAME: PWNet, RANET_NAME: RANet, TARNET_NAME: TARNet, DRAGON_NAME: DragonNet, OFFSET_NAME: OffsetNet, FLEXTE_NAME: FlexTENet, } __all__ = [ T_NAME, SNET1_NAME, SNET2_NAME, SNET3_NAME, SNET_NAME, PSEUDOOUT_NAME, RNET_NAME, XNET_NAME, DRNET_NAME, PWNET_NAME, RANET_NAME, TARNET_NAME, DRAGON_NAME, FLEXTE_NAME, OFFSET_NAME, ] def get_catenet(name: str) -> Any: if name not in ALL_MODELS: raise ValueError( f"Model name should be in catenets.models.jax.ALL_MODELS You passed {name}" ) return MODEL_DICT[name] ================================================ FILE: catenets/models/jax/base.py ================================================ """ Base modules shared across different nets """ # Author: Alicia Curth import abc from typing import Any, Callable, List, Optional, Tuple import jax.numpy as jnp import numpy as onp from jax import grad, jit, random from jax.example_libraries import optimizers, stax from jax.example_libraries.stax import Dense, Elu, Relu, Sigmoid from sklearn.base import BaseEstimator, RegressorMixin from sklearn.model_selection import ParameterGrid import catenets.logger as log from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.jax.model_utils import ( check_shape_1d_data, check_X_is_np, make_val_split, ) def ReprBlock( n_layers: int = 3, n_units: int = 100, nonlin: str = DEFAULT_NONLIN ) -> Any: # Creates a representation block using jax.stax # create first layer if nonlin == "elu": NL = Elu elif nonlin == "relu": NL = Relu elif nonlin == "sigmoid": NL = Sigmoid else: raise ValueError("Unknown nonlinearity") layers: Tuple layers = (Dense(n_units), NL) # add required number of layers for i in range(n_layers - 1): layers = (*layers, Dense(n_units), NL) return stax.serial(*layers) def OutputHead( n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, binary_y: bool = False, n_layers_r: int = 0, n_units_r: int = DEFAULT_UNITS_R, nonlin: str = DEFAULT_NONLIN, ) -> Any: # Creates an output head using jax.stax if nonlin == "elu": NL = Elu elif nonlin == "relu": NL = Relu elif nonlin == "sigmoid": NL = Sigmoid else: raise ValueError("Unknown nonlinearity") layers: Tuple = () # add required number of layers for i in range(n_layers_r): layers = (*layers, Dense(n_units_r), NL) # add required number of layers for i in range(n_layers_out): layers = (*layers, Dense(n_units_out), NL) # return final architecture if not binary_y: return stax.serial(*layers, Dense(1)) else: return stax.serial(*layers, Dense(1), Sigmoid) class BaseCATENet(BaseEstimator, RegressorMixin, abc.ABC): """ Base CATENet class to serve as template for all other nets """ def score( self, X: jnp.ndarray, y: jnp.ndarray, sample_weight: Optional[jnp.ndarray] = None, ) -> float: """ Return the sqrt PEHE error (Oracle metric). Parameters ---------- X: pd.DataFrame or np.array Covariate matrix y: np.array Expected potential outcome vector """ X = check_X_is_np(X) y = check_X_is_np(y) if len(X) != len(y): raise ValueError("X/y length mismatch for score") if y.shape[-1] != 2: raise ValueError(f"y has invalid shape {y.shape}") hat_te = self.predict(X) return jnp.sqrt(jnp.mean(((y[:, 1] - y[:, 0]) - hat_te) ** 2)) @abc.abstractmethod def _get_train_function(self) -> Callable: ... def fit( self, X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, p: Optional[jnp.ndarray] = None, ) -> "BaseCATENet": """ Fit method for a CATENet. Takes covariates, outcome variable and treatment indicator as input Parameters ---------- X: pd.DataFrame or np.array Covariate matrix y: np.array Outcome vector w: np.array Treatment indicator p: np.array Vector of (known) treatment propensities. Currently only supported for TwoStepNets. """ # some quick input checks if p is not None: raise NotImplementedError("Only two-step-nets take p as input. ") X = check_X_is_np(X) self._check_inputs(w, p) train_func = self._get_train_function() train_params = self.get_params() self._params, self._predict_funs = train_func(X, y, w, **train_params) return self @abc.abstractmethod def _get_predict_function(self) -> Callable: ... def predict( self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False ) -> jnp.ndarray: """ Predict treatment effect estimates using a CATENet. Depending on method, can also return potential outcome estimate and propensity score estimate. Parameters ---------- X: pd.DataFrame or np.array Covariate matrix return_po: bool, default False Whether to return potential outcome estimate return_prop: bool, default False Whether to return propensity estimate Returns ------- array of CATE estimates, optionally also potential outcomes and propensity """ X = check_X_is_np(X) predict_func = self._get_predict_function() return predict_func( X, trained_params=self._params, predict_funs=self._predict_funs, return_po=return_po, return_prop=return_prop, ) @staticmethod def _check_inputs(w: jnp.ndarray, p: jnp.ndarray) -> None: if p is not None: if onp.sum(p > 1) > 0 or onp.sum(p < 0) > 0: raise ValueError("p should be in [0,1]") if not ((w == 0) | (w == 1)).all(): raise ValueError("W should be binary") def fit_and_select_params( self, X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, p: Optional[jnp.ndarray] = None, param_grid: dict = {}, ) -> "BaseCATENet": # some quick input checks if param_grid is None: raise ValueError("No param_grid to evaluate. ") X = check_X_is_np(X) self._check_inputs(w, p) param_grid = ParameterGrid(param_grid) self_param_dict = self.get_params() train_function = self._get_train_function() models = [] losses = [] param_settings: list = [] for param_setting in param_grid: log.debug( "Testing parameter setting: " + " ".join( [key + ": " + str(value) for key, value in param_setting.items()] ) ) # replace params train_param_dict = { key: (val if key not in param_setting.keys() else param_setting[key]) for key, val in self_param_dict.items() } if p is not None: params, funs, val_loss = train_function( X, y, w, p=p, return_val_loss=True, **train_param_dict ) else: params, funs, val_loss = train_function( X, y, w, return_val_loss=True, **train_param_dict ) models.append((params, funs)) losses.append(val_loss) # save results param_settings.extend(param_grid) self._selection_results = { "param_settings": param_settings, "val_losses": losses, } # find lowest loss and set params best_idx = jnp.array(losses).argmin() self._params, self._predict_funs = models[best_idx] self.set_params(**param_settings[best_idx]) return self def train_output_net_only( X: jnp.ndarray, y: jnp.ndarray, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_layers_r: int = 0, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = False, ) -> Any: # function to train a single output head # input check y = check_shape_1d_data(y) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well # get validation split (can be none) X, y, X_val, y_val, val_string = make_val_split( X, y, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get output head init_fun, predict_fun = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, n_layers_r=n_layers_r, n_units_r=n_units_r, nonlin=nonlin, ) # get functions if not binary_y: # define loss and grad @jit def loss( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float ) -> jnp.ndarray: # mse loss function inputs, targets = batch preds = predict_fun(params, inputs) weightsq = sum( [ jnp.sum(params[i][0] ** 2) for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2) ] ) if not avg_objective: return jnp.sum((preds - targets) ** 2) + 0.5 * penalty * weightsq else: return jnp.average((preds - targets) ** 2) + 0.5 * penalty * weightsq else: # get loss and grad @jit def loss( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float ) -> jnp.ndarray: # mse loss function inputs, targets = batch preds = predict_fun(params, inputs) weightsq = sum( [ jnp.sum(params[i][0] ** 2) for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2) ] ) if not avg_objective: return ( -jnp.sum( targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds) ) + 0.5 * penalty * weightsq ) else: return ( -jnp.average( targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds) ) + 0.5 * penalty * weightsq ) # set optimization routine # set optimizer opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) # set update function @jit def update(i: int, state: dict, batch: jnp.ndarray, penalty: float) -> jnp.ndarray: params = get_params(state) g_params = grad(loss)(params, batch, penalty) return opt_update(i, g_params, state) # initialise states _, init_params = init_fun(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y[idx_next, :] opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss(params_curr, (X_val, y_val), penalty_l2) if i % n_iter_print == 0: log.info(f"Epoch: {i}, current {val_string} loss: {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 else: p_curr = p_curr + 1 if p_curr > patience: trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss(trained_params, (X_val, y_val), 0) return trained_params, predict_fun, l_final return trained_params, predict_fun # get final parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss(trained_params, (X_val, y_val), 0) return trained_params, predict_fun, l_final return trained_params, predict_fun ================================================ FILE: catenets/models/jax/disentangled_nets.py ================================================ """ Class implements SNet-3, a variation on DR-CFR discussed in Hassanpour and Greiner (2020) and Wu et al (2020). """ # Author: Alicia Curth from typing import Any, Callable, List, Tuple import jax.numpy as jnp import numpy as onp from jax import grad, jit, random from jax.example_libraries import optimizers import catenets.logger as log from catenets.models.constants import ( DEFAULT_AVG_OBJECTIVE, DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_DISC, DEFAULT_PENALTY_L2, DEFAULT_PENALTY_ORTHOGONAL, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R_BIG_S3, DEFAULT_UNITS_R_SMALL_S3, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.jax.base import BaseCATENet, OutputHead, ReprBlock from catenets.models.jax.model_utils import ( check_shape_1d_data, heads_l2_penalty, make_val_split, ) from catenets.models.jax.representation_nets import mmd2_lin # helper functions to avoid abstract tracer values in jit def _get_absolute_rowsums(mat: jnp.ndarray) -> jnp.ndarray: return jnp.sum(jnp.abs(mat), axis=1) def _concatenate_representations(reps: jnp.ndarray) -> jnp.ndarray: return jnp.concatenate(reps, axis=1) class SNet3(BaseCATENet): """ Class implements SNet-3, which is based on Hassanpour & Greiner (2020)'s DR-CFR (Without propensity weighting), using an orthogonal regularizer to enforce decomposition similar to Wu et al (2020). Parameters ---------- binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_layers_out_prop: int Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int Number of hidden units in each hypothesis layer n_units_out_prop: int Number of hidden units in each propensity score hypothesis layer n_layers_r: int Number of shared & private representation layers before hypothesis layers n_units_r: int Number of hidden units in representation layer shared by propensity score and outcome function (the 'confounding factor') n_units_r_small: int Number of hidden units in representation layer NOT shared by propensity score and outcome functions (the 'outcome factor' and the 'instrumental factor') penalty_l2: float l2 (ridge) penalty step_size: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used reg_diff: bool, default False Whether to regularize the difference between the two potential outcome heads penalty_diff: float l2-penalty for regularizing the difference between output heads. used only if train_separate=False same_init: bool, False Whether to initialise the two output heads with same values nonlin: string, default 'elu' Nonlinearity to use in NN penalty_disc: float, default zero Discrepancy penalty. Defaults to zero as this feature is not tested. """ def __init__( self, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R_BIG_S3, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3, n_units_out: int = DEFAULT_UNITS_OUT, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, penalty_disc: float = DEFAULT_PENALTY_DISC, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, same_init: bool = False, ) -> None: self.binary_y = binary_y self.n_layers_r = n_layers_r self.n_layers_out = n_layers_out self.n_layers_out_prop = n_layers_out_prop self.n_units_r = n_units_r self.n_units_r_small = n_units_r_small self.n_units_out = n_units_out self.n_units_out_prop = n_units_out_prop self.nonlin = nonlin self.penalty_l2 = penalty_l2 self.penalty_orthogonal = penalty_orthogonal self.penalty_disc = penalty_disc self.reg_diff = reg_diff self.penalty_diff = penalty_diff self.same_init = same_init self.step_size = step_size self.n_iter = n_iter self.batch_size = batch_size self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min self.seed = seed self.n_iter_print = n_iter_print def _get_predict_function(self) -> Callable: return predict_snet3 def _get_train_function(self) -> Callable: return train_snet3 # SNET-3 ------------------------------------------------------------- def train_snet3( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R_BIG_S3, n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_disc: float = DEFAULT_PENALTY_DISC, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, n_iter_min: int = DEFAULT_N_ITER_MIN, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, same_init: bool = False, ) -> Any: """ SNet-3, based on the decompostion used in Hassanpour and Greiner (2020) """ # function to train a net with 3 representations y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well if not reg_diff: penalty_diff = penalty_l2 # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get representation layers init_fun_repr, predict_fun_repr = ReprBlock( n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin ) init_fun_repr_small, predict_fun_repr_small = ReprBlock( n_layers=n_layers_r, n_units=n_units_r_small, nonlin=nonlin ) # get output head functions (output heads share same structure) init_fun_head_po, predict_fun_head_po = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, nonlin=nonlin, ) # add propensity head init_fun_head_prop, predict_fun_head_prop = OutputHead( n_layers_out=n_layers_out_prop, n_units_out=n_units_out_prop, binary_y=True, nonlin=nonlin, ) def init_fun_snet3(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]: # chain together the layers # param should look like [repr_c, repr_o, repr_t, po_0, po_1, prop] # initialise representation layers rng, layer_rng = random.split(rng) input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape) rng, layer_rng = random.split(rng) input_shape_repr_small, param_repr_o = init_fun_repr_small( layer_rng, input_shape ) rng, layer_rng = random.split(rng) _, param_repr_w = init_fun_repr_small(layer_rng, input_shape) # each head gets two representations input_shape_repr = input_shape_repr[:-1] + ( input_shape_repr[-1] + input_shape_repr_small[-1], ) # initialise output heads rng, layer_rng = random.split(rng) if same_init: # initialise both on same values input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) else: input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr) return input_shape, [ param_repr_c, param_repr_o, param_repr_w, param_0, param_1, param_prop, ] # Define loss functions # loss functions for the head if not binary_y: def loss_head( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty: float, ) -> jnp.ndarray: # mse loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return jnp.sum(weights * ((preds - targets) ** 2)) else: def loss_head( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty: float, ) -> jnp.ndarray: # log loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return -jnp.sum( weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) ) def loss_head_prop( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float, ) -> jnp.ndarray: # log loss function for propensities inputs, targets = batch preds = predict_fun_head_prop(params, inputs) return -jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) # complete loss function for all parts @jit def loss_snet3( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty_l2: float, penalty_orthogonal: float, penalty_disc: float, ) -> jnp.ndarray: # params: list[repr_c, repr_o, repr_t, po_0, po_1, prop] # batch: (X, y, w) X, y, w = batch # get representation reps_c = predict_fun_repr(params[0], X) reps_o = predict_fun_repr_small(params[1], X) reps_w = predict_fun_repr_small(params[2], X) # concatenate reps_po = _concatenate_representations((reps_c, reps_o)) reps_prop = _concatenate_representations((reps_c, reps_w)) # pass down to heads loss_0 = loss_head(params[3], (reps_po, y, 1 - w), penalty_l2) loss_1 = loss_head(params[4], (reps_po, y, w), penalty_l2) # pass down to propensity head loss_prop = loss_head_prop(params[5], (reps_prop, w), penalty_l2) weightsq_prop = sum( [ jnp.sum(params[5][i][0] ** 2) for i in range(0, 2 * n_layers_out_prop + 1, 2) ] ) # which variable has impact on which representation col_c = _get_absolute_rowsums(params[0][0][0]) col_o = _get_absolute_rowsums(params[1][0][0]) col_w = _get_absolute_rowsums(params[2][0][0]) loss_o = penalty_orthogonal * ( jnp.sum(col_c * col_o + col_c * col_w + col_w * col_o) ) # is rep_o balanced between groups? loss_disc = penalty_disc * mmd2_lin(reps_o, w) # weight decay on representations weightsq_body = sum( [ sum( [jnp.sum(params[j][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)] ) for j in range(3) ] ) weightsq_head = heads_l2_penalty( params[3], params[4], n_layers_out, reg_diff, penalty_l2, penalty_diff ) if not avg_objective: return ( loss_0 + loss_1 + loss_prop + loss_o + loss_disc + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) ) else: n_batch = y.shape[0] return ( (loss_0 + loss_1) / n_batch + loss_prop / n_batch + loss_o + loss_disc + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) ) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_orthogonal: float, penalty_disc: float, ) -> jnp.ndarray: # updating function params = get_params(state) return opt_update( i, grad(loss_snet3)( params, batch, penalty_l2, penalty_orthogonal, penalty_disc ), state, ) # initialise states _, init_params = init_fun_snet3(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_orthogonal, penalty_disc, ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_snet3( params_curr, (X_val, y_val, w_val), penalty_l2, penalty_orthogonal, penalty_disc, ) if i % n_iter_print == 0: log.info(f"Epoch: {i}, current {val_string} loss {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_snet3(params_curr, (X_val, y_val, w_val), 0, 0, 0) return ( params_curr, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), l_final, ) return params_curr, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_snet3(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return ( trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), l_final, ) return trained_params, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) def predict_snet3( X: jnp.ndarray, trained_params: dict, predict_funs: list, return_po: bool = False, return_prop: bool = False, ) -> jnp.ndarray: # unpack inputs predict_fun_repr, predict_fun_head, predict_fun_prop = predict_funs param_repr_c, param_repr_o, param_repr_t = ( trained_params[0], trained_params[1], trained_params[2], ) param_0, param_1, param_prop = ( trained_params[3], trained_params[4], trained_params[5], ) # get representations rep_c = predict_fun_repr(param_repr_c, X) rep_o = predict_fun_repr(param_repr_o, X) rep_w = predict_fun_repr(param_repr_t, X) # concatenate reps_po = jnp.concatenate((rep_c, rep_o), axis=1) reps_prop = jnp.concatenate((rep_c, rep_w), axis=1) # get potential outcomes mu_0 = predict_fun_head(param_0, reps_po) mu_1 = predict_fun_head(param_1, reps_po) te = mu_1 - mu_0 if return_prop: # get propensity prop = predict_fun_prop(param_prop, reps_prop) # stack other outputs if return_po: if return_prop: return te, mu_0, mu_1, prop else: return te, mu_0, mu_1 else: if return_prop: return te, prop else: return te ================================================ FILE: catenets/models/jax/flextenet.py ================================================ """ Module implements FlexTENet, also referred to as the 'flexible approach' in "On inductive biases for heterogeneous treatment effect estimation", Curth & vd Schaar (2021). """ # Author: Alicia Curth from typing import Any, Callable, Optional, Tuple import jax.numpy as jnp import numpy as onp from jax import grad, jit, random from jax.example_libraries import optimizers from jax.example_libraries.stax import ( Dense, Sigmoid, elu, glorot_normal, normal, serial, ) import catenets.logger as log from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_DIM_P_OUT, DEFAULT_DIM_P_R, DEFAULT_DIM_S_OUT, DEFAULT_DIM_S_R, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_PENALTY_ORTHOGONAL, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_VAL_SPLIT, LARGE_VAL, N_SUBSPACES, ) from catenets.models.jax.base import BaseCATENet from catenets.models.jax.model_utils import check_shape_1d_data, make_val_split class FlexTENet(BaseCATENet): """ Module implements FlexTENet, an architecture for treatment effect estimation that allows for both shared and private information in each layer of the network. Parameters ---------- binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_s_out: int Number of hidden units in each shared hypothesis layer n_units_p_out: int Number of hidden units in each private hypothesis layer n_layers_r: int Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_s_r: int Number of hidden units in each shared representation layer n_units_s_r: int Number of hidden units in each private representation layer private_out: bool, False Whether the final prediction layer should be fully private, or retain a shared component. penalty_l2: float l2 (ridge) penalty penalty_l2_p: float l2 (ridge) penalty for private layers penalty_orthogonal: float orthogonalisation penalty step_size: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used opt: str, default 'adam' Optimizer to use, accepts 'adam' and 'sgd' shared_repr: bool, False Whether to use a shared representation block as TARNet pretrain_shared: bool, False Whether to pretrain the shared component of the network while freezing the private parameters same_init: bool, True Whether to use the same initialisation for all private spaces lr_scale: float Whether to scale down the learning rate after unfreezing the private components of the network (only used if pretrain_shared=True) normalize_ortho: bool, False Whether to normalize the orthogonality penalty (by depth of network) """ def __init__( self, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_s_out: int = DEFAULT_DIM_S_OUT, n_units_p_out: int = DEFAULT_DIM_P_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_s_r: int = DEFAULT_DIM_S_R, n_units_p_r: int = DEFAULT_DIM_P_R, private_out: bool = False, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_p: float = DEFAULT_PENALTY_L2, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, opt: str = "adam", shared_repr: bool = False, pretrain_shared: bool = False, same_init: bool = True, lr_scale: float = 10, normalize_ortho: bool = False, ) -> None: self.binary_y = binary_y self.n_layers_r = n_layers_r self.n_layers_out = n_layers_out self.n_units_s_out = n_units_s_out self.n_units_p_out = n_units_p_out self.n_units_s_r = n_units_s_r self.n_units_p_r = n_units_p_r self.private_out = private_out self.penalty_orthogonal = penalty_orthogonal self.penalty_l2 = penalty_l2 self.penalty_l2_p = penalty_l2_p self.step_size = step_size self.n_iter = n_iter self.batch_size = batch_size self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min self.opt = opt self.same_init = same_init self.shared_repr = shared_repr self.normalize_ortho = normalize_ortho self.pretrain_shared = pretrain_shared self.lr_scale = lr_scale self.seed = seed self.n_iter_print = n_iter_print self.return_val_loss = return_val_loss def _get_train_function(self) -> Callable: return train_flextenet def _get_predict_function(self) -> Callable: return predict_flextenet def train_flextenet( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_s_out: int = DEFAULT_DIM_S_OUT, n_units_p_out: int = DEFAULT_DIM_P_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_s_r: int = DEFAULT_DIM_S_R, n_units_p_r: int = DEFAULT_DIM_P_R, private_out: bool = False, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_p: float = DEFAULT_PENALTY_L2, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, avg_objective: bool = True, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, opt: str = "adam", shared_repr: bool = False, pretrain_shared: bool = False, same_init: bool = True, lr_scale: float = 10, normalize_ortho: bool = False, nonlin: str = DEFAULT_NONLIN, n_units_r: Optional[int] = None, n_units_out: Optional[int] = None, ) -> Tuple: # TODO incorporate different nonlins here # function to train a single output head # input check y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get output head init_fun, predict_fun = FlexTENetArchitecture( n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_p_r=n_units_p_r, n_units_p_out=n_units_p_out, n_units_s_r=n_units_s_r, n_units_s_out=n_units_s_out, private_out=private_out, shared_repr=shared_repr, same_init=same_init, binary_y=binary_y, ) # get functions if not binary_y: # define loss and grad @jit def loss( params: jnp.ndarray, batch: jnp.ndarray, penalty_l2: float, penalty_l2_p: float, penalty_orthogonal: float, mode: int, ) -> jnp.ndarray: # mse loss function inputs, targets = batch preds = predict_fun(params, inputs, mode=mode) penalty = _compute_penalty( params, n_layers_out, n_layers_r, private_out, penalty_l2, penalty_l2_p, penalty_orthogonal, shared_repr, normalize_ortho, mode, ) if not avg_objective: return jnp.sum((preds - targets) ** 2) + penalty else: return jnp.average((preds - targets) ** 2) + penalty else: # get loss and grad @jit def loss( params: jnp.ndarray, batch: jnp.ndarray, penalty_l2: float, penalty_l2_p: float, penalty_orthogonal: float, mode: int, ) -> jnp.ndarray: # mse loss function inputs, targets = batch preds = predict_fun(params, inputs, mode=mode) penalty = _compute_penalty( params, n_layers_out, n_layers_r, private_out, penalty_l2, penalty_l2_p, penalty_orthogonal, shared_repr, normalize_ortho, mode, ) if not avg_objective: return ( -jnp.sum( targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds) ) + penalty ) else: return ( -jnp.average( targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds) ) + penalty ) # set optimization routine # set optimizer if opt == "adam": opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) elif opt == "sgd": opt_init, opt_update, get_params = optimizers.sgd(step_size=step_size) else: raise ValueError("opt should be adam or sgd") # set update function @jit def update( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_l2_p: float, penalty_orthogonal: float, mode: int, ) -> jnp.ndarray: params = get_params(state) g_params = grad(loss)( params, batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode ) return opt_update(i, g_params, state) # initialise states _, init_params = init_fun(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training if not pretrain_shared: # train entire model together for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode=1, ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss( params_curr, ((X_val, w_val), y_val), penalty_l2, penalty_l2_p, penalty_orthogonal, mode=1, ) if i % n_iter_print == 0: log.debug(f"Epoch: {i}, current {val_string} loss: {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 else: p_curr = p_curr + 1 if p_curr > patience: trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss( trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1 ) return trained_params, predict_fun, l_final return trained_params, predict_fun # get final parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss(trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1) return trained_params, predict_fun, l_final return trained_params, predict_fun else: # Step 1: pretrain only shared bit of network (mode=0) for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode=0, ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss( params_curr, ((X_val, w_val), y_val), penalty_l2, penalty_l2_p, penalty_orthogonal, mode=0, ) if i % n_iter_print == 0: log.debug( f"Pre-training epoch: {i}, current {val_string} loss: {l_curr}" ) if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 else: p_curr = p_curr + 1 if p_curr > patience: break # get final parameters pre_trained_params = get_params(opt_state) # Step 2: train also private parts of network (mode=1) # set new optimizer if opt == "adam": opt_init2, opt_update2, get_params2 = optimizers.adam( step_size=step_size / lr_scale ) elif opt == "sgd": opt_init2, opt_update2, get_params2 = optimizers.sgd( step_size=step_size / lr_scale ) else: raise ValueError("opt should be adam or sgd") # set update function @jit def update2( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_l2_p: float, penalty_orthogonal: float, mode: int, ) -> Any: params = get_params(state) g_params = grad(loss)( params, batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode ) return opt_update2(i, g_params, state) opt_state = opt_init2(pre_trained_params) l_best = LARGE_VAL p_curr = 0 # train full for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :] opt_state = update2( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode=1, ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params2(opt_state) l_curr = loss( params_curr, ((X_val, w_val), y_val), penalty_l2, penalty_l2_p, penalty_orthogonal, mode=1, ) if i % n_iter_print == 0: log.debug(f"Epoch: {i}, current {val_string} loss: {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 else: p_curr = p_curr + 1 if p_curr > patience: trained_params = get_params2(opt_state) if return_val_loss: # return loss without penalty l_final = loss( trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1 ) return trained_params, predict_fun, l_final return trained_params, predict_fun # get final parameters trained_params = get_params2(opt_state) if return_val_loss: # return loss without penalty l_final = loss(trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1) return trained_params, predict_fun, l_final return trained_params, predict_fun def predict_flextenet( X: jnp.ndarray, trained_params: jnp.ndarray, predict_funs: Callable, return_po: bool = False, return_prop: bool = False, ) -> Any: # unpack inputs n, _ = X.shape W1 = check_shape_1d_data(jnp.ones(n)) W0 = check_shape_1d_data(jnp.zeros(n)) # get potential outcomes mu_0 = predict_funs(trained_params, (X, W0)) mu_1 = predict_funs(trained_params, (X, W1)) te = mu_1 - mu_0 if return_prop: raise ValueError("does not have propensity score estimator") # stack other outputs if return_po: return te, mu_0, mu_1 else: return te # helper functions for training def _get_cos_reg( params_0: jnp.ndarray, params_1: jnp.ndarray, normalize: bool ) -> jnp.ndarray: if normalize: params_0 = params_0 / jnp.linalg.norm(params_0, axis=0) params_1 = params_1 / jnp.linalg.norm(params_1, axis=0) return jnp.linalg.norm(jnp.dot(jnp.transpose(params_0), params_1), "fro") ** 2 def _compute_ortho_penalty_asymmetric( params: jnp.ndarray, n_layers_out: int, n_layers_r: int, private_out: int, penalty_orthogonal: float, shared_repr: bool, normalize_ortho: bool, mode: int = 1, ) -> float: # where to start counting: is there a fully shared representation? if shared_repr: lb = 2 * n_layers_r else: lb = 0 n_in = [ params[i][0][0].shape[0] for i in range(lb, 2 * (n_layers_out + n_layers_r), 2) ] ortho_body = _get_cos_reg(params[lb][1][0], params[lb][2][0], normalize_ortho) ortho_body = ortho_body + sum( [ _get_cos_reg( params[i][0][0], params[i][1][0][: n_in[int(i / 2 - lb / 2)], :], normalize_ortho, ) + _get_cos_reg( params[i][0][0], params[i][2][0][: n_in[int(i / 2 - lb / 2)], :], normalize_ortho, ) for i in range(lb, 2 * (n_layers_out + n_layers_r), 2) ] ) if not private_out: # add also orthogonal regularization on final layer idx_out = 2 * (n_layers_r + n_layers_out) n_idx = params[idx_out][0][0].shape[0] ortho_body = ( ortho_body + _get_cos_reg( params[idx_out][0][0], params[idx_out][1][0][:n_idx, :], normalize_ortho, ) + _get_cos_reg( params[idx_out][0][0], params[idx_out][2][0][:n_idx, :], normalize_ortho ) ) return mode * penalty_orthogonal * ortho_body def _compute_penalty_l2( params: jnp.ndarray, n_layers_out: int, n_layers_r: int, private_out: int, penalty_l2: float, penalty_l2_p: float, shared_repr: bool, mode: int = 1, ) -> jnp.ndarray: n_bodys = N_SUBSPACES # compute l2 penalty if shared_repr: # get representation and then heads weightsq_body = penalty_l2 * sum( [jnp.sum(params[i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)] ) weightsq_body = weightsq_body + penalty_l2 * sum( [ jnp.sum(params[i][0][0] ** 2) for i in range(2 * n_layers_r, 2 * (n_layers_out + n_layers_r), 2) ] ) weightsq_body = weightsq_body + penalty_l2_p * mode * sum( [ sum( [ jnp.sum(params[i][j][0] ** 2) for i in range( 2 * n_layers_r, 2 * (n_layers_out + n_layers_r), 2 ) ] ) for j in range(1, n_bodys) ] ) else: weightsq_body = penalty_l2 * sum( [ jnp.sum(params[i][0][0] ** 2) for i in range(0, 2 * (n_layers_out + n_layers_r), 2) ] ) weightsq_body = weightsq_body + penalty_l2_p * mode * sum( [ sum( [ jnp.sum(params[i][j][0] ** 2) for i in range(0, 2 * (n_layers_out + n_layers_r), 2) ] ) for j in range(1, n_bodys) ] ) idx_out = 2 * (n_layers_r + n_layers_out) if private_out: weightsq = ( weightsq_body + penalty_l2 * jnp.sum(params[idx_out][0][0] ** 2) + jnp.sum(params[idx_out][1][0] ** 2) ) else: weightsq = ( weightsq_body + penalty_l2 * jnp.sum(params[idx_out][0][0] ** 2) + penalty_l2_p * mode * jnp.sum(params[idx_out][1][0] ** 2) + penalty_l2_p * mode * jnp.sum(params[idx_out][2][0] ** 2) ) return 0.5 * weightsq def _compute_penalty( params: jnp.ndarray, n_layers_out: int, n_layers_r: int, private_out: int, penalty_l2: float, penalty_l2_p: float, penalty_orthogonal: float, shared_repr: bool, normalize_ortho: bool, mode: int = 1, ) -> jnp.ndarray: l2_penalty = _compute_penalty_l2( params, n_layers_out, n_layers_r, private_out, penalty_l2, penalty_l2_p, shared_repr, mode, ) ortho_penalty = _compute_ortho_penalty_asymmetric( params, n_layers_out, n_layers_r, private_out, penalty_orthogonal, shared_repr, normalize_ortho, mode, ) return l2_penalty + ortho_penalty # ------------------------------------------------------------ # construction of FlexTENetlayers/architecture def SplitLayerAsymmetric( n_units_s: int, n_units_p: int, first_layer: bool = False, same_init: bool = True ) -> Tuple: # create multitask layer has shape [shared, private_0, private_1] init_s, apply_s = Dense(n_units_s) init_p, apply_p = Dense(n_units_p) def init_fun(rng: float, input_shape: Tuple) -> Tuple: if first_layer: # put input shape in expected format input_shape = (input_shape, input_shape, input_shape) out_shape = ( input_shape[0][:-1] + (n_units_s,), input_shape[1][:-1] + (n_units_p + n_units_s,), input_shape[2][:-1] + (n_units_p + n_units_s,), ) rng_1, rng_2, rng_3 = random.split(rng, N_SUBSPACES) if same_init: # use same init for the two private layers return out_shape, ( init_s(rng_1, input_shape[0])[1], init_p(rng_2, input_shape[1])[1], init_p(rng_2, input_shape[2])[1], ) else: # initialise all separately return out_shape, ( init_s(rng_1, input_shape[0])[1], init_p(rng_2, input_shape[1])[1], init_p(rng_3, input_shape[2])[1], ) def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple: mode = kwargs["mode"] if "mode" in kwargs.keys() else 1 if first_layer: # X is the only input X, W = inputs rep_s = apply_s(params[0], X) rep_p0 = mode * apply_p(params[1], X) rep_p1 = mode * apply_p(params[2], X) else: X_s, X_p0, X_p1, W = inputs rep_s = apply_s(params[0], X_s) rep_p0 = mode * apply_p(params[1], jnp.concatenate([X_s, X_p0], axis=1)) rep_p1 = mode * apply_p(params[2], jnp.concatenate([X_s, X_p1], axis=1)) return (rep_s, rep_p0, rep_p1, W) return init_fun, apply_fun def TEOutputLayerAsymmetric(private: bool = True, same_init: bool = True) -> Tuple: init_f, apply_f = Dense(1) if private: # the two output layers are private def init_fun(rng: float, input_shape: Tuple) -> Tuple: out_shape = input_shape[1][:-1] + (1,) rng_1, rng_2 = random.split(rng, N_SUBSPACES - 1) return out_shape, ( init_f(rng_1, input_shape[1])[1], init_f(rng_2, input_shape[2])[1], ) def apply_fun(params: jnp.ndarray, inputs: Tuple, **kwargs: Any) -> jnp.ndarray: X_s, X_p0, X_p1, W = inputs rep_p0 = apply_f(params[0], jnp.concatenate([X_s, X_p0], axis=1)) rep_p1 = apply_f(params[1], jnp.concatenate([X_s, X_p1], axis=1)) return (1 - W) * rep_p0 + W * rep_p1 else: # also have a shared piece of output layer def init_fun(rng: float, input_shape: Tuple) -> Tuple: out_shape = input_shape[1][:-1] + (1,) rng_1, rng_2, rng_3 = random.split(rng, N_SUBSPACES) if same_init: return out_shape, ( init_f(rng_1, input_shape[0])[1], init_f(rng_2, input_shape[1])[1], init_f(rng_2, input_shape[2])[1], ) else: return out_shape, ( init_f(rng_1, input_shape[0])[1], init_f(rng_2, input_shape[1])[1], init_f(rng_3, input_shape[2])[1], ) def apply_fun(params: jnp.ndarray, inputs: Tuple, **kwargs: Any) -> jnp.ndarray: mode = kwargs["mode"] if "mode" in kwargs.keys() else 1 X_s, X_p0, X_p1, W = inputs rep_s = apply_f(params[0], X_s) rep_p0 = mode * apply_f(params[1], jnp.concatenate([X_s, X_p0], axis=1)) rep_p1 = mode * apply_f(params[2], jnp.concatenate([X_s, X_p1], axis=1)) return (1 - W) * rep_p0 + W * rep_p1 + rep_s return init_fun, apply_fun def FlexTENetArchitecture( n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_s_out: int = DEFAULT_DIM_S_OUT, n_units_p_out: int = DEFAULT_DIM_P_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_s_r: int = DEFAULT_DIM_S_R, n_units_p_r: int = DEFAULT_DIM_P_R, private_out: bool = False, binary_y: bool = False, shared_repr: bool = False, same_init: bool = True, ) -> Any: if n_layers_out < 1: raise ValueError( "FlexTENet needs at least one hidden output layer (else there are no " "parameters to be shared)" ) Nonlin_Elu = Elu_parallel Layer = SplitLayerAsymmetric Head = TEOutputLayerAsymmetric # give broader body (as in e.g. TARNet) has_body = n_layers_r > 0 layers: Tuple = () if has_body: # representation block first if shared_repr: # fully shared representation as in TARNet layers = (DenseW(n_units_s_r), Elu_split) # add required number of layers for i in range(n_layers_r - 1): layers = (*layers, DenseW(n_units_s_r), Elu_split) else: # shared AND private representations layers = ( Layer(n_units_s_r, n_units_p_r, first_layer=True, same_init=same_init), Nonlin_Elu, ) # add required number of layers for i in range(n_layers_r - 1): layers = ( *layers, Layer(n_units_s_r, n_units_p_r, same_init=same_init), Nonlin_Elu, ) else: layers = () # add output layers first_layer = (has_body is False) | (shared_repr is True) layers = ( *layers, Layer( n_units_s_out, n_units_p_out, first_layer=first_layer, same_init=same_init ), Nonlin_Elu, ) if n_layers_out > 1: # add required number of layers for i in range(n_layers_out - 1): layers = ( *layers, Layer(n_units_s_out, n_units_p_out, same_init=same_init), Nonlin_Elu, ) # return final architecture if not binary_y: return serial(*layers, Head(private=private_out, same_init=same_init)) else: return serial(*layers, Head(private=private_out, same_init=same_init), Sigmoid) # ------------------------------------------------ # rewrite some jax.stax code to allow different input types to be passed def elementwise_split(fun: Callable, **fun_kwargs: Any) -> Tuple: """Layer that applies a scalar function elementwise on its inputs. Adapted from original jax.stax to skip treatment indicator. Input looks like: X, t = inputs""" def init_fun(rng: float, input_shape: Tuple) -> Tuple: return (input_shape, ()) def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple: return fun(inputs[0], **fun_kwargs), inputs[1] return init_fun, apply_fun Elu_split = elementwise_split(elu) def elementwise_parallel(fun: Callable, **fun_kwargs: Any) -> Tuple: """Layer that applies a scalar function elementwise on its inputs. Adapted from original jax.stax to allow three inputs and to skip treatment indicator. Input looks like: X_s, X_p0, X_p1, t = inputs """ def init_fun(rng: float, input_shape: Tuple) -> Tuple: return input_shape, () def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple: return ( fun(inputs[0], **fun_kwargs), fun(inputs[1], **fun_kwargs), fun(inputs[2], **fun_kwargs), inputs[3], ) return init_fun, apply_fun Elu_parallel = elementwise_parallel(elu) def DenseW( out_dim: int, W_init: Callable = glorot_normal(), b_init: Callable = normal() ) -> Tuple: """Layer constructor function for a dense (fully-connected) layer. Adapted to allow passing treatment indicator through layer without using it""" def init_fun(rng: float, input_shape: Tuple) -> Tuple: output_shape = input_shape[:-1] + (out_dim,) k1, k2 = random.split(rng) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,)) return output_shape, (W, b) def apply_fun( params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any ) -> jnp.ndarray: W, b = params x, t = inputs return (jnp.dot(x, W) + b, t) return init_fun, apply_fun ================================================ FILE: catenets/models/jax/model_utils.py ================================================ """ Model utils shared across different nets """ # Author: Alicia Curth from typing import Any, Optional import jax.numpy as jnp import pandas as pd from sklearn.model_selection import train_test_split from catenets.models.constants import DEFAULT_SEED, DEFAULT_VAL_SPLIT TRAIN_STRING = "training" VALIDATION_STRING = "validation" def check_shape_1d_data(y: jnp.ndarray) -> jnp.ndarray: # helper func to ensure that output shape won't clash # with jax internally shape_y = y.shape if len(shape_y) == 1: # should be shape (n_obs, 1), not (n_obs,) return y.reshape((shape_y[0], 1)) return y def check_X_is_np(X: pd.DataFrame) -> jnp.ndarray: # function to make sure we are using arrays only return jnp.asarray(X) def make_val_split( X: jnp.ndarray, y: jnp.ndarray, w: Optional[jnp.ndarray] = None, val_split_prop: float = DEFAULT_VAL_SPLIT, seed: int = DEFAULT_SEED, stratify_w: bool = True, ) -> Any: if val_split_prop == 0: # return original data if w is None: return X, y, X, y, TRAIN_STRING return X, y, w, X, y, w, TRAIN_STRING # make actual split if w is None: X_t, X_val, y_t, y_val = train_test_split( X, y, test_size=val_split_prop, random_state=seed, shuffle=True ) return X_t, y_t, X_val, y_val, VALIDATION_STRING if stratify_w: # split to stratify by group X_t, X_val, y_t, y_val, w_t, w_val = train_test_split( X, y, w, test_size=val_split_prop, random_state=seed, stratify=w, shuffle=True, ) else: X_t, X_val, y_t, y_val, w_t, w_val = train_test_split( X, y, w, test_size=val_split_prop, random_state=seed, shuffle=True ) return X_t, y_t, w_t, X_val, y_val, w_val, VALIDATION_STRING def heads_l2_penalty( params_0: jnp.ndarray, params_1: jnp.ndarray, n_layers_out: jnp.ndarray, reg_diff: jnp.ndarray, penalty_0: jnp.ndarray, penalty_1: jnp.ndarray, ) -> jnp.ndarray: # Compute l2 penalty for output heads. Either seperately, or regularizing their difference # get l2-penalty for first head weightsq_0 = penalty_0 * sum( [jnp.sum(params_0[i][0] ** 2) for i in range(0, 2 * n_layers_out + 1, 2)] ) # get l2-penalty for second head if reg_diff: weightsq_1 = penalty_1 * sum( [ jnp.sum((params_1[i][0] - params_0[i][0]) ** 2) for i in range(0, 2 * n_layers_out + 1, 2) ] ) else: weightsq_1 = penalty_1 * sum( [jnp.sum(params_1[i][0] ** 2) for i in range(0, 2 * n_layers_out + 1, 2)] ) return weightsq_1 + weightsq_0 ================================================ FILE: catenets/models/jax/offsetnet.py ================================================ """ Module implements OffsetNet, also referred to as the 'reparametrization approach' and 'hard approach' in "On inductive biases for heterogeneous treatment effect estimation", Curth & vd Schaar (2021); modeling the POs using a shared prognostic function and an offset (treatment effect) """ # Author: Alicia Curth from typing import Any, Callable, List, Tuple import jax.numpy as jnp import numpy as onp from jax import grad, jit, random from jax.example_libraries import optimizers from jax.example_libraries.stax import sigmoid import catenets.logger as log from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.jax.base import BaseCATENet, OutputHead from catenets.models.jax.model_utils import ( check_shape_1d_data, heads_l2_penalty, make_val_split, ) class OffsetNet(BaseCATENet): """ Module implements OffsetNet, also referred to as the 'reparametrization approach' and 'hard approach' in Curth & vd Schaar (2021); modeling the POs using a shared prognostic function and an offset (treatment effect). Parameters ---------- binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int Number of hidden units in each hypothesis layer n_layers_r: int Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_r: int Number of hidden units in each representation layer penalty_l2: float l2 (ridge) penalty step_size: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used penalty_l2_p: float l2-penalty for regularizing the offset nonlin: string, default 'elu' Nonlinearity to use in NN """ def __init__( self, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_p: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, ): # structure of net self.binary_y = binary_y self.n_layers_r = n_layers_r self.n_layers_out = n_layers_out self.n_units_r = n_units_r self.n_units_out = n_units_out self.nonlin = nonlin # penalties self.penalty_l2 = penalty_l2 self.penalty_l2_p = penalty_l2_p # training params self.step_size = step_size self.n_iter = n_iter self.batch_size = batch_size self.n_iter_print = n_iter_print self.seed = seed self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min def _get_train_function(self) -> Callable: return train_offsetnet def _get_predict_function(self) -> Callable: return predict_offsetnet def predict_offsetnet( X: jnp.ndarray, trained_params: jnp.ndarray, predict_funs: List[Any], return_po: bool = False, return_prop: bool = False, ) -> jnp.ndarray: if return_prop: raise NotImplementedError("OffsetNet does not implement a propensity model.") # unpack inputs predict_fun_head = predict_funs[0] binary_y = predict_funs[1] param_0, param_1 = trained_params[0], trained_params[1] # get potential outcomes mu_0 = predict_fun_head(param_0, X) offset = predict_fun_head(param_1, X) if not binary_y: if return_po: return offset, mu_0, mu_0 + offset else: return offset else: # still need to sigmoid po_0 = sigmoid(mu_0) po_1 = sigmoid(mu_0 + offset) if return_po: return po_1 - po_0, po_0, po_1 else: return po_1 - po_0 def train_offsetnet( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_p: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = True, ) -> Tuple: # input check y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get output head functions (both heads share same structure) init_fun_head, predict_fun_head = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=False, n_layers_r=n_layers_r, n_units_r=n_units_r, nonlin=nonlin, ) def init_fun_offset(rng: float, input_shape: Tuple) -> Tuple: # chain together the layers # param should look like [param_base, param_offset] rng, layer_rng = random.split(rng) _, param_base = init_fun_head(layer_rng, input_shape) rng, layer_rng = random.split(rng) input_shape, param_offset = init_fun_head(layer_rng, input_shape) return input_shape, [param_base, param_offset] # Define loss functions if not binary_y: @jit def loss_offsetnet( params: jnp.ndarray, batch: jnp.ndarray, penalty: float, penalty_l2_p: float ) -> jnp.ndarray: # params: list[representation, head_0, head_1] # batch: (X, y, w) inputs, targets, w = batch preds_0 = predict_fun_head(params[0], inputs) offset = predict_fun_head(params[1], inputs) preds = preds_0 + w * offset weightsq_head = heads_l2_penalty( params[0], params[1], n_layers_out + n_layers_r, False, penalty, penalty_l2_p, ) if not avg_objective: return jnp.sum((preds - targets) ** 2) + 0.5 * weightsq_head else: return jnp.average((preds - targets) ** 2) + 0.5 * weightsq_head else: def loss_offsetnet( params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty: float, penalty_l2_p: float, ) -> jnp.ndarray: # params: list[representation, head_0, head_1] # batch: (X, y, w) inputs, targets, w = batch preds_0 = predict_fun_head(params[0], inputs) offset = predict_fun_head(params[1], inputs) preds = sigmoid(preds_0 + w * offset) weightsq_head = heads_l2_penalty( params[0], params[1], n_layers_out + n_layers_r, False, penalty, penalty_l2_p, ) if not avg_objective: return ( -jnp.sum( (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) ) + 0.5 * weightsq_head ) else: n_batch = y.shape[0] return ( -jnp.sum( (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) ) / n_batch + 0.5 * weightsq_head ) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_l2_p: float ) -> jnp.ndarray: # updating function params = get_params(state) return opt_update( i, grad(loss_offsetnet)(params, batch, penalty_l2, penalty_l2_p), state ) # initialise states _, init_params = init_fun_offset(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 pred_funs = predict_fun_head, binary_y # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_l2_p ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_offsetnet( params_curr, (X_val, y_val, w_val), penalty_l2, penalty_l2_p ) if i % n_iter_print == 0: log.info(f"Epoch: {i}, current {val_string} loss {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): if l_curr < l_best: l_best = l_curr p_curr = 0 else: p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_offsetnet(params_curr, (X_val, y_val, w_val), 0, 0) return params_curr, pred_funs, l_final return params_curr, pred_funs # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_offsetnet(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return trained_params, pred_funs, l_final return trained_params, pred_funs ================================================ FILE: catenets/models/jax/pseudo_outcome_nets.py ================================================ """ Implements Pseudo-outcome based Two-step Nets, namely the DR-learner, the PW-learner and the RA-learner. """ # Author: Alicia Curth from typing import Callable, Optional, Tuple import jax.numpy as jnp import numpy as onp import pandas as pd from sklearn.model_selection import StratifiedKFold import catenets.logger as log from catenets.models.constants import ( DEFAULT_AVG_OBJECTIVE, DEFAULT_BATCH_SIZE, DEFAULT_CF_FOLDS, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_OUT_T, DEFAULT_LAYERS_R, DEFAULT_LAYERS_R_T, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_STEP_SIZE_T, DEFAULT_UNITS_OUT, DEFAULT_UNITS_OUT_T, DEFAULT_UNITS_R, DEFAULT_UNITS_R_T, DEFAULT_VAL_SPLIT, ) from catenets.models.jax.base import BaseCATENet, train_output_net_only from catenets.models.jax.disentangled_nets import predict_snet3, train_snet3 from catenets.models.jax.flextenet import predict_flextenet, train_flextenet from catenets.models.jax.model_utils import check_shape_1d_data, check_X_is_np from catenets.models.jax.offsetnet import predict_offsetnet, train_offsetnet from catenets.models.jax.representation_nets import ( predict_snet1, predict_snet2, train_snet1, train_snet2, ) from catenets.models.jax.snet import predict_snet, train_snet from catenets.models.jax.tnet import predict_t_net, train_tnet from catenets.models.jax.transformation_utils import ( DR_TRANSFORMATION, PW_TRANSFORMATION, RA_TRANSFORMATION, _get_transformation_function, ) T_STRATEGY = "T" S1_STRATEGY = "Tar" S2_STRATEGY = "S2" S3_STRATEGY = "S3" S_STRATEGY = "S" OFFSET_STRATEGY = "Offset" FLEX_STRATEGY = "Flex" ALL_STRATEGIES = [ T_STRATEGY, S1_STRATEGY, S2_STRATEGY, S3_STRATEGY, S_STRATEGY, FLEX_STRATEGY, OFFSET_STRATEGY, ] class PseudoOutcomeNet(BaseCATENet): """ Class implements TwoStepLearners based on pseudo-outcome regression as discussed in Curth &vd Schaar (2021): RA-learner, PW-learner and DR-learner Parameters ---------- first_stage_strategy: str, default 't' which nuisance estimator to use in first stage first_stage_args: dict Any additional arguments to pass to first stage training function data_split: bool, default False Whether to split the data in two folds for estimation cross_fit: bool, default False Whether to perform cross fitting n_cf_folds: int Number of crossfitting folds to use transformation: str, default 'AIPW' pseudo-outcome to use ('AIPW' for DR-learner, 'HT' for PW learner, 'RA' for RA-learner) binary_y: bool, default False Whether the outcome is binary n_layers_out: int First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int First stage Number of hidden units in each hypothesis layer n_layers_r: int First stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_r: int First stage Number of hidden units in each representation layer n_layers_out_t: int Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out_t: int Second stage Number of hidden units in each hypothesis layer n_layers_r_t: int Second stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_r_t: int Second stage Number of hidden units in each representation layer penalty_l2: float First stage l2 (ridge) penalty penalty_l2_t: float Second stage l2 (ridge) penalty step_size: float First stage learning rate for optimizer step_size_t: float Second stage learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used nonlin: string, default 'elu' Nonlinearity to use in NN """ def __init__( self, first_stage_strategy: str = T_STRATEGY, first_stage_args: Optional[dict] = None, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, transformation: str = DR_TRANSFORMATION, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, rescale_transformation: bool = False, nonlin: str = DEFAULT_NONLIN, ) -> None: # settings self.first_stage_strategy = first_stage_strategy self.first_stage_args = first_stage_args self.binary_y = binary_y self.transformation = transformation self.data_split = data_split self.cross_fit = cross_fit self.n_cf_folds = n_cf_folds # model architecture hyperparams self.n_layers_out = n_layers_out self.n_layers_out_t = n_layers_out_t self.n_layers_r = n_layers_r self.n_layers_r_t = n_layers_r_t self.n_units_out = n_units_out self.n_units_out_t = n_units_out_t self.n_units_r = n_units_r self.n_units_r_t = n_units_r_t self.nonlin = nonlin # other hyperparameters self.penalty_l2 = penalty_l2 self.penalty_l2_t = penalty_l2_t self.step_size = step_size self.step_size_t = step_size_t self.n_iter = n_iter self.batch_size = batch_size self.n_iter_print = n_iter_print self.seed = seed self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min self.rescale_transformation = rescale_transformation def _get_train_function(self) -> Callable: return train_pseudooutcome_net def fit( self, X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, p: Optional[jnp.ndarray] = None, ) -> "PseudoOutcomeNet": # overwrite super so we can pass p as extra param # some quick input checks X = check_X_is_np(X) self._check_inputs(w, p) train_func = self._get_train_function() train_params = self.get_params() if "transformation" not in train_params.keys(): train_params.update({"transformation": self.transformation}) if self.rescale_transformation: self._params, self._predict_funs, self._scale_factor = train_func( X, y, w, p, **train_params ) else: self._params, self._predict_funs = train_func(X, y, w, p, **train_params) return self def _get_predict_function(self) -> Callable: # Two step nets do not need this pass def predict( self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False ) -> jnp.ndarray: # check input if return_po: raise NotImplementedError( "TwoStepNets have no Potential outcome predictors." ) if return_prop: raise NotImplementedError("TwoStepNets have no Propensity predictors.") if isinstance(X, pd.DataFrame): X = X.values if self.rescale_transformation: return 1 / self._scale_factor * self._predict_funs(self._params, X) else: return self._predict_funs(self._params, X) class DRNet(PseudoOutcomeNet): """Wrapper for DR-learner using PseudoOutcomeNet""" def __init__( self, first_stage_strategy: str = T_STRATEGY, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, rescale_transformation: bool = False, nonlin: str = DEFAULT_NONLIN, first_stage_args: Optional[dict] = None, ) -> None: super().__init__( first_stage_strategy=first_stage_strategy, data_split=data_split, cross_fit=cross_fit, n_cf_folds=n_cf_folds, transformation=DR_TRANSFORMATION, binary_y=binary_y, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_layers_out_t=n_layers_out_t, n_layers_r_t=n_layers_r_t, n_units_out=n_units_out, n_units_r=n_units_r, n_units_out_t=n_units_out_t, n_units_r_t=n_units_r_t, penalty_l2=penalty_l2, penalty_l2_t=penalty_l2_t, step_size=step_size, step_size_t=step_size_t, n_iter=n_iter, batch_size=batch_size, n_iter_min=n_iter_min, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, rescale_transformation=rescale_transformation, first_stage_args=first_stage_args, ) class RANet(PseudoOutcomeNet): """Wrapper for RA-learner using PseudoOutcomeNet""" def __init__( self, first_stage_strategy: str = T_STRATEGY, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, rescale_transformation: bool = False, nonlin: str = DEFAULT_NONLIN, first_stage_args: Optional[dict] = None, ) -> None: super().__init__( first_stage_strategy=first_stage_strategy, data_split=data_split, cross_fit=cross_fit, n_cf_folds=n_cf_folds, transformation=RA_TRANSFORMATION, binary_y=binary_y, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_layers_out_t=n_layers_out_t, n_layers_r_t=n_layers_r_t, n_units_out=n_units_out, n_units_r=n_units_r, n_units_out_t=n_units_out_t, n_units_r_t=n_units_r_t, penalty_l2=penalty_l2, penalty_l2_t=penalty_l2_t, step_size=step_size, step_size_t=step_size_t, n_iter=n_iter, batch_size=batch_size, n_iter_min=n_iter_min, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, rescale_transformation=rescale_transformation, first_stage_args=first_stage_args, ) class PWNet(PseudoOutcomeNet): """Wrapper for PW-learner using PseudoOutcomeNet""" def __init__( self, first_stage_strategy: str = T_STRATEGY, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, rescale_transformation: bool = False, nonlin: str = DEFAULT_NONLIN, first_stage_args: Optional[dict] = None, ) -> None: super().__init__( first_stage_strategy=first_stage_strategy, data_split=data_split, cross_fit=cross_fit, n_cf_folds=n_cf_folds, transformation=PW_TRANSFORMATION, binary_y=binary_y, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_layers_out_t=n_layers_out_t, n_layers_r_t=n_layers_r_t, n_units_out=n_units_out, n_units_r=n_units_r, n_units_out_t=n_units_out_t, n_units_r_t=n_units_r_t, penalty_l2=penalty_l2, penalty_l2_t=penalty_l2_t, step_size=step_size, step_size_t=step_size_t, n_iter=n_iter, batch_size=batch_size, n_iter_min=n_iter_min, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, rescale_transformation=rescale_transformation, first_stage_args=first_stage_args, ) def train_pseudooutcome_net( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, p: Optional[jnp.ndarray] = None, first_stage_strategy: str = T_STRATEGY, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, transformation: str = DR_TRANSFORMATION, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, rescale_transformation: bool = False, return_val_loss: bool = False, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, first_stage_args: Optional[dict] = None, ) -> Tuple: # get shape of data n, d = X.shape if p is not None: p = check_shape_1d_data(p) # get transformation function transformation_function = _get_transformation_function(transformation) # get strategy name if first_stage_strategy not in ALL_STRATEGIES: raise ValueError( "Parameter first stage should be in " "catenets.models.pseudo_outcome_nets.ALL_STRATEGIES. " "You passed {}".format(first_stage_strategy) ) # split data as wanted if p is None or transformation is not PW_TRANSFORMATION: if not cross_fit: if not data_split: log.debug("Training first stage with all data (no data splitting)") # use all data for both fit_mask = onp.ones(n, dtype=bool) pred_mask = onp.ones(n, dtype=bool) else: log.debug("Training first stage with half of the data (data splitting)") # split data in half fit_idx = onp.random.choice(n, int(onp.round(n / 2))) fit_mask = onp.zeros(n, dtype=bool) fit_mask[fit_idx] = 1 pred_mask = ~fit_mask mu_0, mu_1, pi_hat = _train_and_predict_first_stage( X, y, w, fit_mask, pred_mask, first_stage_strategy=first_stage_strategy, binary_y=binary_y, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_out=n_units_out, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, avg_objective=avg_objective, transformation=transformation, first_stage_args=first_stage_args, ) if data_split: # keep only prediction data X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :] if p is not None: p = p[pred_mask, :] else: log.debug(f"Training first stage in {n_cf_folds} folds (cross-fitting)") # do cross fitting mu_0, mu_1, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1)), onp.zeros((n, 1)) splitter = StratifiedKFold( n_splits=n_cf_folds, shuffle=True, random_state=seed ) fold_count = 1 for train_idx, test_idx in splitter.split(X, w): log.debug(f"Training fold {fold_count}.") fold_count = fold_count + 1 pred_mask = onp.zeros(n, dtype=bool) pred_mask[test_idx] = 1 fit_mask = ~pred_mask ( mu_0[pred_mask], mu_1[pred_mask], pi_hat[pred_mask], ) = _train_and_predict_first_stage( X, y, w, fit_mask, pred_mask, first_stage_strategy=first_stage_strategy, binary_y=binary_y, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_out=n_units_out, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, avg_objective=avg_objective, transformation=transformation, first_stage_args=first_stage_args, ) log.debug("Training second stage.") if p is not None: # use known propensity score p = check_shape_1d_data(p) pi_hat = p # second stage y, w = check_shape_1d_data(y), check_shape_1d_data(w) # transform data and fit on transformed data if transformation is PW_TRANSFORMATION: mu_0 = None mu_1 = None pseudo_outcome = transformation_function(y=y, w=w, p=pi_hat, mu_0=mu_0, mu_1=mu_1) if rescale_transformation: scale_factor = onp.std(y) / onp.std(pseudo_outcome) if scale_factor > 1: scale_factor = 1 else: pseudo_outcome = scale_factor * pseudo_outcome params, predict_funs = train_output_net_only( X, pseudo_outcome, binary_y=False, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective, ) return params, predict_funs, scale_factor else: return train_output_net_only( X, pseudo_outcome, binary_y=False, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective, ) def _train_and_predict_first_stage( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, fit_mask: jnp.ndarray, pred_mask: jnp.ndarray, first_stage_strategy: str, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = False, transformation: str = DR_TRANSFORMATION, first_stage_args: Optional[dict] = None, ) -> Tuple: if len(w.shape) > 1: w = w.reshape((len(w),)) if first_stage_args is None: first_stage_args = {} # split the data X_fit, y_fit, w_fit = X[fit_mask, :], y[fit_mask], w[fit_mask] X_pred = X[pred_mask, :] train_fun: Callable predict_fun: Callable if first_stage_strategy == T_STRATEGY: train_fun, predict_fun = train_tnet, predict_t_net elif first_stage_strategy == S_STRATEGY: train_fun, predict_fun = train_snet, predict_snet elif first_stage_strategy == S1_STRATEGY: train_fun, predict_fun = train_snet1, predict_snet1 elif first_stage_strategy == S2_STRATEGY: train_fun, predict_fun = train_snet2, predict_snet2 elif first_stage_strategy == S3_STRATEGY: train_fun, predict_fun = train_snet3, predict_snet3 elif first_stage_strategy == OFFSET_STRATEGY: train_fun, predict_fun = train_offsetnet, predict_offsetnet elif first_stage_strategy == FLEX_STRATEGY: train_fun, predict_fun = train_flextenet, predict_flextenet else: raise ValueError( "{} is not a valid first stage strategy for a PseudoOutcomeNet".format( first_stage_strategy ) ) log.debug("Training PO estimators") trained_params, pred_fun = train_fun( X_fit, y_fit, w_fit, binary_y=binary_y, n_layers_r=n_layers_r, n_units_r=n_units_r, n_layers_out=n_layers_out, n_units_out=n_units_out, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, avg_objective=avg_objective, **first_stage_args, ) if first_stage_strategy in [S_STRATEGY, S2_STRATEGY, S3_STRATEGY]: _, mu_0, mu_1, pi_hat = predict_fun( X_pred, trained_params, pred_fun, return_po=True, return_prop=True ) else: if transformation is not PW_TRANSFORMATION: _, mu_0, mu_1 = predict_fun( X_pred, trained_params, pred_fun, return_po=True ) else: mu_0, mu_1 = onp.nan, onp.nan if transformation is not RA_TRANSFORMATION: log.debug("Training propensity net") params_prop, predict_fun_prop = train_output_net_only( X_fit, w_fit, binary_y=True, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, avg_objective=avg_objective, ) pi_hat = predict_fun_prop(params_prop, X_pred) else: pi_hat = onp.nan return mu_0, mu_1, pi_hat ================================================ FILE: catenets/models/jax/representation_nets.py ================================================ """ Module implements SNet1 and SNet2, which are based on CFRNet/TARNet from Shalit et al (2017) and DragonNet from Shi et al (2019), respectively. """ # Author: Alicia Curth from typing import Any, Callable, List, Tuple import jax.numpy as jnp import numpy as onp from jax import grad, jit, random from jax.example_libraries import optimizers import catenets.logger as log from catenets.models.constants import ( DEFAULT_AVG_OBJECTIVE, DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_DISC, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.jax.base import BaseCATENet, OutputHead, ReprBlock from catenets.models.jax.model_utils import ( check_shape_1d_data, heads_l2_penalty, make_val_split, ) class SNet1(BaseCATENet): """ Class implements Shalit et al (2017)'s TARNet & CFR (discrepancy regularization is NOT TESTED). Also referred to as SNet-1 in our paper. Parameters ---------- binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int Number of hidden units in each hypothesis layer n_layers_r: int Number of shared representation layers before hypothesis layers n_units_r: int Number of hidden units in each representation layer penalty_l2: float l2 (ridge) penalty step_size: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used reg_diff: bool, default False Whether to regularize the difference between the two potential outcome heads penalty_diff: float l2-penalty for regularizing the difference between output heads. used only if train_separate=False same_init: bool, False Whether to initialise the two output heads with same values nonlin: string, default 'elu' Nonlinearity to use in NN penalty_disc: float, default zero Discrepancy penalty. Defaults to zero as this feature is not tested. """ def __init__( self, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, same_init: bool = False, nonlin: str = DEFAULT_NONLIN, penalty_disc: float = DEFAULT_PENALTY_DISC, ) -> None: # structure of net self.binary_y = binary_y self.n_layers_r = n_layers_r self.n_layers_out = n_layers_out self.n_units_r = n_units_r self.n_units_out = n_units_out self.nonlin = nonlin # penalties self.penalty_l2 = penalty_l2 self.penalty_disc = penalty_disc self.reg_diff = reg_diff self.penalty_diff = penalty_diff self.same_init = same_init # training params self.step_size = step_size self.n_iter = n_iter self.batch_size = batch_size self.n_iter_print = n_iter_print self.seed = seed self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min def _get_train_function(self) -> Callable: return train_snet1 def _get_predict_function(self) -> Callable: return predict_snet1 class TARNet(SNet1): """Wrapper for TARNet""" def __init__( self, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, same_init: bool = False, nonlin: str = DEFAULT_NONLIN, ): super().__init__( binary_y=binary_y, n_layers_r=n_layers_r, n_units_r=n_units_r, n_layers_out=n_layers_out, n_units_out=n_units_out, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, reg_diff=reg_diff, penalty_diff=penalty_diff, same_init=same_init, nonlin=nonlin, penalty_disc=0, ) class SNet2(BaseCATENet): """ Class implements SNet-2, which is based on Shi et al (2019)'s DragonNet (this version does NOT use targeted regularization and has a (possibly deeper) propensity head. Parameters ---------- binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_layers_out_prop: int Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int Number of hidden units in each hypothesis layer n_units_out_prop: int Number of hidden units in each propensity score hypothesis layer n_layers_r: int Number of shared representation layers before hypothesis layers n_units_r: int Number of hidden units in each representation layer penalty_l2: float l2 (ridge) penalty step_size: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used reg_diff: bool, default False Whether to regularize the difference between the two potential outcome heads penalty_diff: float l2-penalty for regularizing the difference between output heads. used only if train_separate=False same_init: bool, False Whether to initialise the two output heads with same values nonlin: string, default 'elu' Nonlinearity to use in NN """ def __init__( self, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, reg_diff: bool = False, same_init: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, ) -> None: self.binary_y = binary_y self.n_layers_r = n_layers_r self.n_layers_out = n_layers_out self.n_layers_out_prop = n_layers_out_prop self.n_units_r = n_units_r self.n_units_out = n_units_out self.n_units_out_prop = n_units_out_prop self.nonlin = nonlin self.penalty_l2 = penalty_l2 self.step_size = step_size self.n_iter = n_iter self.batch_size = batch_size self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min self.reg_diff = reg_diff self.penalty_diff = penalty_diff self.same_init = same_init self.seed = seed self.n_iter_print = n_iter_print def _get_train_function(self) -> Callable: return train_snet2 def _get_predict_function(self) -> Callable: return predict_snet2 class DragonNet(SNet2): """Wrapper for DragonNet""" def __init__( self, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = 0, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, reg_diff: bool = False, same_init: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, ): super().__init__( binary_y=binary_y, n_layers_r=n_layers_r, n_units_r=n_units_r, n_layers_out=n_layers_out, n_units_out=n_units_out, penalty_l2=penalty_l2, n_units_out_prop=n_units_out_prop, n_layers_out_prop=n_layers_out_prop, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, reg_diff=reg_diff, penalty_diff=penalty_diff, same_init=same_init, nonlin=nonlin, ) # Training functions for SNet1 ------------------------------------------------- def mmd2_lin(X: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray: # Squared Linear MMD as implemented in CFR # jax does not support indexing, so this is a workaround with reweighting in means n = w.shape[0] n_t = jnp.sum(w) # normalize X so scale matters X = X / jnp.sqrt(jnp.var(X, axis=0)) mean_control = (n / (n - n_t)) * jnp.mean((1 - w) * X, axis=0) mean_treated = (n / n_t) * jnp.mean(w * X, axis=0) return jnp.sum((mean_treated - mean_control) ** 2) def predict_snet1( X: jnp.ndarray, trained_params: dict, predict_funs: list, return_po: bool = False, return_prop: bool = False, ) -> jnp.ndarray: if return_prop: raise NotImplementedError("SNet1 does not implement a propensity model.") # unpack inputs predict_fun_repr, predict_fun_head = predict_funs param_repr, param_0, param_1 = ( trained_params[0], trained_params[1], trained_params[2], ) # get representation representation = predict_fun_repr(param_repr, X) # get potential outcomes mu_0 = predict_fun_head(param_0, representation) mu_1 = predict_fun_head(param_1, representation) if return_po: return mu_1 - mu_0, mu_0, mu_1 else: return mu_1 - mu_0 def train_snet1( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_disc: int = DEFAULT_PENALTY_DISC, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, reg_diff: bool = False, same_init: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, ) -> Any: # function to train TARNET (Johansson et al) using jax # input check y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well if not reg_diff: penalty_diff = penalty_l2 # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get representation layer init_fun_repr, predict_fun_repr = ReprBlock( n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin ) # get output head functions (both heads share same structure) init_fun_head, predict_fun_head = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, nonlin=nonlin, ) def init_fun_snet1(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]: # chain together the layers # param should look like [repr, po_0, po_1] rng, layer_rng = random.split(rng) input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape) rng, layer_rng = random.split(rng) if same_init: # initialise both on same values input_shape, param_0 = init_fun_head(layer_rng, input_shape_repr) input_shape, param_1 = init_fun_head(layer_rng, input_shape_repr) else: input_shape, param_0 = init_fun_head(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_1 = init_fun_head(layer_rng, input_shape_repr) return input_shape, [param_repr, param_0, param_1] # Define loss functions # loss functions for the head if not binary_y: def loss_head( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> jnp.ndarray: # mse loss function inputs, targets, weights = batch preds = predict_fun_head(params, inputs) return jnp.sum(weights * ((preds - targets) ** 2)) else: def loss_head( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> jnp.ndarray: # mse loss function inputs, targets, weights = batch preds = predict_fun_head(params, inputs) return -jnp.sum( weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) ) # complete loss function for all parts @jit def loss_snet1( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty_l2: float, penalty_disc: float, penalty_diff: float, ) -> jnp.ndarray: # params: list[representation, head_0, head_1] # batch: (X, y, w) X, y, w = batch # get representation reps = predict_fun_repr(params[0], X) # get mmd disc = mmd2_lin(reps, w) # pass down to two heads loss_0 = loss_head(params[1], (reps, y, 1 - w)) loss_1 = loss_head(params[2], (reps, y, w)) # regularization on representation weightsq_body = sum( [jnp.sum(params[0][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)] ) weightsq_head = heads_l2_penalty( params[1], params[2], n_layers_out, reg_diff, penalty_l2, penalty_diff ) if not avg_objective: return ( loss_0 + loss_1 + penalty_disc * disc + 0.5 * (penalty_l2 * weightsq_body + weightsq_head) ) else: n_batch = y.shape[0] return ( (loss_0 + loss_1) / n_batch + penalty_disc * disc + 0.5 * (penalty_l2 * weightsq_body + weightsq_head) ) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_disc: float ) -> jnp.ndarray: # updating function params = get_params(state) return opt_update( i, grad(loss_snet1)(params, batch, penalty_l2, penalty_disc, penalty_diff), state, ) # initialise states _, init_params = init_fun_snet1(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_disc ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_snet1( params_curr, (X_val, y_val, w_val), penalty_l2, penalty_disc, penalty_diff, ) if i % n_iter_print == 0: log.info(f"Epoch: {i}, current {val_string} loss {l_curr}") if early_stopping: if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, (predict_fun_repr, predict_fun_head) p_curr = p_curr + 1 if p_curr > patience and ((i + 1) * n_batches > n_iter_min): if return_val_loss: # return loss without penalty l_final = loss_snet1(params_curr, (X_val, y_val, w_val), 0, 0, 0) return params_curr, (predict_fun_repr, predict_fun_head), l_final return params_curr, (predict_fun_repr, predict_fun_head) # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_snet1(get_params(opt_state), (X_val, y_val, w_val), 0, 0, 0) return trained_params, (predict_fun_repr, predict_fun_head), l_final return trained_params, (predict_fun_repr, predict_fun_head) # SNET-2 ----------------------------------------------------------------------------------------- def train_snet2( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, same_init: bool = False, ) -> Any: """ SNet2 corresponds to DragonNet (Shi et al, 2019) [without TMLE regularisation term]. """ y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well if not reg_diff: penalty_diff = penalty_l2 # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get representation layer init_fun_repr, predict_fun_repr = ReprBlock( n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin ) # get output head functions (output heads share same structure) init_fun_head_po, predict_fun_head_po = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, nonlin=nonlin, ) # add propensity head init_fun_head_prop, predict_fun_head_prop = OutputHead( n_layers_out=n_layers_out_prop, n_units_out=n_units_out_prop, binary_y=True, nonlin=nonlin, ) def init_fun_snet2(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]: # chain together the layers # param should look like [repr, po_0, po_1, prop] rng, layer_rng = random.split(rng) input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape) rng, layer_rng = random.split(rng) if same_init: # initialise both on same values input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) else: input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr) return input_shape, [param_repr, param_0, param_1, param_prop] # Define loss functions # loss functions for the head if not binary_y: def loss_head( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> jnp.ndarray: # mse loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return jnp.sum(weights * ((preds - targets) ** 2)) else: def loss_head( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> jnp.ndarray: # log loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return -jnp.sum( weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) ) def loss_head_prop( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float ) -> jnp.ndarray: # log loss function for propensities inputs, targets = batch preds = predict_fun_head_prop(params, inputs) return -jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) # complete loss function for all parts @jit def loss_snet2( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty_l2: float, penalty_diff: float, ) -> jnp.ndarray: # params: list[representation, head_0, head_1, head_prop] # batch: (X, y, w) X, y, w = batch # get representation reps = predict_fun_repr(params[0], X) # pass down to heads loss_0 = loss_head(params[1], (reps, y, 1 - w)) loss_1 = loss_head(params[2], (reps, y, w)) # pass down to propensity head loss_prop = loss_head_prop(params[3], (reps, w), penalty_l2) weightsq_prop = sum( [ jnp.sum(params[3][i][0] ** 2) for i in range(0, 2 * n_layers_out_prop + 1, 2) ] ) weightsq_body = sum( [jnp.sum(params[0][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)] ) weightsq_head = heads_l2_penalty( params[1], params[2], n_layers_out, reg_diff, penalty_l2, penalty_diff ) if not avg_objective: return ( loss_0 + loss_1 + loss_prop + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) ) else: n_batch = y.shape[0] return ( (loss_0 + loss_1) / n_batch + loss_prop / n_batch + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) ) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_diff: float ) -> jnp.ndarray: # updating function params = get_params(state) return opt_update( i, grad(loss_snet2)(params, batch, penalty_l2, penalty_diff), state ) # initialise states _, init_params = init_fun_snet2(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_diff ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_snet2( params_curr, (X_val, y_val, w_val), penalty_l2, penalty_diff ) if i % n_iter_print == 0: log.info(f"Epoch: {i}, current {val_string} loss {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_snet2(params_curr, (X_val, y_val, w_val), 0, 0) return ( params_curr, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), l_final, ) return params_curr, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_snet2(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return ( trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), l_final, ) return trained_params, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) def predict_snet2( X: jnp.ndarray, trained_params: dict, predict_funs: list, return_po: bool = False, return_prop: bool = False, ) -> jnp.ndarray: # unpack inputs predict_fun_repr, predict_fun_head, predict_fun_prop = predict_funs param_repr, param_0, param_1, param_prop = ( trained_params[0], trained_params[1], trained_params[2], trained_params[3], ) # get representation representation = predict_fun_repr(param_repr, X) # get potential outcomes mu_0 = predict_fun_head(param_0, representation) mu_1 = predict_fun_head(param_1, representation) te = mu_1 - mu_0 if return_prop: # get propensity prop = predict_fun_prop(param_prop, representation) # stack other outputs if return_po: if return_prop: return te, mu_0, mu_1, prop else: return te, mu_0, mu_1 else: if return_prop: return te, prop else: return te ================================================ FILE: catenets/models/jax/rnet.py ================================================ """ Implements NN based on R-learner and U-learner (as discussed in Nie & Wager (2017)) """ # Author: Alicia Curth from typing import Any, Callable, Optional import jax.numpy as jnp import numpy as onp import pandas as pd from jax import grad, jit, random from jax.example_libraries import optimizers from sklearn.model_selection import StratifiedKFold import catenets.logger as log from catenets.models.constants import ( DEFAULT_AVG_OBJECTIVE, DEFAULT_BATCH_SIZE, DEFAULT_CF_FOLDS, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_OUT_T, DEFAULT_LAYERS_R, DEFAULT_LAYERS_R_T, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_STEP_SIZE_T, DEFAULT_UNITS_OUT, DEFAULT_UNITS_OUT_T, DEFAULT_UNITS_R, DEFAULT_UNITS_R_T, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.jax.base import ( BaseCATENet, OutputHead, make_val_split, train_output_net_only, ) from catenets.models.jax.model_utils import check_shape_1d_data, check_X_is_np R_STRATEGY_NAME = "R" U_STRATEGY_NAME = "U" class RNet(BaseCATENet): """ Class implements R-learner and U-learner using NNs Parameters ---------- second_stage_strategy: str, default 'R' Which strategy to use in the second stage ('R' for R-learner, 'U' for U-learner) data_split: bool, default False Whether to split the data in two folds for estimation cross_fit: bool, default False Whether to perform cross fitting n_cf_folds: int Number of crossfitting folds to use n_layers_out: int First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int First stage Number of hidden units in each hypothesis layer n_layers_r: int First stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_r: int First stage Number of hidden units in each representation layer n_layers_out_t: int Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out_t: int Second stage Number of hidden units in each hypothesis layer n_layers_r_t: int Second stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_r_t: int Second stage Number of hidden units in each representation layer penalty_l2: float First stage l2 (ridge) penalty penalty_l2_t: float Second stage l2 (ridge) penalty step_size: float First stage learning rate for optimizer step_size_t: float Second stage learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used nonlin: string, default 'elu' Nonlinearity to use in NN """ def __init__( self, second_stage_strategy: str = R_STRATEGY_NAME, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, binary_y: bool = False, ) -> None: # settings self.binary_y = binary_y self.second_stage_strategy = second_stage_strategy self.data_split = data_split self.cross_fit = cross_fit self.n_cf_folds = n_cf_folds # model architecture hyperparams self.n_layers_out = n_layers_out self.n_layers_out_t = n_layers_out_t self.n_layers_r = n_layers_r self.n_layers_r_t = n_layers_r_t self.n_units_out = n_units_out self.n_units_out_t = n_units_out_t self.n_units_r = n_units_r self.n_units_r_t = n_units_r_t self.nonlin = nonlin # other hyperparameters self.penalty_l2 = penalty_l2 self.penalty_l2_t = penalty_l2_t self.step_size = step_size self.step_size_t = step_size_t self.n_iter = n_iter self.batch_size = batch_size self.n_iter_print = n_iter_print self.seed = seed self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min def _get_train_function(self) -> Callable: return train_r_net def fit( self, X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, p: Optional[jnp.ndarray] = None, ) -> "RNet": # overwrite super so we can pass p as extra param # some quick input checks X = check_X_is_np(X) self._check_inputs(w, p) train_func = self._get_train_function() train_params = self.get_params() self._params, self._predict_funs = train_func(X, y, w, p, **train_params) return self def _get_predict_function(self) -> Callable: # Two step nets do not need this pass def predict( self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False ) -> jnp.ndarray: # check input if return_po: raise NotImplementedError( "TwoStepNets have no Potential outcome predictors." ) if return_prop: raise NotImplementedError("TwoStepNets have no Propensity predictors.") if isinstance(X, pd.DataFrame): X = X.values return self._predict_funs(self._params, X) def train_r_net( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, p: Optional[jnp.ndarray] = None, second_stage_strategy: str = R_STRATEGY_NAME, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = DEFAULT_CF_FOLDS, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, nonlin: str = DEFAULT_NONLIN, binary_y: bool = False, ) -> Any: # get shape of data n, d = X.shape if p is not None: p = check_shape_1d_data(p) # split data as wanted if not cross_fit: if not data_split: log.debug("Training first stage with all data (no data splitting)") # use all data for both fit_mask = onp.ones(n, dtype=bool) pred_mask = onp.ones(n, dtype=bool) else: log.debug("Training first stage with half of the data (data splitting)") # split data in half fit_idx = onp.random.choice(n, int(onp.round(n / 2))) fit_mask = onp.zeros(n, dtype=bool) fit_mask[fit_idx] = 1 pred_mask = ~fit_mask mu_hat, pi_hat = _train_and_predict_r_stage1( X, y, w, fit_mask, pred_mask, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_out=n_units_out, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, binary_y=binary_y, ) if data_split: # keep only prediction data X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :] if p is not None: p = p[pred_mask, :] else: log.debug(f"Training first stage in {n_cf_folds} folds (cross-fitting)") # do cross fitting mu_hat, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1)) splitter = StratifiedKFold(n_splits=n_cf_folds, shuffle=True, random_state=seed) fold_count = 1 for train_idx, test_idx in splitter.split(X, w): log.debug(f"Training fold {fold_count}.") fold_count = fold_count + 1 pred_mask = onp.zeros(n, dtype=bool) pred_mask[test_idx] = 1 fit_mask = ~pred_mask mu_hat[pred_mask], pi_hat[pred_mask] = _train_and_predict_r_stage1( X, y, w, fit_mask, pred_mask, n_layers_out=n_layers_out, n_layers_r=n_layers_r, n_units_out=n_units_out, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, binary_y=binary_y, ) log.debug("Training second stage.") if p is not None: # use known propensity score p = check_shape_1d_data(p) pi_hat = p y, w = check_shape_1d_data(y), check_shape_1d_data(w) w_ortho = w - pi_hat y_ortho = y - mu_hat if second_stage_strategy == R_STRATEGY_NAME: return train_r_stage2( X, y_ortho, w_ortho, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, ) elif second_stage_strategy == U_STRATEGY_NAME: return train_output_net_only( X, y_ortho / w_ortho, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, ) else: raise ValueError("R-learner only supports strategies R and U.") def _train_and_predict_r_stage1( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, fit_mask: jnp.ndarray, pred_mask: jnp.ndarray, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, binary_y: bool = False, ) -> Any: if len(w.shape) > 1: w = w.reshape((len(w),)) # split the data X_fit, y_fit, w_fit = X[fit_mask, :], y[fit_mask], w[fit_mask] X_pred = X[pred_mask, :] log.debug("Training output Net") params_out, predict_fun_out = train_output_net_only( X_fit, y_fit, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, binary_y=binary_y, ) mu_hat = predict_fun_out(params_out, X_pred) log.debug("Training propensity net") params_prop, predict_fun_prop = train_output_net_only( X_fit, w_fit, binary_y=True, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, ) pi_hat = predict_fun_prop(params_prop, X_pred) return mu_hat, pi_hat def train_r_stage2( X: jnp.ndarray, y_ortho: jnp.ndarray, w_ortho: jnp.ndarray, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_layers_r: int = 0, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, ) -> Any: # function to train a single output head # input check y_ortho, w_ortho = check_shape_1d_data(y_ortho), check_shape_1d_data(w_ortho) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well # get validation split (can be none) X, y_ortho, w_ortho, X_val, y_val, w_val, val_string = make_val_split( X, y_ortho, w_ortho, val_split_prop=val_split_prop, seed=seed, stratify_w=False ) n = X.shape[0] # could be different from before due to split # get output head init_fun, predict_fun = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, nonlin=nonlin, ) # define loss and grad @jit def loss(params: dict, batch: jnp.ndarray, penalty: float) -> jnp.ndarray: # mse loss function inputs, ortho_targets, ortho_treats = batch preds = predict_fun(params, inputs) weightsq = sum( [ jnp.sum(params[i][0] ** 2) for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2) ] ) if not avg_objective: return ( jnp.sum((ortho_targets - ortho_treats * preds) ** 2) + 0.5 * penalty * weightsq ) else: return ( jnp.average((ortho_targets - ortho_treats * preds) ** 2) + 0.5 * penalty * weightsq ) # set optimization routine # set optimizer opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) # set update function @jit def update(i: int, state: dict, batch: jnp.ndarray, penalty: float) -> jnp.ndarray: params = get_params(state) g_params = grad(loss)(params, batch, penalty) # g_params = optimizers.clip_grads(g_params, 1.0) return opt_update(i, g_params, state) # initialise states _, init_params = init_fun(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y_ortho[idx_next, :], w_ortho[idx_next, :] opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss(params_curr, (X_val, y_val, w_val), penalty_l2) if i % n_iter_print == 0: log.debug(f"Epoch: {i}, current {val_string} loss: {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 else: p_curr = p_curr + 1 if p_curr > patience: trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss(trained_params, (X_val, y_val, w_val), 0) return trained_params, predict_fun, l_final return trained_params, predict_fun # get final parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss(trained_params, (X_val, y_val, w_val), 0) return trained_params, predict_fun, l_final return trained_params, predict_fun ================================================ FILE: catenets/models/jax/snet.py ================================================ """ Module implements SNet class as discussed in Curth & van der Schaar (2021) """ # Author: Alicia Curth from typing import Callable, List, Tuple import jax.numpy as jnp import numpy as onp from jax import grad, jit, random from jax.example_libraries import optimizers import catenets.logger as log from catenets.models.constants import ( DEFAULT_AVG_OBJECTIVE, DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_DISC, DEFAULT_PENALTY_L2, DEFAULT_PENALTY_ORTHOGONAL, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R_BIG_S3, DEFAULT_UNITS_R_SMALL_S3, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.jax.base import BaseCATENet, OutputHead, ReprBlock from catenets.models.jax.disentangled_nets import ( _concatenate_representations, _get_absolute_rowsums, ) from catenets.models.jax.flextenet import _get_cos_reg from catenets.models.jax.model_utils import ( check_shape_1d_data, heads_l2_penalty, make_val_split, ) from catenets.models.jax.representation_nets import mmd2_lin DEFAULT_UNITS_R_BIG_S = 100 DEFAULT_UNITS_R_SMALL_S = 50 class SNet(BaseCATENet): """ Class implements SNet as discussed in Curth & van der Schaar (2021). Additionally to the version implemented in the AISTATS paper, we also include an implementation that does not have propensity heads (set with_prop=False) Parameters ---------- with_prop: bool, True Whether to include propensity head binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_layers_out_prop: int Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int Number of hidden units in each hypothesis layer n_units_out_prop: int Number of hidden units in each propensity score hypothesis layer n_layers_r: int Number of shared & private representation layers before hypothesis layers n_units_r: int If withprop=True: Number of hidden units in representation layer shared by propensity score and outcome function (the 'confounding factor') and in the ('instrumental factor') If withprop=False: Number of hidden units in representation shared across PO function n_units_r_small: int If withprop=True: Number of hidden units in representation layer of the 'outcome factor' and each PO functions private representation if withprop=False: Number of hidden units in each PO functions private representation penalty_l2: float l2 (ridge) penalty step_size: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used reg_diff: bool, default False Whether to regularize the difference between the two potential outcome heads penalty_diff: float l2-penalty for regularizing the difference between output heads. used only if train_separate=False same_init: bool, False Whether to initialise the two output heads with same values nonlin: string, default 'elu' Nonlinearity to use in NN penalty_disc: float, default zero Discrepancy penalty. Defaults to zero as this feature is not tested. ortho_reg_type: str, 'abs' Which type of orthogonalization to use. 'abs' uses the (hard) disentanglement described in AISTATS paper, 'fro' uses frobenius norm as in FlexTENet """ def __init__( self, with_prop: bool = True, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R_BIG_S, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S, n_units_out: int = DEFAULT_UNITS_OUT, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, penalty_disc: float = DEFAULT_PENALTY_DISC, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, same_init: bool = False, ortho_reg_type: str = "abs", ): self.with_prop = with_prop self.binary_y = binary_y self.n_layers_r = n_layers_r self.n_layers_out = n_layers_out self.n_layers_out_prop = n_layers_out_prop self.n_units_r = n_units_r self.n_units_r_small = n_units_r_small self.n_units_out = n_units_out self.n_units_out_prop = n_units_out_prop self.nonlin = nonlin self.penalty_l2 = penalty_l2 self.penalty_orthogonal = penalty_orthogonal self.penalty_disc = penalty_disc self.reg_diff = reg_diff self.penalty_diff = penalty_diff self.same_init = same_init self.ortho_reg_type = ortho_reg_type self.step_size = step_size self.n_iter = n_iter self.batch_size = batch_size self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min self.seed = seed self.n_iter_print = n_iter_print def _get_predict_function(self) -> Callable: if self.with_prop: return predict_snet else: return predict_snet_noprop def _get_train_function(self) -> Callable: if self.with_prop: return train_snet else: return train_snet_noprop def train_snet( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R_BIG_S, n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_disc: float = DEFAULT_PENALTY_DISC, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, with_prop: bool = True, same_init: bool = False, ortho_reg_type: str = "abs", ) -> Tuple: # function to train a net with 5 representations if not with_prop: raise ValueError("train_snet works only withprop=True") y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well if not reg_diff: penalty_diff = penalty_l2 # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get representation layers init_fun_repr, predict_fun_repr = ReprBlock( n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin ) init_fun_repr_small, predict_fun_repr_small = ReprBlock( n_layers=n_layers_r, n_units=n_units_r_small, nonlin=nonlin ) # get output head functions (output heads share same structure) init_fun_head_po, predict_fun_head_po = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, nonlin=nonlin, ) # add propensity head init_fun_head_prop, predict_fun_head_prop = OutputHead( n_layers_out=n_layers_out_prop, n_units_out=n_units_out_prop, binary_y=True, nonlin=nonlin, ) def init_fun_snet(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]: # chain together the layers # param should look like [param_repr_c, param_repr_o, param_repr_mu0, param_repr_mu1, # param_repr_w, param_0, param_1, param_prop] # initialise representation layers rng, layer_rng = random.split(rng) input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape) rng, layer_rng = random.split(rng) input_shape_repr_small, param_repr_o = init_fun_repr_small( layer_rng, input_shape ) rng, layer_rng = random.split(rng) _, param_repr_mu0 = init_fun_repr_small(layer_rng, input_shape) rng, layer_rng = random.split(rng) _, param_repr_mu1 = init_fun_repr_small(layer_rng, input_shape) rng, layer_rng = random.split(rng) _, param_repr_w = init_fun_repr(layer_rng, input_shape) # prop and mu_0 each get two representations, mu_1 gets 3 input_shape_repr_prop = input_shape_repr[:-1] + (2 * input_shape_repr[-1],) input_shape_repr_mu = input_shape_repr[:-1] + ( input_shape_repr[-1] + (2 * input_shape_repr_small[-1]), ) # initialise output heads rng, layer_rng = random.split(rng) if same_init: # initialise both on same values input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr_mu) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr_mu) else: input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr_mu) rng, layer_rng = random.split(rng) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr_mu) rng, layer_rng = random.split(rng) input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr_prop) return input_shape, [ param_repr_c, param_repr_o, param_repr_mu0, param_repr_mu1, param_repr_w, param_0, param_1, param_prop, ] # Define loss functions # loss functions for the head if not binary_y: def loss_head( params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty: float, ) -> jnp.ndarray: # mse loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return jnp.sum(weights * ((preds - targets) ** 2)) else: def loss_head( params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty: float, ) -> jnp.ndarray: # log loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return -jnp.sum( weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) ) def loss_head_prop( params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float ) -> jnp.ndarray: # log loss function for propensities inputs, targets = batch preds = predict_fun_head_prop(params, inputs) return -jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) # define ortho-reg function if ortho_reg_type == "abs": def ortho_reg(params: jnp.ndarray) -> jnp.ndarray: col_c = _get_absolute_rowsums(params[0][0][0]) col_o = _get_absolute_rowsums(params[1][0][0]) col_mu0 = _get_absolute_rowsums(params[2][0][0]) col_mu1 = _get_absolute_rowsums(params[3][0][0]) col_w = _get_absolute_rowsums(params[4][0][0]) return jnp.sum( col_c * col_o + col_c * col_w + col_c * col_mu1 + col_c * col_mu0 + col_w * col_o + col_mu0 * col_o + col_o * col_mu1 + col_mu0 * col_mu1 + col_mu0 * col_w + col_w * col_mu1 ) elif ortho_reg_type == "fro": def ortho_reg(params: jnp.ndarray) -> jnp.ndarray: return ( _get_cos_reg(params[0][0][0], params[1][0][0], False) + _get_cos_reg(params[0][0][0], params[2][0][0], False) + _get_cos_reg(params[0][0][0], params[3][0][0], False) + _get_cos_reg(params[0][0][0], params[4][0][0], False) + _get_cos_reg(params[1][0][0], params[2][0][0], False) + _get_cos_reg(params[1][0][0], params[3][0][0], False) + _get_cos_reg(params[1][0][0], params[4][0][0], False) + _get_cos_reg(params[2][0][0], params[3][0][0], False) + _get_cos_reg(params[2][0][0], params[4][0][0], False) + _get_cos_reg(params[3][0][0], params[4][0][0], False) ) else: raise NotImplementedError( "train_snet_noprop supports only orthogonal regularization " "using absolute values or frobenious norms." ) # complete loss function for all parts @jit def loss_snet( params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty_l2: float, penalty_orthogonal: float, penalty_disc: float, ) -> jnp.ndarray: # params: # param should look like [param_repr_c, param_repr_o, param_repr_mu0, # param_repr_mu1, param_repr_w, param_0, param_1, param_prop] # batch: (X, y, w) X, y, w = batch # get representation reps_c = predict_fun_repr(params[0], X) reps_o = predict_fun_repr_small(params[1], X) reps_mu0 = predict_fun_repr_small(params[2], X) reps_mu1 = predict_fun_repr_small(params[3], X) reps_w = predict_fun_repr(params[4], X) # concatenate reps_po_0 = _concatenate_representations((reps_c, reps_o, reps_mu0)) reps_po_1 = _concatenate_representations((reps_c, reps_o, reps_mu1)) reps_prop = _concatenate_representations((reps_c, reps_w)) # pass down to heads loss_0 = loss_head(params[5], (reps_po_0, y, 1 - w), penalty_l2) loss_1 = loss_head(params[6], (reps_po_1, y, w), penalty_l2) # pass down to propensity head loss_prop = loss_head_prop(params[7], (reps_prop, w), penalty_l2) # is rep_o balanced between groups? loss_disc = penalty_disc * mmd2_lin(reps_o, w) # which variable has impact on which representation -- orthogonal loss loss_o = penalty_orthogonal * ortho_reg(params) # weight decay on representations weightsq_body = sum( [ sum( [jnp.sum(params[j][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)] ) for j in range(5) ] ) weightsq_head = heads_l2_penalty( params[5], params[6], n_layers_out, reg_diff, penalty_l2, penalty_diff ) weightsq_prop = sum( [ jnp.sum(params[7][i][0] ** 2) for i in range(0, 2 * n_layers_out_prop + 1, 2) ] ) if not avg_objective: return ( loss_0 + loss_1 + loss_prop + loss_disc + loss_o + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) ) else: n_batch = y.shape[0] return ( (loss_0 + loss_1) / n_batch + loss_prop / n_batch + loss_disc + loss_o + 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head) ) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_orthogonal: float, penalty_disc: float, ) -> jnp.ndarray: # updating function params = get_params(state) return opt_update( i, grad(loss_snet)( params, batch, penalty_l2, penalty_orthogonal, penalty_disc ), state, ) # initialise states _, init_params = init_fun_snet(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_orthogonal, penalty_disc, ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_snet( params_curr, (X_val, y_val, w_val), penalty_l2, penalty_orthogonal, penalty_disc, ) if i % n_iter_print == 0: log.info(f"Epoch: {i}, current {val_string} loss {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_snet(params_curr, (X_val, y_val, w_val), 0, 0, 0) return ( params_curr, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), l_final, ) return params_curr, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_snet(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return ( trained_params, (predict_fun_repr, predict_fun_head_po, predict_fun_head_prop), l_final, ) return trained_params, ( predict_fun_repr, predict_fun_head_po, predict_fun_head_prop, ) def predict_snet( X: jnp.ndarray, trained_params: jnp.ndarray, predict_funs: list, return_po: bool = False, return_prop: bool = False, ) -> jnp.ndarray: # unpack inputs predict_fun_repr, predict_fun_head, predict_fun_prop = predict_funs param_0, param_1, param_prop = ( trained_params[5], trained_params[6], trained_params[7], ) reps_c = predict_fun_repr(trained_params[0], X) reps_o = predict_fun_repr(trained_params[1], X) reps_mu0 = predict_fun_repr(trained_params[2], X) reps_mu1 = predict_fun_repr(trained_params[3], X) reps_w = predict_fun_repr(trained_params[4], X) # concatenate reps_po_0 = _concatenate_representations((reps_c, reps_o, reps_mu0)) reps_po_1 = _concatenate_representations((reps_c, reps_o, reps_mu1)) reps_prop = _concatenate_representations((reps_c, reps_w)) # get potential outcomes mu_0 = predict_fun_head(param_0, reps_po_0) mu_1 = predict_fun_head(param_1, reps_po_1) te = mu_1 - mu_0 if return_prop: # get propensity prop = predict_fun_prop(param_prop, reps_prop) # stack other outputs if return_po: if return_prop: return te, mu_0, mu_1, prop else: return te, mu_0, mu_1 else: if return_prop: return te, prop else: return te # SNet without propensity head ---------------------------------------- def train_snet_noprop( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R_BIG_S3, n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, n_iter_min: int = DEFAULT_N_ITER_MIN, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, with_prop: bool = False, same_init: bool = False, ortho_reg_type: str = "abs", ) -> Tuple: """ SNet but without the propensity head """ if with_prop: raise ValueError("train_snet_noprop works only with_prop=False") # function to train a net with 3 representations y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well if not reg_diff: penalty_diff = penalty_l2 # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get representation layers init_fun_repr, predict_fun_repr = ReprBlock( n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin ) init_fun_repr_small, predict_fun_repr_small = ReprBlock( n_layers=n_layers_r, n_units=n_units_r_small, nonlin=nonlin ) # get output head functions (output heads share same structure) init_fun_head_po, predict_fun_head_po = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, nonlin=nonlin, ) def init_fun_snet_noprop(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]: # chain together the layers # param should look like [repr_o, repr_p0, repr_p1, po_0, po_1] # initialise representation layers rng, layer_rng = random.split(rng) input_shape_repr, param_repr_o = init_fun_repr(layer_rng, input_shape) rng, layer_rng = random.split(rng) input_shape_repr_small, param_repr_p0 = init_fun_repr_small( layer_rng, input_shape ) rng, layer_rng = random.split(rng) _, param_repr_p1 = init_fun_repr_small(layer_rng, input_shape) # each head gets two representations input_shape_repr = input_shape_repr[:-1] + ( input_shape_repr[-1] + input_shape_repr_small[-1], ) # initialise output heads rng, layer_rng = random.split(rng) if same_init: # initialise both on same values input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) else: input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr) rng, layer_rng = random.split(rng) input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr) return input_shape, [ param_repr_o, param_repr_p0, param_repr_p1, param_0, param_1, ] # Define loss functions # loss functions for the head if not binary_y: def loss_head( params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty: float, ) -> jnp.ndarray: # mse loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return jnp.sum(weights * ((preds - targets) ** 2)) else: def loss_head( params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty: float, ) -> jnp.ndarray: # log loss function inputs, targets, weights = batch preds = predict_fun_head_po(params, inputs) return -jnp.sum( weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) ) # define ortho-reg function if ortho_reg_type == "abs": def ortho_reg(params: jnp.ndarray) -> jnp.ndarray: col_o = _get_absolute_rowsums(params[0][0][0]) col_p0 = _get_absolute_rowsums(params[1][0][0]) col_p1 = _get_absolute_rowsums(params[2][0][0]) return jnp.sum(col_o * col_p0 + col_o * col_p1 + col_p1 * col_p0) elif ortho_reg_type == "fro": def ortho_reg(params: jnp.ndarray) -> jnp.ndarray: return ( _get_cos_reg(params[0][0][0], params[1][0][0], False) + _get_cos_reg(params[0][0][0], params[2][0][0], False) + _get_cos_reg(params[1][0][0], params[2][0][0], False) ) else: raise NotImplementedError( "train_snet_noprop supports only orthogonal regularization " "using absolute values or frobenious norms." ) # complete loss function for all parts @jit def loss_snet_noprop( params: jnp.ndarray, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty_l2: float, penalty_orthogonal: float, ) -> jnp.ndarray: # params: list[repr_o, repr_p0, repr_p1, po_0, po_1] # batch: (X, y, w) X, y, w = batch # get representation reps_o = predict_fun_repr(params[0], X) reps_p0 = predict_fun_repr_small(params[1], X) reps_p1 = predict_fun_repr_small(params[2], X) # concatenate reps_po0 = _concatenate_representations((reps_o, reps_p0)) reps_po1 = _concatenate_representations((reps_o, reps_p1)) # pass down to heads loss_0 = loss_head(params[3], (reps_po0, y, 1 - w), penalty_l2) loss_1 = loss_head(params[4], (reps_po1, y, w), penalty_l2) # which variable has impact on which representation loss_o = penalty_orthogonal * ortho_reg(params) # weight decay on representations weightsq_body = sum( [ sum( [jnp.sum(params[j][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)] ) for j in range(3) ] ) weightsq_head = heads_l2_penalty( params[3], params[4], n_layers_out, reg_diff, penalty_l2, penalty_diff ) if not avg_objective: return ( loss_0 + loss_1 + loss_o + 0.5 * (penalty_l2 * weightsq_body + weightsq_head) ) else: n_batch = y.shape[0] return ( (loss_0 + loss_1) / n_batch + loss_o + 0.5 * (penalty_l2 * weightsq_body + weightsq_head) ) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_orthogonal: float, ) -> jnp.ndarray: # updating function params = get_params(state) return opt_update( i, grad(loss_snet_noprop)(params, batch, penalty_l2, penalty_orthogonal), state, ) # initialise states _, init_params = init_fun_snet_noprop(rng_key, input_shape) opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_orthogonal ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_snet_noprop( params_curr, (X_val, y_val, w_val), penalty_l2, penalty_orthogonal ) if i % n_iter_print == 0: log.info(f"Epoch: {i}, current {val_string} loss {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): # check if loss updated if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, (predict_fun_repr, predict_fun_head_po) p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_snet_noprop(params_curr, (X_val, y_val, w_val), 0, 0) return params_curr, (predict_fun_repr, predict_fun_head_po), l_final return params_curr, (predict_fun_repr, predict_fun_head_po) # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_snet_noprop(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return trained_params, (predict_fun_repr, predict_fun_head_po), l_final return trained_params, (predict_fun_repr, predict_fun_head_po) def predict_snet_noprop( X: jnp.ndarray, trained_params: jnp.ndarray, predict_funs: list, return_po: bool = False, return_prop: bool = False, ) -> jnp.ndarray: if return_prop: raise NotImplementedError("SNet5 does not have propensity estimator") # unpack inputs predict_fun_repr, predict_fun_head = predict_funs param_repr_o, param_repr_po0, param_repr_po1 = ( trained_params[0], trained_params[1], trained_params[2], ) param_0, param_1 = trained_params[3], trained_params[4] # get representations rep_o = predict_fun_repr(param_repr_o, X) rep_po0 = predict_fun_repr(param_repr_po0, X) rep_po1 = predict_fun_repr(param_repr_po1, X) # concatenate reps_po0 = jnp.concatenate((rep_o, rep_po0), axis=1) reps_po1 = jnp.concatenate((rep_o, rep_po1), axis=1) # get potential outcomes mu_0 = predict_fun_head(param_0, reps_po0) mu_1 = predict_fun_head(param_1, reps_po1) te = mu_1 - mu_0 # stack other outputs if return_po: return te, mu_0, mu_1 else: return te ================================================ FILE: catenets/models/jax/tnet.py ================================================ """ Implements a T-Net: T-learner for CATE based on a dense NN """ # Author: Alicia Curth from typing import Any, Callable, List, Tuple import jax.numpy as jnp import numpy as onp from jax import grad, jit, random from jax.example_libraries import optimizers import catenets.logger as log from catenets.models.constants import ( DEFAULT_AVG_OBJECTIVE, DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.jax.base import BaseCATENet, OutputHead, train_output_net_only from catenets.models.jax.model_utils import ( check_shape_1d_data, heads_l2_penalty, make_val_split, ) class TNet(BaseCATENet): """ TNet class -- two separate functions learned for each Potential Outcome function Parameters ---------- binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int Number of hidden units in each hypothesis layer n_layers_r: int Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_r: int Number of hidden units in each representation layer penalty_l2: float l2 (ridge) penalty step_size: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used train_separate: bool, default True Whether to train the two output heads completely separately or whether to regularize their difference penalty_diff: float l2-penalty for regularizing the difference between output heads. used only if train_separate=False nonlin: string, default 'elu' Nonlinearity to use in NN """ def __init__( self, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, train_separate: bool = True, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, ) -> None: self.binary_y = binary_y self.n_layers_out = n_layers_out self.n_units_out = n_units_out self.n_layers_r = n_layers_r self.n_units_r = n_units_r self.penalty_l2 = penalty_l2 self.step_size = step_size self.n_iter = n_iter self.batch_size = batch_size self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min self.n_iter_print = n_iter_print self.seed = seed self.train_separate = train_separate self.penalty_diff = penalty_diff self.nonlin = nonlin def _get_predict_function(self) -> Callable: return predict_t_net def _get_train_function(self) -> Callable: return train_tnet def train_tnet( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, train_separate: bool = True, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, ) -> Any: # w should be 1-D for indexing if len(w.shape) > 1: w = w.reshape((len(w),)) if train_separate: # train two heads completely independently log.debug("Training PO_0 Net") out_0 = train_output_net_only( X[w == 0], y[w == 0], binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective, ) log.debug("Training PO_1 Net") out_1 = train_output_net_only( X[w == 1], y[w == 1], binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective, ) if return_val_loss: params_0, predict_fun_0, loss_0 = out_0 params_1, predict_fun_1, loss_1 = out_1 return (params_0, params_1), (predict_fun_0, predict_fun_1), loss_1 + loss_0 params_0, predict_fun_0 = out_0 params_1, predict_fun_1 = out_1 else: # train jointly by regularizing similarity params, predict_fun = _train_tnet_jointly( X, y, w, binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, penalty_diff=penalty_diff, nonlin=nonlin, ) params_0, params_1 = params[0], params[1] predict_fun_0, predict_fun_1 = predict_fun, predict_fun return (params_0, params_1), (predict_fun_0, predict_fun_1) def predict_t_net( X: jnp.ndarray, trained_params: dict, predict_funs: list, return_po: bool = False, return_prop: bool = False, ) -> jnp.ndarray: if return_prop: raise NotImplementedError("TNet does not implement a propensity model.") # return CATE predictions using T-net params params_0, params_1 = trained_params predict_fun_0, predict_fun_1 = predict_funs mu_0 = predict_fun_0(params_0, X) mu_1 = predict_fun_1(params_1, X) if return_po: return mu_1 - mu_0, mu_0, mu_1 else: return mu_1 - mu_0 def _train_tnet_jointly( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, return_val_loss: bool = False, same_init: bool = True, penalty_diff: float = DEFAULT_PENALTY_L2, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, ) -> jnp.ndarray: # input check y, w = check_shape_1d_data(y), check_shape_1d_data(w) d = X.shape[1] input_shape = (-1, d) rng_key = random.PRNGKey(seed) onp.random.seed(seed) # set seed for data generation via numpy as well # get validation split (can be none) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w, val_split_prop=val_split_prop, seed=seed ) n = X.shape[0] # could be different from before due to split # get output head functions (both heads share same structure) init_fun_head, predict_fun_head = OutputHead( n_layers_out=n_layers_out, n_units_out=n_units_out, binary_y=binary_y, n_layers_r=n_layers_r, n_units_r=n_units_r, nonlin=nonlin, ) # Define loss functions # loss functions for the head if not binary_y: def loss_head( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> jnp.ndarray: # mse loss function inputs, targets, weights = batch preds = predict_fun_head(params, inputs) return jnp.sum(weights * ((preds - targets) ** 2)) else: def loss_head( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> jnp.ndarray: # mse loss function inputs, targets, weights = batch preds = predict_fun_head(params, inputs) return -jnp.sum( weights * (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)) ) @jit def loss_tnet( params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], penalty_l2: float, penalty_diff: float, ) -> jnp.ndarray: # params: list[representation, head_0, head_1] # batch: (X, y, w) X, y, w = batch # pass down to two heads loss_0 = loss_head(params[0], (X, y, 1 - w)) loss_1 = loss_head(params[1], (X, y, w)) # regularization weightsq_head = heads_l2_penalty( params[0], params[1], n_layers_r + n_layers_out, True, penalty_l2, penalty_diff, ) if not avg_objective: return loss_0 + loss_1 + 0.5 * (weightsq_head) else: n_batch = y.shape[0] return (loss_0 + loss_1) / n_batch + 0.5 * (weightsq_head) # Define optimisation routine opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) @jit def update( i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_diff: float ) -> jnp.ndarray: # updating function params = get_params(state) return opt_update( i, grad(loss_tnet)(params, batch, penalty_l2, penalty_diff), state ) # initialise states if same_init: _, init_head = init_fun_head(rng_key, input_shape) init_params = [init_head, init_head] else: rng_key, rng_key_2 = random.split(rng_key) _, init_head_0 = init_fun_head(rng_key, input_shape) _, init_head_1 = init_fun_head(rng_key_2, input_shape) init_params = [init_head_0, init_head_1] opt_state = opt_init(init_params) # calculate number of batches per epoch batch_size = batch_size if batch_size < n else n n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1 train_indices = onp.arange(n) l_best = LARGE_VAL p_curr = 0 # do training for i in range(n_iter): # shuffle data for minibatches onp.random.shuffle(train_indices) for b in range(n_batches): idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] next_batch = X[idx_next, :], y[idx_next, :], w[idx_next] opt_state = update( i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_diff ) if (i % n_iter_print == 0) or early_stopping: params_curr = get_params(opt_state) l_curr = loss_tnet( params_curr, (X_val, y_val, w_val), penalty_l2, penalty_diff ) if i % n_iter_print == 0: log.debug(f"Epoch: {i}, current {val_string} loss {l_curr}") if early_stopping and ((i + 1) * n_batches > n_iter_min): if l_curr < l_best: l_best = l_curr p_curr = 0 params_best = params_curr else: if onp.isnan(l_curr): # if diverged, return best return params_best, predict_fun_head p_curr = p_curr + 1 if p_curr > patience: if return_val_loss: # return loss without penalty l_final = loss_tnet(params_curr, (X_val, y_val, w_val), 0, 0) return params_curr, predict_fun_head, l_final return params_curr, predict_fun_head # return the parameters trained_params = get_params(opt_state) if return_val_loss: # return loss without penalty l_final = loss_tnet(get_params(opt_state), (X_val, y_val, w_val), 0, 0) return trained_params, predict_fun_head, l_final return trained_params, predict_fun_head ================================================ FILE: catenets/models/jax/transformation_utils.py ================================================ """ Utils for transformations """ # Author: Alicia Curth from typing import Any, Optional import numpy as np PW_TRANSFORMATION = "PW" DR_TRANSFORMATION = "DR" RA_TRANSFORMATION = "RA" ALL_TRANSFORMATIONS = [PW_TRANSFORMATION, DR_TRANSFORMATION, RA_TRANSFORMATION] def aipw_te_transformation( y: np.ndarray, w: np.ndarray, p: Optional[np.ndarray], mu_0: np.ndarray, mu_1: np.ndarray, ) -> np.ndarray: """ Transforms data to efficient influence function pseudo-outcome for CATE estimation Parameters ---------- y : array-like of shape (n_samples,) or (n_samples, ) The observed outcome variable w: array-like of shape (n_samples,) The observed treatment indicator p: array-like of shape (n_samples,) The treatment propensity, estimated or known. Can be None, then p=0.5 is assumed mu_0: array-like of shape (n_samples,) Estimated or known potential outcome mean of the control group mu_1: array-like of shape (n_samples,) Estimated or known potential outcome mean of the treatment group Returns ------- d_hat: EIF transformation for CATE """ if p is None: # assume equal p = np.full(len(y), 0.5) w_1 = w / p w_0 = (1 - w) / (1 - p) return (w_1 - w_0) * y + ((1 - w_1) * mu_1 - (1 - w_0) * mu_0) def ht_te_transformation( y: np.ndarray, w: np.ndarray, p: Optional[np.ndarray] = None, mu_0: Optional[np.ndarray] = None, mu_1: Optional[np.ndarray] = None, ) -> np.ndarray: """ Transform data to Horvitz-Thompson transformation for CATE Parameters ---------- y : array-like of shape (n_samples,) or (n_samples, ) The observed outcome variable w: array-like of shape (n_samples,) The observed treatment indicator p: array-like of shape (n_samples,) The treatment propensity, estimated or known. Can be None, then p=0.5 is assumed mu_0: array-like of shape (n_samples,) Placeholder, not used. Estimated or known potential outcome mean of the control group mu_1: array-like of shape (n_samples,) Placerholder, not used. Estimated or known potential outcome mean of the treatment group Returns ------- res: array-like of shape (n_samples,) Horvitz-Thompson transformed data """ if p is None: # assume equal propensities p = np.full(len(y), 0.5) return (w / p - (1 - w) / (1 - p)) * y def ra_te_transformation( y: np.ndarray, w: np.ndarray, p: Optional[np.ndarray], mu_0: np.ndarray, mu_1: np.ndarray, ) -> np.ndarray: """ Transform data to regression adjustment for CATE Parameters ---------- y : array-like of shape (n_samples,) or (n_samples, ) The observed outcome variable w: array-like of shape (n_samples,) The observed treatment indicator p: array-like of shape (n_samples,) Placeholder, not used. The treatment propensity, estimated or known. mu_0: array-like of shape (n_samples,) Estimated or known potential outcome mean of the control group mu_1: array-like of shape (n_samples,) Estimated or known potential outcome mean of the treatment group Returns ------- res: array-like of shape (n_samples,) Regression adjusted transformation """ return w * (y - mu_0) + (1 - w) * (mu_1 - y) TRANSFORMATION_DICT = { PW_TRANSFORMATION: ht_te_transformation, RA_TRANSFORMATION: ra_te_transformation, DR_TRANSFORMATION: aipw_te_transformation, } def _get_transformation_function(transformation_name: str) -> Any: """ Get transformation function associated with a name """ if transformation_name not in ALL_TRANSFORMATIONS: raise ValueError( "Parameter first stage should be in " "catenets.models.transformations.ALL_TRANSFORMATIONS." " You passed {}".format(transformation_name) ) return TRANSFORMATION_DICT[transformation_name] ================================================ FILE: catenets/models/jax/xnet.py ================================================ """ Module implements X-learner from Kuenzel et al (2019) using NNs """ # Author: Alicia Curth from typing import Callable, Optional, Tuple import jax.numpy as jnp import catenets.logger as log from catenets.models.constants import ( DEFAULT_AVG_OBJECTIVE, DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_OUT_T, DEFAULT_LAYERS_R, DEFAULT_LAYERS_R_T, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_STEP_SIZE_T, DEFAULT_UNITS_OUT, DEFAULT_UNITS_OUT_T, DEFAULT_UNITS_R, DEFAULT_UNITS_R_T, DEFAULT_VAL_SPLIT, ) from catenets.models.jax.base import BaseCATENet, train_output_net_only from catenets.models.jax.model_utils import check_shape_1d_data, check_X_is_np from catenets.models.jax.pseudo_outcome_nets import ( # same strategies as other nets ALL_STRATEGIES, FLEX_STRATEGY, OFFSET_STRATEGY, S1_STRATEGY, S2_STRATEGY, S3_STRATEGY, S_STRATEGY, T_STRATEGY, predict_flextenet, predict_offsetnet, predict_snet, predict_snet1, predict_snet2, predict_snet3, predict_t_net, train_flextenet, train_offsetnet, train_snet, train_snet1, train_snet2, train_snet3, train_tnet, ) class XNet(BaseCATENet): """ Class implements X-learner using NNs. Parameters ---------- weight_strategy: int, default None Which strategy to use to weight the two CATE estimators in the second stage. weight_strategy is coded as follows: for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)] weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1, weight_strategy=None sets g(x)=pi(x) [propensity score], weight_strategy=-1 sets g(x)=(1-pi(x)) binary_y: bool, default False Whether the outcome is binary n_layers_out: int First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int First stage Number of hidden units in each hypothesis layer n_layers_r: int First stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_r: int First stage Number of hidden units in each representation layer n_layers_out_t: int Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out_t: int Second stage Number of hidden units in each hypothesis layer n_layers_r_t: int Second stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_r_t: int Second stage Number of hidden units in each representation layer penalty_l2: float First stage l2 (ridge) penalty penalty_l2_t: float Second stage l2 (ridge) penalty step_size: float First stage learning rate for optimizer step_size_t: float Second stage learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used nonlin: string, default 'elu' Nonlinearity to use in NN """ def __init__( self, weight_strategy: Optional[int] = None, first_stage_strategy: str = T_STRATEGY, first_stage_args: Optional[dict] = None, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, ): # settings self.weight_strategy = weight_strategy self.first_stage_strategy = first_stage_strategy self.first_stage_args = first_stage_args self.binary_y = binary_y # model architecture hyperparams self.n_layers_out = n_layers_out self.n_layers_out_t = n_layers_out_t self.n_layers_r = n_layers_r self.n_layers_r_t = n_layers_r_t self.n_units_out = n_units_out self.n_units_out_t = n_units_out_t self.n_units_r = n_units_r self.n_units_r_t = n_units_r_t self.nonlin = nonlin # other hyperparameters self.penalty_l2 = penalty_l2 self.penalty_l2_t = penalty_l2_t self.step_size = step_size self.step_size_t = step_size_t self.n_iter = n_iter self.batch_size = batch_size self.n_iter_print = n_iter_print self.seed = seed self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min def _get_train_function(self) -> Callable: return train_x_net def _get_predict_function(self) -> Callable: # Two step nets do not need this return predict_x_net def predict( self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False ) -> jnp.ndarray: """ Predict treatment effect estimates using a CATENet. Depending on method, can also return potential outcome estimate and propensity score estimate. Parameters ---------- X: pd.DataFrame or np.array Covariate matrix return_po: bool, default False Whether to return potential outcome estimate return_prop: bool, default False Whether to return propensity estimate Returns ------- array of CATE estimates, optionally also potential outcomes and propensity """ X = check_X_is_np(X) predict_func = self._get_predict_function() return predict_func( X, trained_params=self._params, predict_funs=self._predict_funs, return_po=return_po, return_prop=return_prop, weight_strategy=self.weight_strategy, ) def train_x_net( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, weight_strategy: Optional[int] = None, first_stage_strategy: str = T_STRATEGY, first_stage_args: Optional[dict] = None, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_layers_r_t: int = DEFAULT_LAYERS_R_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_r_t: int = DEFAULT_UNITS_R_T, penalty_l2: float = DEFAULT_PENALTY_L2, penalty_l2_t: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, step_size_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, return_val_loss: bool = False, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, ) -> Tuple: y = check_shape_1d_data(y) if len(w.shape) > 1: w = w.reshape((len(w),)) if weight_strategy not in [0, 1, -1, None]: # weight_strategy is coded as follows: # for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)] # weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1, # weight_strategy=None sets g(x)=pi(x) [propensity score], # weight_strategy=-1 sets g(x)=(1-pi(x)) raise ValueError("XNet only implements weight_strategy in [0, 1, -1, None]") if first_stage_strategy not in ALL_STRATEGIES: raise ValueError( "Parameter first stage should be in " "catenets.models.twostep_nets.ALL_STRATEGIES. " "You passed {}".format(first_stage_strategy) ) # first stage: get estimates of PO regression log.debug("Training first stage") mu_hat_0, mu_hat_1 = _get_first_stage_pos( X, y, w, binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, avg_objective=avg_objective, first_stage_strategy=first_stage_strategy, first_stage_args=first_stage_args, ) if weight_strategy is None or weight_strategy == -1: # also fit propensity estimator log.debug("Training propensity net") params_prop, predict_fun_prop = train_output_net_only( X, w, binary_y=True, n_layers_out=n_layers_out, n_units_out=n_units_out, n_layers_r=n_layers_r, n_units_r=n_units_r, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, avg_objective=avg_objective, ) else: params_prop, predict_fun_prop = None, None # second stage log.debug("Training second stage") if not weight_strategy == 0: # fit tau_0 log.debug("Fitting tau_0") pseudo_outcome0 = mu_hat_1 - y[w == 0] params_tau0, predict_fun_tau0 = train_output_net_only( X[w == 0], pseudo_outcome0, binary_y=False, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective, ) else: params_tau0, predict_fun_tau0 = None, None if not weight_strategy == 1: # fit tau_1 log.debug("Fitting tau_1") pseudo_outcome1 = y[w == 1] - mu_hat_0 params_tau1, predict_fun_tau1 = train_output_net_only( X[w == 1], pseudo_outcome1, binary_y=False, n_layers_out=n_layers_out_t, n_units_out=n_units_out_t, n_layers_r=n_layers_r_t, n_units_r=n_units_r_t, penalty_l2=penalty_l2_t, step_size=step_size_t, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, return_val_loss=return_val_loss, nonlin=nonlin, avg_objective=avg_objective, ) else: params_tau1, predict_fun_tau1 = None, None params = params_tau0, params_tau1, params_prop predict_funs = predict_fun_tau0, predict_fun_tau1, predict_fun_prop return params, predict_funs def _get_first_stage_pos( X: jnp.ndarray, y: jnp.ndarray, w: jnp.ndarray, first_stage_strategy: str = T_STRATEGY, first_stage_args: Optional[dict] = None, binary_y: bool = False, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_out: int = DEFAULT_UNITS_OUT, n_units_r: int = DEFAULT_UNITS_R, penalty_l2: float = DEFAULT_PENALTY_L2, step_size: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_min: int = DEFAULT_N_ITER_MIN, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, avg_objective: bool = DEFAULT_AVG_OBJECTIVE, ) -> Tuple[jnp.ndarray, jnp.ndarray]: if first_stage_args is None: first_stage_args = {} train_fun: Callable predict_fun: Callable if first_stage_strategy == T_STRATEGY: train_fun, predict_fun = train_tnet, predict_t_net elif first_stage_strategy == S_STRATEGY: train_fun, predict_fun = train_snet, predict_snet elif first_stage_strategy == S1_STRATEGY: train_fun, predict_fun = train_snet1, predict_snet1 elif first_stage_strategy == S2_STRATEGY: train_fun, predict_fun = train_snet2, predict_snet2 elif first_stage_strategy == S3_STRATEGY: train_fun, predict_fun = train_snet3, predict_snet3 elif first_stage_strategy == OFFSET_STRATEGY: train_fun, predict_fun = train_offsetnet, predict_offsetnet elif first_stage_strategy == FLEX_STRATEGY: train_fun, predict_fun = train_flextenet, predict_flextenet trained_params, pred_fun = train_fun( X, y, w, binary_y=binary_y, n_layers_r=n_layers_r, n_units_r=n_units_r, n_layers_out=n_layers_out, n_units_out=n_units_out, penalty_l2=penalty_l2, step_size=step_size, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, early_stopping=early_stopping, patience=patience, n_iter_min=n_iter_min, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, avg_objective=avg_objective, **first_stage_args ) _, mu_0, mu_1 = predict_fun(X, trained_params, pred_fun, return_po=True) return mu_0[w == 1], mu_1[w == 0] def predict_x_net( X: jnp.ndarray, trained_params: dict, predict_funs: list, return_po: bool = False, return_prop: bool = False, weight_strategy: Optional[int] = None, ) -> jnp.ndarray: if return_po: raise NotImplementedError("TwoStepNets have no Potential outcome predictors.") if return_prop: raise NotImplementedError("TwoStepNets have no Propensity predictors.") params_tau0, params_tau1, params_prop = trained_params predict_fun_tau0, predict_fun_tau1, predict_fun_prop = predict_funs tau0_pred: jnp.ndarray tau1_pred: jnp.ndarray if not weight_strategy == 0: tau0_pred = predict_fun_tau0(params_tau0, X) else: tau0_pred = 0 if not weight_strategy == 1: tau1_pred = predict_fun_tau1(params_tau1, X) else: tau1_pred = 0 if weight_strategy is None or weight_strategy == -1: prop_pred = predict_fun_prop(params_prop, X) if weight_strategy is None: weight = prop_pred elif weight_strategy == -1: weight = 1 - prop_pred elif weight_strategy == 0: weight = 0 elif weight_strategy == 1: weight = 1 return weight * tau0_pred + (1 - weight) * tau1_pred ================================================ FILE: catenets/models/torch/__init__.py ================================================ """ PyTorch-based implementations for the CATE estimators. """ from .flextenet import FlexTENet from .pseudo_outcome_nets import ( DRLearner, PWLearner, RALearner, RLearner, ULearner, XLearner, ) from .representation_nets import DragonNet, TARNet from .slearner import SLearner from .snet import SNet from .tlearner import TLearner __all__ = [ "TLearner", "SLearner", "TARNet", "DragonNet", "XLearner", "RLearner", "ULearner", "RALearner", "PWLearner", "DRLearner", "SNet", "FlexTENet", ] ================================================ FILE: catenets/models/torch/base.py ================================================ import abc from typing import Optional import numpy as np import torch from torch import nn import catenets.logger as log from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.torch.utils.decorators import benchmark, check_input_train from catenets.models.torch.utils.model_utils import make_val_split from catenets.models.torch.utils.weight_utils import compute_importance_weights DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EPS = 1e-8 NONLIN = { "elu": nn.ELU, "relu": nn.ReLU, "leaky_relu": nn.LeakyReLU, "selu": nn.SELU, "sigmoid": nn.Sigmoid, } class BasicNet(nn.Module): """ Basic hypothesis neural net. Parameters ---------- n_unit_in: int Number of features n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer) n_units_out: int Number of hidden units in each hypothesis layer binary_y: bool, default False Whether the outcome is binary. Impacts the loss function. nonlin: string, default 'elu' Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu' or 'leaky_relu'. lr: float learning rate for optimizer. step_size equivalent in the JAX version. weight_decay: float l2 (ridge) penalty for the weights. n_iter: int Maximum number of iterations. batch_size: int Batch size n_iter_print: int Number of iterations after which to print updates and check the validation loss. seed: int Seed used val_split_prop: float Proportion of samples used for validation split (can be 0) patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping clipping_value: int, default 1 Gradients clipping value """ def __init__( self, name: str, n_unit_in: int, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, binary_y: bool = False, nonlin: str = DEFAULT_NONLIN, lr: float = DEFAULT_STEP_SIZE, weight_decay: float = DEFAULT_PENALTY_L2, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, val_split_prop: float = DEFAULT_VAL_SPLIT, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, clipping_value: int = 1, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2, ) -> None: super(BasicNet, self).__init__() self.name = name if nonlin not in list(NONLIN.keys()): raise ValueError("Unknown nonlinearity") NL = NONLIN[nonlin] if n_layers_out > 0: if batch_norm: layers = [ nn.Linear(n_unit_in, n_units_out), nn.BatchNorm1d(n_units_out), NL(), ] else: layers = [nn.Linear(n_unit_in, n_units_out), NL()] # add required number of layers for i in range(n_layers_out - 1): if dropout: layers.extend([nn.Dropout(dropout_prob)]) if batch_norm: layers.extend( [ nn.Linear(n_units_out, n_units_out), nn.BatchNorm1d(n_units_out), NL(), ] ) else: layers.extend( [ nn.Linear(n_units_out, n_units_out), NL(), ] ) # add final layers layers.append(nn.Linear(n_units_out, 1)) else: layers = [nn.Linear(n_unit_in, 1)] if binary_y: layers.append(nn.Sigmoid()) # return final architecture self.model = nn.Sequential(*layers).to(DEVICE) self.binary_y = binary_y self.n_iter = n_iter self.batch_size = batch_size self.n_iter_print = n_iter_print self.seed = seed self.val_split_prop = val_split_prop self.patience = patience self.n_iter_min = n_iter_min self.clipping_value = clipping_value self.early_stopping = early_stopping self.optimizer = torch.optim.Adam( self.parameters(), lr=lr, weight_decay=weight_decay ) def forward(self, X: torch.Tensor) -> torch.Tensor: return self.model(X) def fit( self, X: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] = None ) -> "BasicNet": self.train() X = self._check_tensor(X) y = self._check_tensor(y).squeeze() # get validation split (can be none) X, y, X_val, y_val, val_string = make_val_split( X, y, val_split_prop=self.val_split_prop, seed=self.seed ) y_val = y_val.squeeze() n = X.shape[0] # could be different from before due to split # calculate number of batches per epoch batch_size = self.batch_size if self.batch_size < n else n n_batches = int(np.round(n / batch_size)) if batch_size < n else 1 train_indices = np.arange(n) # do training val_loss_best = LARGE_VAL patience = 0 for i in range(self.n_iter): # shuffle data for minibatches np.random.shuffle(train_indices) train_loss = [] for b in range(n_batches): self.optimizer.zero_grad() idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] X_next = X[idx_next] y_next = y[idx_next] weight_next = None if weight is not None: weight_next = weight[idx_next].detach() loss = nn.BCELoss(weight=weight_next) if self.binary_y else nn.MSELoss() preds = self.forward(X_next).squeeze() batch_loss = loss(preds, y_next) batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value) self.optimizer.step() train_loss.append(batch_loss.detach()) train_loss = torch.Tensor(train_loss).to(DEVICE) if self.early_stopping or i % self.n_iter_print == 0: loss = nn.BCELoss() if self.binary_y else nn.MSELoss() with torch.no_grad(): preds = self.forward(X_val).squeeze() val_loss = loss(preds, y_val) if self.early_stopping: if val_loss_best > val_loss: val_loss_best = val_loss patience = 0 else: patience += 1 if patience > self.patience and i > self.n_iter_min: break if i % self.n_iter_print == 0: log.info( f"[{self.name}] Epoch: {i}, current {val_string} loss: {val_loss}, train_loss: {torch.mean(train_loss)}" ) return self def _check_tensor(self, X: torch.Tensor) -> torch.Tensor: if isinstance(X, torch.Tensor): return X.to(DEVICE) else: return torch.from_numpy(np.asarray(X)).to(DEVICE) class RepresentationNet(nn.Module): """ Basic representation neural net Parameters ---------- n_unit_in: int Number of features n_layers: int Number of shared representation layers before hypothesis layers n_units: int Number of hidden units in each representation layer nonlin: string, default 'elu' Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu' or 'leaky_relu'. """ def __init__( self, n_unit_in: int, n_layers: int = DEFAULT_LAYERS_R, n_units: int = DEFAULT_UNITS_R, nonlin: str = DEFAULT_NONLIN, batch_norm: bool = True, ) -> None: super(RepresentationNet, self).__init__() if nonlin not in list(NONLIN.keys()): raise ValueError("Unknown nonlinearity") NL = NONLIN[nonlin] if batch_norm: layers = [nn.Linear(n_unit_in, n_units), nn.BatchNorm1d(n_units), NL()] else: layers = [nn.Linear(n_unit_in, n_units), NL()] # add required number of layers for i in range(n_layers - 1): if batch_norm: layers.extend( [nn.Linear(n_units, n_units), nn.BatchNorm1d(n_units), NL()] ) else: layers.extend([nn.Linear(n_units, n_units), NL()]) self.model = nn.Sequential(*layers).to(DEVICE) def forward(self, X: torch.Tensor) -> torch.Tensor: return self.model(X) class PropensityNet(nn.Module): """ Basic propensity neural net Parameters ---------- name: str Display name n_unit_in: int Number of features n_unit_out: int Number of output features weighting_strategy: str Weighting strategy n_units_out_prop: int Number of hidden units in each propensity score hypothesis layer n_layers_out_prop: int Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer) nonlin: string, default 'elu' Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu' or 'leaky_relu'. lr: float learning rate for optimizer. step_size equivalent in the JAX version. weight_decay: float l2 (ridge) penalty for the weights. n_iter: int Maximum number of iterations. batch_size: int Batch size n_iter_print: int Number of iterations after which to print updates and check the validation loss. seed: int Seed used val_split_prop: float Proportion of samples used for validation split (can be 0) patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping clipping_value: int, default 1 Gradients clipping value """ def __init__( self, name: str, n_unit_in: int, n_unit_out: int, weighting_strategy: str, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = 0, nonlin: str = DEFAULT_NONLIN, lr: float = DEFAULT_STEP_SIZE, weight_decay: float = DEFAULT_PENALTY_L2, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, val_split_prop: float = DEFAULT_VAL_SPLIT, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, clipping_value: int = 1, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2, ) -> None: super(PropensityNet, self).__init__() if nonlin not in list(NONLIN.keys()): raise ValueError("Unknown nonlinearity") NL = NONLIN[nonlin] if batch_norm: layers = [ nn.Linear(in_features=n_unit_in, out_features=n_units_out_prop), nn.BatchNorm1d(n_units_out_prop), NL(), ] else: layers = [ nn.Linear(in_features=n_unit_in, out_features=n_units_out_prop), NL(), ] for i in range(n_layers_out_prop - 1): if dropout: layers.extend([nn.Dropout(dropout_prob)]) if batch_norm: layers.extend( [ nn.Linear( in_features=n_units_out_prop, out_features=n_units_out_prop ), nn.BatchNorm1d(n_units_out_prop), NL(), ] ) else: layers.extend( [ nn.Linear( in_features=n_units_out_prop, out_features=n_units_out_prop ), NL(), ] ) layers.extend( [ nn.Linear(in_features=n_units_out_prop, out_features=n_unit_out), nn.Softmax(dim=-1), ] ) self.model = nn.Sequential(*layers).to(DEVICE) self.name = name self.weighting_strategy = weighting_strategy self.n_iter = n_iter self.batch_size = batch_size self.n_iter_print = n_iter_print self.seed = seed self.val_split_prop = val_split_prop self.patience = patience self.n_iter_min = n_iter_min self.clipping_value = clipping_value self.early_stopping = early_stopping self.optimizer = torch.optim.Adam( self.parameters(), lr=lr, weight_decay=weight_decay ) def forward(self, X: torch.Tensor) -> torch.Tensor: return self.model(X) def get_importance_weights( self, X: torch.Tensor, w: Optional[torch.Tensor] = None ) -> torch.Tensor: p_pred = self.forward(X).squeeze()[:, 1] return compute_importance_weights(p_pred, w, self.weighting_strategy, {}) def loss(self, y_pred: torch.Tensor, y_target: torch.Tensor) -> torch.Tensor: return nn.NLLLoss()(torch.log(y_pred + EPS), y_target) def fit(self, X: torch.Tensor, y: torch.Tensor) -> "PropensityNet": self.train() X = self._check_tensor(X) y = self._check_tensor(y).long() # get validation split (can be none) X, y, X_val, y_val, val_string = make_val_split( X, y, val_split_prop=self.val_split_prop, seed=self.seed ) y_val = y_val.squeeze() n = X.shape[0] # could be different from before due to split # calculate number of batches per epoch batch_size = self.batch_size if self.batch_size < n else n n_batches = int(np.round(n / batch_size)) if batch_size < n else 1 train_indices = np.arange(n) # do training val_loss_best = LARGE_VAL patience = 0 for i in range(self.n_iter): # shuffle data for minibatches np.random.shuffle(train_indices) train_loss = [] for b in range(n_batches): self.optimizer.zero_grad() idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] X_next = X[idx_next] y_next = y[idx_next].squeeze() preds = self.forward(X_next).squeeze() batch_loss = self.loss(preds, y_next) batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value) self.optimizer.step() train_loss.append(batch_loss.detach()) train_loss = torch.Tensor(train_loss).to(DEVICE) if self.early_stopping or i % self.n_iter_print == 0: with torch.no_grad(): preds = self.forward(X_val).squeeze() val_loss = self.loss(preds, y_val) if self.early_stopping: if val_loss_best > val_loss: val_loss_best = val_loss patience = 0 else: patience += 1 if patience > self.patience and ( (i + 1) * n_batches > self.n_iter_min ): break if i % self.n_iter_print == 0: log.info( f"[{self.name}] Epoch: {i}, current {val_string} loss: {val_loss}, train_loss: {torch.mean(train_loss)}" ) return self def _check_tensor(self, X: torch.Tensor) -> torch.Tensor: if isinstance(X, torch.Tensor): return X.to(DEVICE) else: return torch.from_numpy(np.asarray(X)).to(DEVICE) class BaseCATEEstimator(nn.Module): """ Interface for estimators of CATE. The interface has train/forward API for PyTorch-based models and fit/predict API for sklearn-based models. """ def __init__( self, ) -> None: super(BaseCATEEstimator, self).__init__() def score( self, X: torch.Tensor, y: torch.Tensor, ) -> float: """ Return the sqrt PEHE error (oracle metric). Parameters ---------- X: torch.Tensor Covariate matrix y: torch.Tensor Expected potential outcome vector """ X = self._check_tensor(X) y = self._check_tensor(y) if len(X) != len(y): raise ValueError("X/y length mismatch for score") if y.shape[-1] != 2: raise ValueError(f"y has invalid shape {y.shape}") hat_te = self.predict(X) return torch.sqrt(torch.mean(((y[:, 1] - y[:, 0]) - hat_te) ** 2)) @abc.abstractmethod @check_input_train @benchmark def fit( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, ) -> "BaseCATEEstimator": """ Train method for a CATEModel Parameters ---------- X: torch.Tensor Covariate matrix y: torch.Tensor Outcome vector w: torch.Tensor Treatment indicator """ ... @benchmark def forward(self, X: torch.Tensor) -> torch.Tensor: """ Predict treatment effect estimates using a CATE estimator. Parameters ---------- X: pd.DataFrame or np.array Covariate matrix Returns ------- potential outcomes probabilities """ return self.predict(X, return_po=False, training=True) @abc.abstractmethod @benchmark def predict( self, X: torch.Tensor, return_po: bool = False, training: bool = False ) -> torch.Tensor: """ Predict treatment effect estimates using a CATE estimator. Parameters ---------- X: pd.DataFrame or np.array Covariate matrix return_po: bool, optional Return the potential outcomes too Returns ------- potential outcomes probabilities """ ... def _check_tensor(self, X: torch.Tensor) -> torch.Tensor: if isinstance(X, torch.Tensor): return X.to(DEVICE) else: return torch.from_numpy(np.asarray(X)).to(DEVICE) ================================================ FILE: catenets/models/torch/flextenet.py ================================================ from typing import Any, Callable, List import numpy as np import torch from torch import nn import catenets.logger as log from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_DIM_P_OUT, DEFAULT_DIM_P_R, DEFAULT_DIM_S_OUT, DEFAULT_DIM_S_R, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_PENALTY_ORTHOGONAL, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.torch.base import DEVICE, BaseCATEEstimator from catenets.models.torch.utils.model_utils import make_val_split class FlexTELinearLayer(nn.Module): """Layer constructor function for a fully-connected layer. Adapted to allow passing treatment indicator through layer without using it""" def __init__( self, name: str, dropout: bool = False, dropout_prob: float = 0.5, *args: Any, **kwargs: Any, ) -> None: super(FlexTELinearLayer, self).__init__() self.name = name if dropout: self.model = nn.Sequential( nn.Dropout(dropout_prob), nn.Linear(*args, **kwargs) ).to(DEVICE) else: self.model = nn.Sequential(nn.Linear(*args, **kwargs)).to(DEVICE) def forward(self, tensors: List[torch.Tensor]) -> List: if len(tensors) != 2: raise ValueError( "Invalid number of tensor for the FlexLinearLayer layer. It requires the features vector and the treatments vector" ) features_vector = tensors[0] treatments_vector = tensors[1] return [self.model(features_vector), treatments_vector] class FlexTESplitLayer(nn.Module): """ Create multitask layer has shape [shared, private_0, private_1] """ def __init__( self, name: str, n_units_in: int, n_units_in_p: int, n_units_s: int, n_units_p: int, first_layer: bool, dropout: bool = False, dropout_prob: float = 0.5, ) -> None: super(FlexTESplitLayer, self).__init__() self.name = name self.first_layer = first_layer self.n_units_in = n_units_in self.n_units_in_p = n_units_in_p self.n_units_s = n_units_s self.n_units_p = n_units_p if dropout: self.net_shared = nn.Sequential( nn.Dropout(dropout_prob), nn.Linear(n_units_in, n_units_s) ).to(DEVICE) self.net_p0 = nn.Sequential( nn.Dropout(dropout_prob), nn.Linear(n_units_in_p, n_units_p) ).to(DEVICE) self.net_p1 = nn.Sequential( nn.Dropout(dropout_prob), nn.Linear(n_units_in_p, n_units_p) ).to(DEVICE) else: self.net_shared = nn.Sequential(nn.Linear(n_units_in, n_units_s)).to(DEVICE) self.net_p0 = nn.Sequential(nn.Linear(n_units_in_p, n_units_p)).to(DEVICE) self.net_p1 = nn.Sequential(nn.Linear(n_units_in_p, n_units_p)).to(DEVICE) def forward(self, tensors: List[torch.Tensor]) -> List: if self.first_layer and len(tensors) != 2: raise ValueError( "Invalid number of tensor for the FlexSplitLayer layer. It requires the features vector and the treatments vector" ) if not self.first_layer and len(tensors) != 4: raise ValueError( "Invalid number of tensor for the FlexSplitLayer layer. It requires X_s, X_p0, X_p1 and W as input" ) if self.first_layer: X = tensors[0] W = tensors[1] rep_s = self.net_shared(X) rep_p0 = self.net_p0(X) rep_p1 = self.net_p1(X) else: X_s = tensors[0] X_p0 = tensors[1] X_p1 = tensors[2] W = tensors[3] rep_s = self.net_shared(X_s) rep_p0 = self.net_p0(torch.cat([X_s, X_p0], dim=1)) rep_p1 = self.net_p1(torch.cat([X_s, X_p1], dim=1)) return [rep_s, rep_p0, rep_p1, W] class FlexTEOutputLayer(nn.Module): def __init__( self, n_units_in: int, n_units_in_p: int, private: bool, dropout: bool = False, dropout_prob: float = 0.5, ) -> None: super(FlexTEOutputLayer, self).__init__() self.private = private if dropout: self.net_shared = nn.Sequential( nn.Dropout(dropout_prob), nn.Linear(n_units_in, 1) ).to(DEVICE) self.net_p0 = nn.Sequential( nn.Dropout(dropout_prob), nn.Linear(n_units_in_p, 1) ).to(DEVICE) self.net_p1 = nn.Sequential( nn.Dropout(dropout_prob), nn.Linear(n_units_in_p, 1) ).to(DEVICE) else: self.net_shared = nn.Sequential(nn.Linear(n_units_in, 1)).to(DEVICE) self.net_p0 = nn.Sequential(nn.Linear(n_units_in_p, 1)).to(DEVICE) self.net_p1 = nn.Sequential(nn.Linear(n_units_in_p, 1)).to(DEVICE) def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: if len(tensors) != 4: raise ValueError( "Invalid number of tensor for the FlexSplitLayer layer. It requires X_s, X_p0, X_p1 and W as input" ) X_s = tensors[0] X_p0 = tensors[1] X_p1 = tensors[2] W = tensors[3] if self.private: rep_p0 = self.net_p0(torch.cat([X_s, X_p0], dim=1)).squeeze() rep_p1 = self.net_p1(torch.cat([X_s, X_p1], dim=1)).squeeze() return (1 - W) * rep_p0 + W * rep_p1 else: rep_s = self.net_shared(X_s).squeeze() rep_p0 = self.net_p0(torch.cat([X_s, X_p0], dim=1)).squeeze() rep_p1 = self.net_p1(torch.cat([X_s, X_p1], dim=1)).squeeze() return (1 - W) * rep_p0 + W * rep_p1 + rep_s class ElementWiseParallelActivation(nn.Module): """Layer that applies a scalar function elementwise on its inputs. Input looks like: X_s, X_p0, X_p1, t = inputs """ def __init__(self, act: Callable, **act_kwargs: Any) -> None: super(ElementWiseParallelActivation, self).__init__() self.act = act self.act_kwargs = act_kwargs def forward(self, tensors: List[torch.Tensor]) -> List: if len(tensors) != 4: raise ValueError( "Invalid number of tensor for the ElementWiseParallelActivation layer. It requires X_s, X_p0, X_p1, t as input" ) return [ self.act(tensors[0], **self.act_kwargs), self.act(tensors[1], **self.act_kwargs), self.act(tensors[2], **self.act_kwargs), tensors[3], ] class ElementWiseSplitActivation(nn.Module): """Layer that applies a scalar function elementwise on its inputs. Input looks like: X, t = inputs """ def __init__(self, act: Callable, **act_kwargs: Any) -> None: super(ElementWiseSplitActivation, self).__init__() self.act = act self.act_kwargs = act_kwargs def forward(self, tensors: List[torch.Tensor]) -> List: if len(tensors) != 2: raise ValueError( "Invalid number of tensor for the ElementWiseSplitActivation layer. It requires X, t as input" ) return [ self.act(tensors[0], **self.act_kwargs), tensors[1], ] class FlexTENet(BaseCATEEstimator): """ CLass implements FlexTENet, an architecture for treatment effect estimation that allows for both shared and private information in each layer of the network. Parameters ---------- n_unit_in: int Number of features binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer) n_units_s_out: int Number of hidden units in each shared hypothesis layer n_units_p_out: int Number of hidden units in each private hypothesis layer n_layers_r: int Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets) n_units_s_r: int Number of hidden units in each shared representation layer n_units_s_r: int Number of hidden units in each private representation layer private_out: bool, False Whether the final prediction layer should be fully private, or retain a shared component. weight_decay: float l2 (ridge) penalty penalty_orthogonal: float orthogonalisation penalty lr: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) early_stopping: bool, default True Whether to use early stopping patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used opt: str, default 'adam' Optimizer to use, accepts 'adam' and 'sgd' shared_repr: bool, False Whether to use a shared representation block as TARNet lr_scale: float Whether to scale down the learning rate after unfreezing the private components of the network (only used if pretrain_shared=True) normalize_ortho: bool, False Whether to normalize the orthogonality penalty (by depth of network) clipping_value: int, default 1 Gradients clipping value """ def __init__( self, n_unit_in: int, binary_y: bool, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_s_out: int = DEFAULT_DIM_S_OUT, n_units_p_out: int = DEFAULT_DIM_P_OUT, n_layers_r: int = DEFAULT_LAYERS_R, n_units_s_r: int = DEFAULT_DIM_S_R, n_units_p_r: int = DEFAULT_DIM_P_R, private_out: bool = False, weight_decay: float = DEFAULT_PENALTY_L2, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, lr: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, early_stopping: bool = True, patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, shared_repr: bool = False, normalize_ortho: bool = False, mode: int = 1, clipping_value: int = 1, dropout: bool = False, dropout_prob: float = 0.5, ) -> None: super(FlexTENet, self).__init__() self.binary_y = binary_y self.n_layers_r = n_layers_r if n_layers_r else 1 self.n_layers_out = n_layers_out self.n_units_s_out = n_units_s_out self.n_units_p_out = n_units_p_out self.n_units_s_r = n_units_s_r self.n_units_p_r = n_units_p_r self.private_out = private_out self.mode = mode self.penalty_orthogonal = penalty_orthogonal self.weight_decay = weight_decay self.lr = lr self.n_iter = n_iter self.batch_size = batch_size self.val_split_prop = val_split_prop self.early_stopping = early_stopping self.patience = patience self.n_iter_min = n_iter_min self.shared_repr = shared_repr self.normalize_ortho = normalize_ortho self.clipping_value = clipping_value self.early_stopping = early_stopping self.dropout = dropout self.dropout_prob = dropout_prob self.seed = seed self.n_iter_print = n_iter_print layers = [] if shared_repr: # fully shared representation as in TARNet layers.extend( [ FlexTELinearLayer( "shared_repr_layer_0", dropout, dropout_prob, n_unit_in, n_units_s_r, ), ElementWiseSplitActivation(nn.SELU(inplace=True)), ] ) # add required number of layers for i in range(self.n_layers_r - 1): layers.extend( [ FlexTELinearLayer( f"shared_repr_layer_{i + 1}", dropout, dropout_prob, n_units_s_r, n_units_s_r, ), ElementWiseSplitActivation(nn.SELU(inplace=True)), ] ) else: # shared AND private representations layers.extend( [ FlexTESplitLayer( "shared_private_layer_0", n_unit_in, n_unit_in, n_units_s_r, n_units_p_r, first_layer=True, dropout=dropout, dropout_prob=dropout_prob, ), ElementWiseParallelActivation(nn.SELU(inplace=True)), ] ) # add required number of layers for i in range(n_layers_r - 1): layers.extend( [ FlexTESplitLayer( f"shared_private_layer_{i + 1}", n_units_s_r, n_units_s_r + n_units_p_r, n_units_s_r, n_units_p_r, first_layer=False, dropout=dropout, dropout_prob=dropout_prob, ), ElementWiseParallelActivation(nn.SELU(inplace=True)), ] ) # add output layers layers.extend( [ FlexTESplitLayer( "output_layer_0", n_units_s_r, n_units_s_r if shared_repr else n_units_s_r + n_units_p_r, n_units_s_out, n_units_p_out, first_layer=(shared_repr), dropout=dropout, dropout_prob=dropout_prob, ), ElementWiseParallelActivation(nn.SELU(inplace=True)), ] ) # add required number of layers for i in range(n_layers_out - 1): layers.extend( [ FlexTESplitLayer( f"output_layer_{i + 1}", n_units_s_out, n_units_s_out + n_units_p_out, n_units_s_out, n_units_p_out, first_layer=False, dropout=dropout, dropout_prob=dropout_prob, ), ElementWiseParallelActivation(nn.SELU(inplace=True)), ] ) # append final layer layers.append( FlexTEOutputLayer( n_units_s_out, n_units_s_out + n_units_p_out, private=self.private_out, dropout=dropout, dropout_prob=dropout_prob, ) ) if binary_y: layers.append(nn.Sigmoid()) self.model = nn.Sequential(*layers).to(DEVICE) def _ortho_penalty_asymmetric(self) -> torch.Tensor: def _get_cos_reg( params_0: torch.Tensor, params_1: torch.Tensor, normalize: bool ) -> torch.Tensor: if normalize: params_0 = params_0 / torch.linalg.norm(params_0, dim=0) params_1 = params_1 / torch.linalg.norm(params_1, dim=0) x_min = min(params_0.shape[0], params_1.shape[0]) y_min = min(params_0.shape[1], params_1.shape[1]) return ( torch.linalg.norm( params_0[:x_min, :y_min] * params_1[:x_min, :y_min], "fro" ) ** 2 ) def _apply_reg_split_layer( layer: FlexTESplitLayer, full: bool = True ) -> torch.Tensor: _ortho_body = 0 if full: _ortho_body = _get_cos_reg( layer.net_p0[-1].weight, layer.net_p1[-1].weight, self.normalize_ortho, ) _ortho_body += torch.sum( _get_cos_reg( layer.net_shared[-1].weight, layer.net_p0[-1].weight, self.normalize_ortho, ) + _get_cos_reg( layer.net_shared[-1].weight, layer.net_p1[-1].weight, self.normalize_ortho, ) ) return _ortho_body ortho_body = 0 for layer in self.model: if not isinstance(layer, (FlexTESplitLayer, FlexTEOutputLayer)): continue if isinstance(layer, FlexTESplitLayer): ortho_body += _apply_reg_split_layer(layer, full=True) if self.private_out: continue ortho_body += _apply_reg_split_layer(layer, full=False) return self.penalty_orthogonal * ortho_body def loss( self, y0_pred: torch.Tensor, y1_pred: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor, ) -> torch.Tensor: def head_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.binary_y: return nn.BCELoss()(y_pred, y_true) else: return (y_pred - y_true) ** 2 def po_loss() -> torch.Tensor: loss0 = torch.mean((1.0 - t_true) * head_loss(y0_pred, y_true)) loss1 = torch.mean(t_true * head_loss(y1_pred, y_true)) return loss0 + loss1 return po_loss() + self._ortho_penalty_asymmetric() def fit( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, ) -> "FlexTENet": """ Fit treatment models. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) The features to fit to y : torch.Tensor of shape (n_samples,) or (n_samples, ) The outcome variable w: torch.Tensor of shape (n_samples,) The treatment indicator """ self.model.train() X = torch.Tensor(X).to(DEVICE) y = torch.Tensor(y).squeeze().to(DEVICE) w = torch.Tensor(w).squeeze().long().to(DEVICE) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w=w, val_split_prop=self.val_split_prop, seed=self.seed ) n = X.shape[0] # could be different from before due to split # calculate number of batches per epoch batch_size = self.batch_size if self.batch_size < n else n n_batches = int(np.round(n / batch_size)) if batch_size < n else 1 train_indices = np.arange(n) optimizer = torch.optim.Adam( self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) # training val_loss_best = LARGE_VAL patience = 0 for i in range(self.n_iter): # shuffle data for minibatches np.random.shuffle(train_indices) train_loss = [] for b in range(n_batches): optimizer.zero_grad() idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] X_next = X[idx_next] y_next = y[idx_next].squeeze() w_next = w[idx_next].squeeze() _, mu0, mu1 = self.predict(X_next, return_po=True, training=True) batch_loss = self.loss(mu0, mu1, y_next, w_next) batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value) optimizer.step() train_loss.append(batch_loss.detach()) train_loss = torch.Tensor(train_loss).to(DEVICE) if self.early_stopping or i % self.n_iter_print == 0: with torch.no_grad(): _, mu0, mu1 = self.predict(X_val, return_po=True, training=True) val_loss = self.loss(mu0, mu1, y_val, w_val).detach().cpu() if self.early_stopping: if val_loss_best > val_loss: val_loss_best = val_loss patience = 0 else: patience += 1 if patience > self.patience and ( (i + 1) * n_batches > self.n_iter_min ): break if i % self.n_iter_print == 0: log.info( f"[FlexTENet] Epoch: {i}, current {val_string} loss: {val_loss} train_loss: {torch.mean(train_loss)}" ) return self def predict( self, X: torch.Tensor, return_po: bool = False, training: bool = False ) -> torch.Tensor: """ Predict treatment effects and potential outcomes Parameters ---------- X: array-like of shape (n_samples, n_features) Test-sample features Returns ------- y: array-like of shape (n_samples,) """ if not training: self.model.eval() X = self._check_tensor(X).float() W0 = torch.zeros(X.shape[0]).to(DEVICE) W1 = torch.ones(X.shape[0]).to(DEVICE) mu0 = self.model([X, W0]) mu1 = self.model([X, W1]) te = mu1 - mu0 if return_po: return te, mu0, mu1 return te ================================================ FILE: catenets/models/torch/pseudo_outcome_nets.py ================================================ import abc import copy from typing import Any, Optional, Tuple import numpy as np import torch from sklearn.model_selection import StratifiedKFold from torch import nn from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_CF_FOLDS, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_OUT_T, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_STEP_SIZE_T, DEFAULT_UNITS_OUT, DEFAULT_UNITS_OUT_T, DEFAULT_VAL_SPLIT, ) from catenets.models.torch.base import ( DEVICE, BaseCATEEstimator, BasicNet, PropensityNet, ) from catenets.models.torch.utils.model_utils import predict_wrapper, train_wrapper from catenets.models.torch.utils.transformations import ( dr_transformation_cate, pw_transformation_cate, ra_transformation_cate, u_transformation_cate, ) class PseudoOutcomeLearner(BaseCATEEstimator): """ Class implements TwoStepLearners based on pseudo-outcome regression as discussed in Curth &vd Schaar (2021): RA-learner, PW-learner and DR-learner Parameters ---------- n_unit_in: int Number of features binary_y: bool, default False Whether the outcome is binary po_estimator: sklearn/PyTorch model, default: None Custom potential outcome model. If this parameter is set, the rest of the parameters are ignored. te_estimator: sklearn/PyTorch model, default: None Custom treatment effects model. If this parameter is set, the rest of the parameters are ignored. n_folds: int, default 1 Number of cross-fitting folds. If 1, no cross-fitting n_layers_out: int First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer) n_units_out: int First stage Number of hidden units in each hypothesis layer n_layers_r: int Number of shared & private representation layers before hypothesis layers n_units_r: int Number of hidden units in representation shared before the hypothesis layers. n_layers_out_t: int Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer) n_units_out_t: int Second stage Number of hidden units in each hypothesis layer n_layers_out_prop: int Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer) n_units_out_prop: int Number of hidden units in each propensity score hypothesis layer weight_decay: float First stage l2 (ridge) penalty weight_decay_t: float Second stage l2 (ridge) penalty lr: float First stage learning rate for optimizer lr_: float Second stage learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) n_iter_print: int Number of iterations after which to print updates seed: int Seed used nonlin: string, default 'elu' Nonlinearity to use in NN. Can be 'elu', 'relu', 'selu' or 'leaky_relu'. weighting_strategy: str, default "prop" Weighting strategy. Can be "prop" or "1-prop". patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping """ def __init__( self, n_unit_in: int, binary_y: bool, po_estimator: Any = None, te_estimator: Any = None, n_folds: int = DEFAULT_CF_FOLDS, n_layers_out: int = DEFAULT_LAYERS_OUT, n_layers_out_t: int = DEFAULT_LAYERS_OUT_T, n_units_out: int = DEFAULT_UNITS_OUT, n_units_out_t: int = DEFAULT_UNITS_OUT_T, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = 0, weight_decay: float = DEFAULT_PENALTY_L2, weight_decay_t: float = DEFAULT_PENALTY_L2, lr: float = DEFAULT_STEP_SIZE, lr_t: float = DEFAULT_STEP_SIZE_T, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, weighting_strategy: Optional[str] = "prop", patience: int = DEFAULT_PATIENCE, n_iter_min: int = DEFAULT_N_ITER_MIN, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2, ): super(PseudoOutcomeLearner, self).__init__() self.n_unit_in = n_unit_in self.binary_y = binary_y self.n_layers_out = n_layers_out self.n_units_out = n_units_out self.n_units_out_prop = n_units_out_prop self.n_layers_out_prop = n_layers_out_prop self.weight_decay_t = weight_decay_t self.weight_decay = weight_decay self.weighting_strategy = weighting_strategy self.lr = lr self.lr_t = lr_t self.n_iter = n_iter self.batch_size = batch_size self.val_split_prop = val_split_prop self.n_iter_print = n_iter_print self.seed = seed self.nonlin = nonlin self.n_folds = n_folds self.patience = patience self.n_iter_min = n_iter_min self.n_layers_out_t = n_layers_out_t self.n_units_out_t = n_units_out_t self.n_layers_out = n_layers_out self.n_units_out = n_units_out self.batch_norm = batch_norm self.early_stopping = early_stopping self.dropout = dropout self.dropout_prob = dropout_prob # set estimators self._te_template = te_estimator self._po_template = po_estimator self._te_estimator = self._generate_te_estimator() self._po_estimator = self._generate_po_estimator() if weighting_strategy is not None: self._propensity_estimator = self._generate_propensity_estimator() def _generate_te_estimator(self, name: str = "te_estimator") -> nn.Module: if self._te_template is not None: return copy.deepcopy(self._te_template) return BasicNet( name, self.n_unit_in, binary_y=False, n_layers_out=self.n_layers_out_t, n_units_out=self.n_units_out_t, weight_decay=self.weight_decay_t, lr=self.lr_t, n_iter=self.n_iter, batch_size=self.batch_size, val_split_prop=self.val_split_prop, n_iter_print=self.n_iter_print, seed=self.seed, nonlin=self.nonlin, patience=self.patience, n_iter_min=self.n_iter_min, batch_norm=self.batch_norm, early_stopping=self.early_stopping, dropout=self.dropout, dropout_prob=self.dropout_prob, ).to(DEVICE) def _generate_po_estimator(self, name: str = "po_estimator") -> nn.Module: if self._po_template is not None: return copy.deepcopy(self._po_template) return BasicNet( name, self.n_unit_in, binary_y=self.binary_y, n_layers_out=self.n_layers_out, n_units_out=self.n_units_out, weight_decay=self.weight_decay, lr=self.lr, n_iter=self.n_iter, batch_size=self.batch_size, val_split_prop=self.val_split_prop, n_iter_print=self.n_iter_print, seed=self.seed, nonlin=self.nonlin, patience=self.patience, n_iter_min=self.n_iter_min, batch_norm=self.batch_norm, early_stopping=self.early_stopping, dropout=self.dropout, dropout_prob=self.dropout_prob, ).to(DEVICE) def _generate_propensity_estimator( self, name: str = "propensity_estimator" ) -> nn.Module: if self.weighting_strategy is None: raise ValueError("Invalid weighting_strategy for PropensityNet") return PropensityNet( name, self.n_unit_in, 2, # number of treatments self.weighting_strategy, n_units_out_prop=self.n_units_out_prop, n_layers_out_prop=self.n_layers_out_prop, weight_decay=self.weight_decay, lr=self.lr, n_iter=self.n_iter, batch_size=self.batch_size, n_iter_print=self.n_iter_print, seed=self.seed, nonlin=self.nonlin, val_split_prop=self.val_split_prop, batch_norm=self.batch_norm, early_stopping=self.early_stopping, dropout_prob=self.dropout_prob, dropout=self.dropout, ).to(DEVICE) def fit( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor ) -> "PseudoOutcomeLearner": """ Train treatment effects nets. Parameters ---------- X: array-like of shape (n_samples, n_features) Train-sample features y: array-like of shape (n_samples,) Train-sample labels w: array-like of shape (n_samples,) Train-sample treatments """ self.train() X = self._check_tensor(X).float() y = self._check_tensor(y).squeeze().float() w = self._check_tensor(w).squeeze().float() n = len(y) # STEP 1: fit plug-in estimators via cross-fitting if self.n_folds == 1: pred_mask = np.ones(n, dtype=bool) # fit plug-in models mu_0_pred, mu_1_pred, p_pred = self._first_step( X, y, w, pred_mask, pred_mask ) else: mu_0_pred, mu_1_pred, p_pred = ( torch.zeros(n).to(DEVICE), torch.zeros(n).to(DEVICE), torch.zeros(n).to(DEVICE), ) # create folds stratified by treatment assignment to ensure balance splitter = StratifiedKFold( n_splits=self.n_folds, shuffle=True, random_state=self.seed ) for train_index, test_index in splitter.split(X.cpu(), w.cpu()): # create masks pred_mask = torch.zeros(n, dtype=bool).to(DEVICE) pred_mask[test_index] = 1 # fit plug-in te_estimator ( mu_0_pred[pred_mask], mu_1_pred[pred_mask], p_pred[pred_mask], ) = self._first_step(X, y, w, ~pred_mask, pred_mask) # use estimated propensity scores if self.weighting_strategy is not None: p = p_pred # STEP 2: direct TE estimation self._second_step(X, y, w, p, mu_0_pred, mu_1_pred) return self def predict( self, X: torch.Tensor, return_po: bool = False, training: bool = False ) -> torch.Tensor: """ Predict treatment effects Parameters ---------- X: array-like of shape (n_samples, n_features) Test-sample features Returns ------- te_est: array-like of shape (n_samples,) Predicted treatment effects """ if return_po: raise NotImplementedError( "PseudoOutcomeLearners have no Potential outcome predictors." ) if not training: self.eval() X = self._check_tensor(X).float() return predict_wrapper(self._te_estimator, X) @abc.abstractmethod def _first_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pass @abc.abstractmethod def _second_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> None: pass def _impute_pos( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # split sample X_fit, Y_fit, W_fit = X[fit_mask, :], y[fit_mask], w[fit_mask] # fit two separate (standard) models # untreated model temp_model_0 = self._generate_po_estimator("po_estimator_0_impute_pos") train_wrapper(temp_model_0, X_fit[W_fit == 0], Y_fit[W_fit == 0]) # treated model temp_model_1 = self._generate_po_estimator("po_estimator_1_impute_pos") train_wrapper(temp_model_1, X_fit[W_fit == 1], Y_fit[W_fit == 1]) mu_0_pred = predict_wrapper(temp_model_0, X[pred_mask, :]) mu_1_pred = predict_wrapper(temp_model_1, X[pred_mask, :]) return mu_0_pred, mu_1_pred def _impute_propensity( self, X: torch.Tensor, w: torch.Tensor, fit_mask: torch.tensor, pred_mask: torch.Tensor, ) -> torch.Tensor: # split sample X_fit, W_fit = X[fit_mask, :], w[fit_mask] # fit propensity estimator temp_propensity_estimator = self._generate_propensity_estimator( "prop_estimator_impute_propensity" ) train_wrapper(temp_propensity_estimator, X_fit, W_fit) # predict propensity on hold out return temp_propensity_estimator.get_importance_weights( X[pred_mask, :], w[pred_mask] ) def _impute_unconditional_mean( self, X: torch.Tensor, y: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> torch.Tensor: # R-learner and U-learner need to impute unconditional mean X_fit, Y_fit = X[fit_mask, :], y[fit_mask] # fit model temp_model = self._generate_po_estimator("po_est_impute_unconditional_mean") train_wrapper(temp_model, X_fit, Y_fit) return predict_wrapper(temp_model, X[pred_mask, :]) class DRLearner(PseudoOutcomeLearner): """ DR-learner for CATE estimation, based on doubly robust AIPW pseudo-outcome """ def _first_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu0_pred, mu1_pred = self._impute_pos(X, y, w, fit_mask, pred_mask) p_pred = self._impute_propensity(X, w, fit_mask, pred_mask).squeeze() return ( mu0_pred.squeeze().to(DEVICE), mu1_pred.squeeze().to(DEVICE), p_pred.to(DEVICE), ) def _second_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> None: pseudo_outcome = dr_transformation_cate(y, w, p, mu_0, mu_1) train_wrapper(self._te_estimator, X, pseudo_outcome.detach()) class PWLearner(PseudoOutcomeLearner): """ PW-learner for CATE estimation, based on singly robust Horvitz Thompson pseudo-outcome """ def _first_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu0_pred, mu1_pred = np.nan, np.nan # not needed p_pred = self._impute_propensity(X, w, fit_mask, pred_mask).squeeze() return mu0_pred.to(DEVICE), mu1_pred.to(DEVICE), p_pred.to(DEVICE) def _second_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> None: pseudo_outcome = pw_transformation_cate(y, w, p) train_wrapper(self._te_estimator, X, pseudo_outcome.detach()) class RALearner(PseudoOutcomeLearner): """ RA-learner for CATE estimation, based on singly robust regression-adjusted pseudo-outcome """ def _first_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu0_pred, mu1_pred = self._impute_pos(X, y, w, fit_mask, pred_mask) p_pred = np.nan # not needed return mu0_pred.squeeze().to(DEVICE), mu1_pred.squeeze().to(DEVICE), p_pred def _second_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> None: pseudo_outcome = ra_transformation_cate(y, w, p, mu_0, mu_1) train_wrapper(self._te_estimator, X, pseudo_outcome.detach()) class ULearner(PseudoOutcomeLearner): """ U-learner for CATE estimation. Based on pseudo-outcome (Y-mu(x))/(w-pi(x)) """ def _first_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu_pred = self._impute_unconditional_mean(X, y, fit_mask, pred_mask).squeeze() mu1_pred = np.nan # only have one thing to impute here p_pred = self._impute_propensity(X, w, fit_mask, pred_mask).squeeze() return mu_pred.to(DEVICE), mu1_pred, p_pred.to(DEVICE) def _second_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> None: pseudo_outcome = u_transformation_cate(y, w, p, mu_0) train_wrapper(self._te_estimator, X, pseudo_outcome.detach()) class RLearner(PseudoOutcomeLearner): """ R-learner for CATE estimation. Based on pseudo-outcome (Y-mu(x))/(w-pi(x)) and sample weight (w-pi(x))^2 -- can only be implemented if .fit of te_estimator takes argument 'sample_weight'. """ def _first_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu_pred = self._impute_unconditional_mean(X, y, fit_mask, pred_mask).squeeze() mu1_pred = np.nan # only have one thing to impute here p_pred = self._impute_propensity(X, w, fit_mask, pred_mask).squeeze() return mu_pred.to(DEVICE), mu1_pred, p_pred.to(DEVICE) def _second_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> None: pseudo_outcome = u_transformation_cate(y, w, p, mu_0) train_wrapper( self._te_estimator, X, pseudo_outcome.detach(), weight=(w - p) ** 2 ) class XLearner(PseudoOutcomeLearner): """ X-learner for CATE estimation. Combines two CATE estimates via a weighting function g(x): tau(x) = g(x) tau_0(x) + (1-g(x)) tau_1(x) """ def __init__( self, *args: Any, weighting_strategy: str = "prop", **kwargs: Any, ) -> None: super().__init__( *args, **kwargs, ) self.weighting_strategy = weighting_strategy def _first_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, fit_mask: torch.Tensor, pred_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu0_pred, mu1_pred = self._impute_pos(X, y, w, fit_mask, pred_mask) p_pred = np.nan return mu0_pred.squeeze().to(DEVICE), mu1_pred.squeeze().to(DEVICE), p_pred def _second_step( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> None: # split by treatment status, fit one model per group pseudo_0 = mu_1[w == 0] - y[w == 0] self._te_estimator_0 = self._generate_te_estimator("te_estimator_0_xnet") train_wrapper(self._te_estimator_0, X[w == 0], pseudo_0.detach()) pseudo_1 = y[w == 1] - mu_0[w == 1] self._te_estimator_1 = self._generate_te_estimator("te_estimator_1_xnet") train_wrapper(self._te_estimator_1, X[w == 1], pseudo_1.detach()) train_wrapper(self._propensity_estimator, X, w) def predict( self, X: torch.Tensor, return_po: bool = False, training: bool = False ) -> torch.Tensor: """ Predict treatment effects Parameters ---------- X: array-like of shape (n_samples, n_features) Test-sample features return_po: bool, default False Whether to return potential outcome predictions. Placeholder, can only accept False. Returns ------- te_est: array-like of shape (n_samples,) Predicted treatment effects """ if return_po: raise NotImplementedError( "PseudoOutcomeLearners have no Potential outcome predictors." ) if not training: self.eval() X = self._check_tensor(X).float().to(DEVICE) tau0_pred = predict_wrapper(self._te_estimator_0, X) tau1_pred = predict_wrapper(self._te_estimator_1, X) weight = self._propensity_estimator.get_importance_weights(X) return weight * tau0_pred + (1 - weight) * tau1_pred ================================================ FILE: catenets/models/torch/representation_nets.py ================================================ import abc from typing import Any, Optional, Tuple import numpy as np import torch from torch import nn import catenets.logger as log from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_DISC, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.torch.base import ( DEVICE, BaseCATEEstimator, BasicNet, PropensityNet, RepresentationNet, ) from catenets.models.torch.utils.model_utils import make_val_split EPS = 1e-8 class BasicDragonNet(BaseCATEEstimator): """ Base class for TARNet and DragonNet. Parameters ---------- name: str Estimator name n_unit_in: int Number of features propensity_estimator: nn.Module Propensity estimator binary_y: bool, default False Whether the outcome is binary n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer) n_units_out: int Number of hidden units in each hypothesis layer n_layers_r: int Number of shared & private representation layers before the hypothesis layers. n_units_r: int Number of hidden units in representation before the hypothesis layers. weight_decay: float l2 (ridge) penalty lr: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) n_iter_print: int Number of iterations after which to print updates seed: int Seed used nonlin: string, default 'elu' Nonlinearity to use in the neural net. Can be 'elu', 'relu', 'selu', 'leaky_relu'. weighting_strategy: optional str, None Whether to include propensity head and which weightening strategy to use penalty_disc: float, default zero Discrepancy penalty. """ def __init__( self, name: str, n_unit_in: int, propensity_estimator: nn.Module, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, weight_decay: float = DEFAULT_PENALTY_L2, lr: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, weighting_strategy: Optional[str] = None, penalty_disc: float = 0, batch_norm: bool = True, early_stopping: bool = True, prop_loss_multiplier: float = 1, n_iter_min: int = DEFAULT_N_ITER_MIN, patience: int = DEFAULT_PATIENCE, dropout: bool = False, dropout_prob: float = 0.2, ) -> None: super(BasicDragonNet, self).__init__() self.name = name self.val_split_prop = val_split_prop self.seed = seed self.batch_size = batch_size self.n_iter = n_iter self.n_iter_print = n_iter_print self.lr = lr self.weight_decay = weight_decay self.binary_y = binary_y self.penalty_disc = penalty_disc self.early_stopping = early_stopping self.prop_loss_multiplier = prop_loss_multiplier self.n_iter_min = n_iter_min self.patience = patience self.dropout = dropout self.dropout_prob = dropout_prob self._repr_estimator = RepresentationNet( n_unit_in, n_units=n_units_r, n_layers=n_layers_r, nonlin=nonlin, batch_norm=batch_norm, ) self._po_estimators = [] for idx in range(2): self._po_estimators.append( BasicNet( f"{name}_po_estimator_{idx}", n_units_r, binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, nonlin=nonlin, batch_norm=batch_norm, dropout=dropout, dropout_prob=dropout_prob, ) ) self._propensity_estimator = propensity_estimator def loss( self, po_pred: torch.Tensor, t_pred: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor, discrepancy: torch.Tensor, ) -> torch.Tensor: def head_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.binary_y: return nn.BCELoss()(y_pred, y_true) else: return (y_pred - y_true) ** 2 def po_loss( po_pred: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor ) -> torch.Tensor: y0_pred = po_pred[:, 0] y1_pred = po_pred[:, 1] loss0 = torch.mean((1.0 - t_true) * head_loss(y0_pred, y_true)) loss1 = torch.mean(t_true * head_loss(y1_pred, y_true)) return loss0 + loss1 def prop_loss(t_pred: torch.Tensor, t_true: torch.Tensor) -> torch.Tensor: t_pred = t_pred + EPS return nn.CrossEntropyLoss()(t_pred, t_true) return ( po_loss(po_pred, y_true, t_true) + self.prop_loss_multiplier * prop_loss(t_pred, t_true) + discrepancy ) def fit( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, ) -> "BasicDragonNet": """ Fit the treatment models. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) The features to fit to y : torch.Tensor of shape (n_samples,) or (n_samples, ) The outcome variable w: torch.Tensor of shape (n_samples,) The treatment indicator """ self.train() X = torch.Tensor(X).to(DEVICE) y = torch.Tensor(y).squeeze().to(DEVICE) w = torch.Tensor(w).squeeze().long().to(DEVICE) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w=w, val_split_prop=self.val_split_prop, seed=self.seed ) n = X.shape[0] # could be different from before due to split # calculate number of batches per epoch batch_size = self.batch_size if self.batch_size < n else n n_batches = int(np.round(n / batch_size)) if batch_size < n else 1 train_indices = np.arange(n) params = ( list(self._repr_estimator.parameters()) + list(self._po_estimators[0].parameters()) + list(self._po_estimators[1].parameters()) + list(self._propensity_estimator.parameters()) ) optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay) # training val_loss_best = LARGE_VAL patience = 0 for i in range(self.n_iter): # shuffle data for minibatches np.random.shuffle(train_indices) train_loss = [] for b in range(n_batches): optimizer.zero_grad() idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] X_next = X[idx_next] y_next = y[idx_next].squeeze() w_next = w[idx_next].squeeze() po_preds, prop_preds, discr = self._step(X_next, w_next) batch_loss = self.loss(po_preds, prop_preds, y_next, w_next, discr) batch_loss.backward() optimizer.step() train_loss.append(batch_loss.detach()) train_loss = torch.Tensor(train_loss).to(DEVICE) if self.early_stopping or i % self.n_iter_print == 0: with torch.no_grad(): po_preds, prop_preds, discr = self._step(X_val, w_val) val_loss = self.loss(po_preds, prop_preds, y_val, w_val, discr) if self.early_stopping: if val_loss_best > val_loss: val_loss_best = val_loss patience = 0 else: patience += 1 if patience > self.patience and ( (i + 1) * n_batches > self.n_iter_min ): break if i % self.n_iter_print == 0: log.info( f"[{self.name}] Epoch: {i}, current {val_string} loss: {val_loss} train_loss: {torch.mean(train_loss)}" ) return self @abc.abstractmethod def _step( self, X: torch.Tensor, w: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... def _forward(self, X: torch.Tensor) -> torch.Tensor: X = self._check_tensor(X) repr_preds = self._repr_estimator(X).squeeze() y0_preds = self._po_estimators[0](repr_preds).squeeze() y1_preds = self._po_estimators[1](repr_preds).squeeze() return torch.vstack((y0_preds, y1_preds)).T def predict( self, X: torch.Tensor, return_po: bool = False, training: bool = False ) -> torch.Tensor: """ Predict the treatment effects Parameters ---------- X: array-like of shape (n_samples, n_features) Test-sample features Returns ------- y: array-like of shape (n_samples,) """ if not training: self.eval() X = self._check_tensor(X).float() preds = self._forward(X) y0_preds = preds[:, 0] y1_preds = preds[:, 1] outcome = y1_preds - y0_preds if return_po: return outcome, y0_preds, y1_preds return outcome def _maximum_mean_discrepancy( self, X: torch.Tensor, w: torch.Tensor ) -> torch.Tensor: n = w.shape[0] n_t = torch.sum(w) X = X / torch.sqrt(torch.var(X, dim=0) + EPS) w = w.unsqueeze(dim=0) mean_control = (n / (n - n_t)) * torch.mean((1 - w).T * X, dim=0) mean_treated = (n / n_t) * torch.mean(w.T * X, dim=0) return self.penalty_disc * torch.sum((mean_treated - mean_control) ** 2) class TARNet(BasicDragonNet): """ Class implements Shalit et al (2017)'s TARNet """ def __init__( self, n_unit_in: int, binary_y: bool = False, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = 0, nonlin: str = DEFAULT_NONLIN, penalty_disc: float = DEFAULT_PENALTY_DISC, batch_norm: bool = True, dropout: bool = False, dropout_prob: float = 0.2, **kwargs: Any, ) -> None: propensity_estimator = PropensityNet( "tarnet_propensity_estimator", n_unit_in, 2, "prop", n_layers_out_prop=n_layers_out_prop, n_units_out_prop=n_units_out_prop, nonlin=nonlin, batch_norm=batch_norm, dropout_prob=dropout_prob, dropout=dropout, ).to(DEVICE) super(TARNet, self).__init__( "TARNet", n_unit_in, propensity_estimator, binary_y=binary_y, nonlin=nonlin, penalty_disc=penalty_disc, batch_norm=batch_norm, dropout=dropout, dropout_prob=dropout_prob, **kwargs, ) self.prop_loss_multiplier = 0 def _step( self, X: torch.Tensor, w: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: repr_preds = self._repr_estimator(X).squeeze() y0_preds = self._po_estimators[0](repr_preds).squeeze() y1_preds = self._po_estimators[1](repr_preds).squeeze() po_preds = torch.vstack((y0_preds, y1_preds)).T prop_preds = self._propensity_estimator(X) return po_preds, prop_preds, self._maximum_mean_discrepancy(repr_preds, w) class DragonNet(BasicDragonNet): """ Class implements a variant based on Shi et al (2019)'s DragonNet. """ def __init__( self, n_unit_in: int, binary_y: bool = False, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = 0, nonlin: str = DEFAULT_NONLIN, n_units_r: int = DEFAULT_UNITS_R, batch_norm: bool = True, dropout: bool = False, dropout_prob: float = 0.2, **kwargs: Any, ) -> None: propensity_estimator = PropensityNet( "dragonnet_propensity_estimator", n_units_r, 2, "prop", n_layers_out_prop=n_layers_out_prop, n_units_out_prop=n_units_out_prop, nonlin=nonlin, batch_norm=batch_norm, dropout=dropout, dropout_prob=dropout_prob, ).to(DEVICE) super(DragonNet, self).__init__( "DragonNet", n_unit_in, propensity_estimator, binary_y=binary_y, nonlin=nonlin, batch_norm=batch_norm, dropout=dropout, dropout_prob=dropout_prob, **kwargs, ) def _step( self, X: torch.Tensor, w: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: repr_preds = self._repr_estimator(X).squeeze() y0_preds = self._po_estimators[0](repr_preds).squeeze() y1_preds = self._po_estimators[1](repr_preds).squeeze() po_preds = torch.vstack((y0_preds, y1_preds)).T prop_preds = self._propensity_estimator(repr_preds) return po_preds, prop_preds, self._maximum_mean_discrepancy(repr_preds, w) ================================================ FILE: catenets/models/torch/slearner.py ================================================ from typing import Any, Optional import torch import catenets.logger as log from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_N_ITER, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_VAL_SPLIT, ) from catenets.models.torch.base import ( DEVICE, BaseCATEEstimator, BasicNet, PropensityNet, ) from catenets.models.torch.utils.model_utils import predict_wrapper class SLearner(BaseCATEEstimator): """ S-learner for treatment effect estimation (single learner, treatment indicator just another feature). Parameters ---------- n_unit_in: int Number of features binary_y: bool Whether the outcome is binary po_estimator: sklearn/PyTorch model, default: None Custom potential outcome model. If this parameter is set, the rest of the parameters are ignored. n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer) n_layers_out_prop: int Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Linear layer) n_units_out: int Number of hidden units in each hypothesis layer n_units_out_prop: int Number of hidden units in each propensity score hypothesis layer weight_decay: float l2 (ridge) penalty lr: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) n_iter_print: int Number of iterations after which to print updates seed: int Seed used nonlin: string, default 'elu' Nonlinearity to use in the neural net. Can be 'elu', 'relu', 'selu' or 'leaky_relu'. weighting_strategy: optional str, None Whether to include propensity head and which weightening strategy to use """ def __init__( self, n_unit_in: int, binary_y: bool, po_estimator: Any = None, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, weight_decay: float = DEFAULT_PENALTY_L2, lr: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, weighting_strategy: Optional[str] = None, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2, ) -> None: super(SLearner, self).__init__() self._weighting_strategy = weighting_strategy if po_estimator is not None: self._po_estimator = po_estimator else: self._po_estimator = BasicNet( "slearner_po_estimator", n_unit_in + 1, binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, weight_decay=weight_decay, lr=lr, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, batch_norm=batch_norm, early_stopping=early_stopping, dropout_prob=dropout_prob, dropout=dropout, ).to(DEVICE) if weighting_strategy is not None: self._propensity_estimator = PropensityNet( "slearner_prop_estimator", n_unit_in, 2, # number of treatments weighting_strategy, n_units_out_prop=n_units_out_prop, n_layers_out_prop=n_layers_out_prop, weight_decay=weight_decay, lr=lr, n_iter=n_iter, batch_size=batch_size, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, val_split_prop=val_split_prop, batch_norm=batch_norm, early_stopping=early_stopping, dropout=dropout, dropout_prob=dropout_prob, ).to(DEVICE) def fit( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, ) -> "SLearner": """ Fit treatment models. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) The features to fit to y : torch.Tensor of shape (n_samples,) or (n_samples, ) The outcome variable w: torch.Tensor of shape (n_samples,) The treatment indicator """ self.train() X = torch.Tensor(X).to(DEVICE) y = torch.Tensor(y).to(DEVICE) w = torch.Tensor(w).to(DEVICE) # add indicator as additional variable X_ext = torch.cat((X, w.reshape((-1, 1))), dim=1).to(DEVICE) if not ( hasattr(self._po_estimator, "train") or hasattr(self._po_estimator, "fit") ): raise NotImplementedError("invalid po_estimator for the slearner") if hasattr(self._po_estimator, "fit"): log.info("Fit the sklearn po_estimator") self._po_estimator.fit( X_ext.detach().cpu().numpy(), y.detach().cpu().numpy() ) return self if self._weighting_strategy is None: # fit standard S-learner log.info("Fit the PyTorch po_estimator") self._po_estimator.fit(X_ext, y) return self # use reweighting within the outcome model log.info("Fit the PyTorch po_estimator with the propensity estimator") self._propensity_estimator.fit(X, w) weights = self._propensity_estimator.get_importance_weights(X, w) self._po_estimator.fit(X_ext, y, weight=weights) return self def _create_extended_matrices(self, X: torch.Tensor) -> torch.Tensor: n = X.shape[0] X = self._check_tensor(X) # create extended matrices w_1 = torch.ones((n, 1)).to(DEVICE) w_0 = torch.zeros((n, 1)).to(DEVICE) X_ext_0 = torch.cat((X, w_0), dim=1).to(DEVICE) X_ext_1 = torch.cat((X, w_1), dim=1).to(DEVICE) return [X_ext_0, X_ext_1] def predict( self, X: torch.Tensor, return_po: bool = False, training: bool = False ) -> torch.Tensor: """ Predict treatment effects and potential outcomes Parameters ---------- X: array-like of shape (n_samples, n_features) Test-sample features Returns ------- y: array-like of shape (n_samples,) """ if not training: self.eval() X = self._check_tensor(X).float() X_ext = self._create_extended_matrices(X) y = [] for ext_mat in X_ext: y.append(predict_wrapper(self._po_estimator, ext_mat).to(DEVICE)) outcome = y[1] - y[0] if return_po: return outcome, y[0], y[1] return outcome ================================================ FILE: catenets/models/torch/snet.py ================================================ from typing import Tuple import numpy as np import torch from torch import nn import catenets.logger as log from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_LAYERS_R, DEFAULT_N_ITER, DEFAULT_N_ITER_MIN, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PATIENCE, DEFAULT_PENALTY_DISC, DEFAULT_PENALTY_L2, DEFAULT_PENALTY_ORTHOGONAL, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_UNITS_R_BIG_S, DEFAULT_UNITS_R_SMALL_S, DEFAULT_VAL_SPLIT, LARGE_VAL, ) from catenets.models.torch.base import ( DEVICE, BaseCATEEstimator, BasicNet, PropensityNet, RepresentationNet, ) from catenets.models.torch.utils.model_utils import make_val_split EPS = 1e-8 class SNet(BaseCATEEstimator): """ Class implements SNet as discussed in Curth & van der Schaar (2021). Additionally to the version implemented in the AISTATS paper, we also include an implementation that does not have propensity heads (set with_prop=False) Parameters ---------- n_unit_in: int Number of features binary_y: bool, default False Whether the outcome is binary n_layers_r: int Number of shared & private representation layers before the hypothesis layers. n_units_r: int Number of hidden units in representation shared before the hypothesis layer. n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer) n_layers_out_prop: int Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Linear layer) n_units_out: int Number of hidden units in each hypothesis layer n_units_out_prop: int Number of hidden units in each propensity score hypothesis layer n_units_r_small: int Number of hidden units in each PO functions private representation weight_decay: float l2 (ridge) penalty lr: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) patience: int Number of iterations to wait before early stopping after decrease in validation loss n_iter_min: int Minimum number of iterations to go through before starting early stopping n_iter_print: int Number of iterations after which to print updates seed: int Seed used nonlin: string, default 'elu' Nonlinearity to use in the neural net. Can be 'elu', 'relu', 'selu' or 'leaky_relu'. penalty_disc: float, default zero Discrepancy penalty. Defaults to zero as this feature is not tested. clipping_value: int, default 1 Gradients clipping value """ def __init__( self, n_unit_in: int, binary_y: bool = False, n_layers_r: int = DEFAULT_LAYERS_R, n_units_r: int = DEFAULT_UNITS_R_BIG_S, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S, n_units_out: int = DEFAULT_UNITS_OUT, n_units_out_prop: int = DEFAULT_UNITS_OUT, n_layers_out_prop: int = DEFAULT_LAYERS_OUT, weight_decay: float = DEFAULT_PENALTY_L2, penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL, penalty_disc: float = DEFAULT_PENALTY_DISC, lr: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, n_iter_min: int = DEFAULT_N_ITER_MIN, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, ortho_reg_type: str = "abs", patience: int = DEFAULT_PATIENCE, clipping_value: int = 1, batch_norm: bool = True, with_prop: bool = True, early_stopping: bool = True, prop_loss_multiplier: float = 1, dropout: bool = False, dropout_prob: float = 0.2, ) -> None: super(SNet, self).__init__() self.n_unit_in = n_unit_in self.binary_y = binary_y self.penalty_orthogonal = penalty_orthogonal self.penalty_disc = penalty_disc self.n_iter = n_iter self.batch_size = batch_size self.val_split_prop = val_split_prop self.n_iter_print = n_iter_print self.seed = seed self.ortho_reg_type = ortho_reg_type self.clipping_value = clipping_value self.patience = patience self.with_prop = with_prop self.early_stopping = early_stopping self.n_iter_min = n_iter_min self.prop_loss_multiplier = prop_loss_multiplier self.dropout = dropout self.dropout_prob = dropout_prob self._reps_mu0 = RepresentationNet( n_unit_in, n_units=n_units_r_small, n_layers=n_layers_r, nonlin=nonlin, batch_norm=batch_norm, ) self._reps_mu1 = RepresentationNet( n_unit_in, n_units=n_units_r_small, n_layers=n_layers_r, nonlin=nonlin, batch_norm=batch_norm, ) self._po_estimators = [] if self.with_prop: self._reps_c = RepresentationNet( n_unit_in, n_units=n_units_r, n_layers=n_layers_r, nonlin=nonlin, batch_norm=batch_norm, ) self._reps_o = RepresentationNet( n_unit_in, n_units=n_units_r_small, n_layers=n_layers_r, nonlin=nonlin, batch_norm=batch_norm, ) self._reps_prop = RepresentationNet( n_unit_in, n_units=n_units_r, n_layers=n_layers_r, nonlin=nonlin, batch_norm=batch_norm, ) for idx in range(2): self._po_estimators.append( BasicNet( f"snet_po_estimator_{idx}", n_units_r + n_units_r_small + n_units_r_small, # (reps_c, reps_o, reps_mu{idx}) binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, nonlin=nonlin, batch_norm=batch_norm, dropout_prob=dropout_prob, dropout=dropout, ) ) self._propensity_estimator = PropensityNet( "snet_propensity_estimator", n_units_r + n_units_r, # reps_c, reps_w 2, "prop", n_layers_out_prop=n_layers_out_prop, n_units_out_prop=n_units_out_prop, nonlin=nonlin, batch_norm=batch_norm, dropout=dropout, dropout_prob=dropout_prob, ).to(DEVICE) params = ( list(self._reps_c.parameters()) + list(self._reps_o.parameters()) + list(self._reps_mu0.parameters()) + list(self._reps_mu1.parameters()) + list(self._reps_prop.parameters()) + list(self._po_estimators[0].parameters()) + list(self._po_estimators[1].parameters()) + list(self._propensity_estimator.parameters()) ) else: self._reps_o = RepresentationNet( n_unit_in, n_units=n_units_r, n_layers=n_layers_r, nonlin=nonlin, batch_norm=batch_norm, ) for idx in range(2): self._po_estimators.append( BasicNet( f"snet_po_estimator_{idx}", n_units_r + n_units_r_small, # (reps_o, reps_mu{idx}) binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, nonlin=nonlin, batch_norm=batch_norm, ) ) params = ( list(self._reps_o.parameters()) + list(self._reps_mu0.parameters()) + list(self._reps_mu1.parameters()) + list(self._po_estimators[0].parameters()) + list(self._po_estimators[1].parameters()) ) self.optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) def loss( self, y0_pred: torch.Tensor, y1_pred: torch.Tensor, t_pred: torch.Tensor, discrepancy: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor, ) -> torch.Tensor: def head_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.binary_y: return nn.BCELoss()(y_pred, y_true) else: return (y_pred - y_true) ** 2 def po_loss( y0_pred: torch.Tensor, y1_pred: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor, ) -> torch.Tensor: loss0 = torch.mean((1.0 - t_true) * head_loss(y0_pred, y_true)) loss1 = torch.mean(t_true * head_loss(y1_pred, y_true)) return loss0 + loss1 def prop_loss( t_pred: torch.Tensor, t_true: torch.Tensor, ) -> torch.Tensor: if self.with_prop: t_pred = t_pred + EPS return nn.CrossEntropyLoss()(t_pred, t_true) else: return 0 return ( po_loss(y0_pred, y1_pred, y_true, t_true) + self.prop_loss_multiplier * prop_loss(t_pred, t_true) + discrepancy + self._ortho_reg() ) def fit( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, ) -> "SNet": """ Fit treatment models. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) The features to fit to y : torch.Tensor of shape (n_samples,) or (n_samples, ) The outcome variable w: torch.Tensor of shape (n_samples,) The treatment indicator """ self.train() X = torch.Tensor(X).to(DEVICE) y = torch.Tensor(y).squeeze().to(DEVICE) w = torch.Tensor(w).squeeze().long().to(DEVICE) X, y, w, X_val, y_val, w_val, val_string = make_val_split( X, y, w=w, val_split_prop=self.val_split_prop, seed=self.seed ) n = X.shape[0] # could be different from before due to split # calculate number of batches per epoch batch_size = self.batch_size if self.batch_size < n else n n_batches = int(np.round(n / batch_size)) if batch_size < n else 1 train_indices = np.arange(n) # training val_loss_best = LARGE_VAL patience = 0 for i in range(self.n_iter): # shuffle data for minibatches np.random.shuffle(train_indices) train_loss = [] for b in range(n_batches): self.optimizer.zero_grad() idx_next = train_indices[ (b * batch_size) : min((b + 1) * batch_size, n - 1) ] X_next = X[idx_next] y_next = y[idx_next].squeeze() w_next = w[idx_next].squeeze() y0_preds, y1_preds, prop_preds, discrepancy = self._step(X_next, w_next) batch_loss = self.loss( y0_preds, y1_preds, prop_preds, discrepancy, y_next, w_next ) batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value) self.optimizer.step() train_loss.append(batch_loss.detach()) train_loss = torch.Tensor(train_loss).to(DEVICE) if self.early_stopping or i % self.n_iter_print == 0: with torch.no_grad(): y0_preds, y1_preds, prop_preds, discrepancy = self._step( X_val, w_val ) val_loss = ( self.loss( y0_preds, y1_preds, prop_preds, discrepancy, y_val, w_val ) .detach() .cpu() ) if self.early_stopping: if val_loss_best > val_loss: val_loss_best = val_loss patience = 0 else: patience += 1 if patience > self.patience and ( (i + 1) * n_batches > self.n_iter_min ): break if i % self.n_iter_print == 0: log.info( f"[SNet] Epoch: {i}, current {val_string} loss: {val_loss} train_loss: {torch.mean(train_loss)}" ) return self def _ortho_reg(self) -> float: def _get_absolute_rowsums(mat: torch) -> torch.Tensor: return torch.sum(torch.abs(mat), dim=0) def _get_cos_reg( params_0: torch.Tensor, params_1: torch.Tensor, normalize: bool = False ) -> torch.Tensor: if normalize: params_0 = params_0 / torch.linalg.norm(params_0, dim=0) params_1 = params_1 / torch.linalg.norm(params_1, dim=0) x_min = min(params_0.shape[0], params_1.shape[0]) y_min = min(params_0.shape[1], params_1.shape[1]) return ( torch.linalg.norm( params_0[:x_min, :y_min] * params_1[:x_min, :y_min], "fro" ) ** 2 ) reps_o_params = self._reps_o.model[0].weight reps_mu0_params = self._reps_mu0.model[0].weight reps_mu1_params = self._reps_mu1.model[0].weight if self.with_prop: reps_c_params = self._reps_c.model[0].weight reps_prop_params = self._reps_prop.model[0].weight # define ortho-reg function if self.ortho_reg_type == "abs": col_o = _get_absolute_rowsums(reps_o_params) col_mu0 = _get_absolute_rowsums(reps_mu0_params) col_mu1 = _get_absolute_rowsums(reps_mu1_params) if self.with_prop: col_c = _get_absolute_rowsums(reps_c_params) col_w = _get_absolute_rowsums(reps_prop_params) return self.penalty_orthogonal * torch.sum( col_c * col_o + col_c * col_w + col_c * col_mu1 + col_c * col_mu0 + col_w * col_o + col_mu0 * col_o + col_o * col_mu1 + col_mu0 * col_mu1 + col_mu0 * col_w + col_w * col_mu1 ) else: return self.penalty_orthogonal * torch.sum( +col_mu0 * col_o + col_o * col_mu1 + col_mu0 * col_mu1 ) elif self.ortho_reg_type == "fro": if self.with_prop: return self.penalty_orthogonal * ( _get_cos_reg(reps_c_params, reps_o_params) + _get_cos_reg(reps_c_params, reps_mu0_params) + _get_cos_reg(reps_c_params, reps_mu1_params) + _get_cos_reg(reps_c_params, reps_prop_params) + _get_cos_reg(reps_o_params, reps_mu0_params) + _get_cos_reg(reps_o_params, reps_mu1_params) + _get_cos_reg(reps_o_params, reps_prop_params) + _get_cos_reg(reps_mu0_params, reps_mu1_params) + _get_cos_reg(reps_mu0_params, reps_prop_params) + _get_cos_reg(reps_mu1_params, reps_prop_params) ) else: return self.penalty_orthogonal * ( +_get_cos_reg(reps_o_params, reps_mu0_params) + _get_cos_reg(reps_o_params, reps_mu1_params) + _get_cos_reg(reps_mu0_params, reps_mu1_params) ) else: raise ValueError(f"Invalid orth_reg_typ {self.ortho_reg_type}") def _maximum_mean_discrepancy( self, X: torch.Tensor, w: torch.Tensor ) -> torch.Tensor: n = w.shape[0] n_t = torch.sum(w) X = X / torch.sqrt(torch.var(X, dim=0) + EPS) w = w.unsqueeze(dim=0) mean_control = (n / (n - n_t)) * torch.mean((1 - w).T * X, dim=0) mean_treated = (n / n_t) * torch.mean(w.T * X, dim=0) return torch.sum((mean_treated - mean_control) ** 2) def _step( self, X: torch.Tensor, w: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: y0_preds, y1_preds, prop_preds, reps_o = self._forward(X) discrepancy = self.penalty_disc * self._maximum_mean_discrepancy(reps_o, w) return y0_preds, y1_preds, prop_preds, discrepancy def _forward( self, X: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: reps_o = self._reps_o(X) reps_mu0 = self._reps_mu0(X) reps_mu1 = self._reps_mu1(X) if self.with_prop: reps_c = self._reps_c(X) reps_w = self._reps_prop(X) reps_po_0 = torch.cat((reps_c, reps_o, reps_mu0), dim=1) reps_po_1 = torch.cat((reps_c, reps_o, reps_mu1), dim=1) reps_w = torch.cat((reps_c, reps_w), dim=1) prop_preds = self._propensity_estimator(reps_w) else: reps_po_0 = torch.cat((reps_o, reps_mu0), dim=1) reps_po_1 = torch.cat((reps_o, reps_mu1), dim=1) prop_preds = 0.5 * torch.ones(len(X)) # no probability predictions y0_preds = self._po_estimators[0](reps_po_0).squeeze() y1_preds = self._po_estimators[1](reps_po_1).squeeze() return y0_preds, y1_preds, prop_preds, reps_o def predict( self, X: torch.Tensor, return_po: bool = False, training: bool = False ) -> torch.Tensor: """ Predict treatment effects and potential outcomes Parameters ---------- X: array-like of shape (n_samples, n_features) Test-sample features Returns ------- y: array-like of shape (n_samples,) """ if not training: self.eval() X = self._check_tensor(X).float() y0_preds, y1_preds, _, _ = self._forward(X) outcome = y1_preds - y0_preds if return_po: return outcome, y0_preds, y1_preds return outcome ================================================ FILE: catenets/models/torch/tlearner.py ================================================ import copy from typing import Any import torch from catenets.models.constants import ( DEFAULT_BATCH_SIZE, DEFAULT_LAYERS_OUT, DEFAULT_N_ITER, DEFAULT_N_ITER_PRINT, DEFAULT_NONLIN, DEFAULT_PENALTY_L2, DEFAULT_SEED, DEFAULT_STEP_SIZE, DEFAULT_UNITS_OUT, DEFAULT_VAL_SPLIT, ) from catenets.models.torch.base import DEVICE, BaseCATEEstimator, BasicNet from catenets.models.torch.utils.model_utils import predict_wrapper, train_wrapper class TLearner(BaseCATEEstimator): """ TLearner class -- two separate functions learned for each Potential Outcome function Parameters ---------- n_unit_in: int Number of features binary_y: bool, default False Whether the outcome is binary po_estimator: sklearn/PyTorch model, default: None Custom plugin model. If this parameter is set, the rest of the parameters are ignored. n_layers_out: int Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer) n_units_out: int Number of hidden units in each hypothesis layer weight_decay: float l2 (ridge) penalty lr: float learning rate for optimizer n_iter: int Maximum number of iterations batch_size: int Batch size val_split_prop: float Proportion of samples used for validation split (can be 0) n_iter_print: int Number of iterations after which to print updates seed: int Seed used nonlin: string, default 'elu' Nonlinearity to use in the neural net. Cat be 'elu', 'relu', 'selu' or 'leaky_relu'. """ def __init__( self, n_unit_in: int, binary_y: bool, po_estimator: Any = None, n_layers_out: int = DEFAULT_LAYERS_OUT, n_units_out: int = DEFAULT_UNITS_OUT, weight_decay: float = DEFAULT_PENALTY_L2, lr: float = DEFAULT_STEP_SIZE, n_iter: int = DEFAULT_N_ITER, batch_size: int = DEFAULT_BATCH_SIZE, val_split_prop: float = DEFAULT_VAL_SPLIT, n_iter_print: int = DEFAULT_N_ITER_PRINT, seed: int = DEFAULT_SEED, nonlin: str = DEFAULT_NONLIN, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2, ) -> None: super(TLearner, self).__init__() self._plug_in: Any = [] plugins = [f"tlearner_po_estimator_{i}" for i in range(2)] if po_estimator is not None: for plugin in plugins: self._plug_in.append(copy.deepcopy(po_estimator)) else: for plugin in plugins: self._plug_in.append( BasicNet( plugin, n_unit_in, binary_y=binary_y, n_layers_out=n_layers_out, n_units_out=n_units_out, weight_decay=weight_decay, lr=lr, n_iter=n_iter, batch_size=batch_size, val_split_prop=val_split_prop, n_iter_print=n_iter_print, seed=seed, nonlin=nonlin, batch_norm=batch_norm, early_stopping=early_stopping, dropout_prob=dropout_prob, dropout=dropout, ).to(DEVICE), ) def predict( self, X: torch.Tensor, return_po: bool = False, training: bool = False ) -> torch.Tensor: """ Predict treatment effects and potential outcomes Parameters ---------- X: torch.Tensor of shape (n_samples, n_features) Test-sample features return_po: bool Return potential outcomes too Returns ------- y: torch.Tensor of shape (n_samples,) """ if not training: self.eval() X = self._check_tensor(X).float() y_hat = [] for widx, plugin in enumerate(self._plug_in): y_hat.append(predict_wrapper(plugin, X)) outcome = y_hat[1] - y_hat[0] if return_po: return outcome, y_hat[0], y_hat[1] return outcome def fit( self, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor, ) -> "TLearner": """ Train plug-in models. Parameters ---------- X : torch.Tensor (n_samples, n_features) The features to fit to y : torch.Tensor (n_samples,) or (n_samples, ) The outcome variable w: torch.Tensor (n_samples,) The treatment indicator """ self.train() X = torch.Tensor(X).to(DEVICE) y = torch.Tensor(y).to(DEVICE) w = torch.Tensor(w).to(DEVICE) for widx, plugin in enumerate(self._plug_in): train_wrapper(plugin, X[w == widx], y[w == widx]) return self ================================================ FILE: catenets/models/torch/utils/__init__.py ================================================ ================================================ FILE: catenets/models/torch/utils/decorators.py ================================================ import time from typing import Any, Callable import torch import catenets.logger as log def check_input_train(func: Callable) -> Callable: """Decorator used for checking training params. Args: func: the function to be benchmarked. Returns: Callable: the decorator """ def wrapper(self: Any, X: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> Any: w = torch.Tensor(w) if not ((w == 0) | (w == 1)).all(): raise ValueError("W should be binary") return func(self, X, y, w) return wrapper def benchmark(func: Callable) -> Callable: """Decorator used for function duration benchmarking. It is active only with DEBUG loglevel. Args: func: the function to be benchmarked. Returns: Callable: the decorator """ def wrapper(*args: Any, **kwargs: Any) -> Any: start = time.time() res = func(*args, **kwargs) end = time.time() log.debug(f"{func.__qualname__} took {round(end - start, 4)} seconds") return res return wrapper ================================================ FILE: catenets/models/torch/utils/model_utils.py ================================================ """ Model utils shared across different nets """ # Author: Alicia Curth, Bogdan Cebere from typing import Any, Optional import torch from sklearn.model_selection import train_test_split import catenets.logger as log from catenets.models.constants import DEFAULT_SEED, DEFAULT_VAL_SPLIT DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") TRAIN_STRING = "training" VALIDATION_STRING = "validation" def make_val_split( X: torch.Tensor, y: torch.Tensor, w: Optional[torch.Tensor] = None, val_split_prop: float = DEFAULT_VAL_SPLIT, seed: int = DEFAULT_SEED, stratify_w: bool = True, ) -> Any: if val_split_prop == 0: # return original data if w is None: return X, y, X, y, TRAIN_STRING return X, y, w, X, y, w, TRAIN_STRING X = X.cpu() y = y.cpu() # make actual split if w is None: X_t, X_val, y_t, y_val = train_test_split( X, y, test_size=val_split_prop, random_state=seed, shuffle=True ) return ( X_t.to(DEVICE), y_t.to(DEVICE), X_val.to(DEVICE), y_val.to(DEVICE), VALIDATION_STRING, ) w = w.cpu() if stratify_w: # split to stratify by group X_t, X_val, y_t, y_val, w_t, w_val = train_test_split( X, y, w, test_size=val_split_prop, random_state=seed, stratify=w, shuffle=True, ) else: X_t, X_val, y_t, y_val, w_t, w_val = train_test_split( X, y, w, test_size=val_split_prop, random_state=seed, shuffle=True ) return ( X_t.to(DEVICE), y_t.to(DEVICE), w_t.to(DEVICE), X_val.to(DEVICE), y_val.to(DEVICE), w_val.to(DEVICE), VALIDATION_STRING, ) def train_wrapper( estimator: Any, X: torch.Tensor, y: torch.Tensor, **kwargs: Any, ) -> None: if hasattr(estimator, "train"): log.debug(f"Train PyTorch network {estimator}") estimator.fit(X, y, **kwargs) elif hasattr(estimator, "fit"): log.debug(f"Train sklearn estimator {estimator}") estimator.fit(X.detach().cpu().numpy(), y.detach().cpu().numpy()) else: raise NotImplementedError(f"Invalid estimator for the {estimator}") def predict_wrapper(estimator: Any, X: torch.Tensor) -> torch.Tensor: if hasattr(estimator, "forward"): return estimator(X) elif hasattr(estimator, "predict_proba"): X_np = X.detach().cpu().numpy() no_event_proba = estimator.predict_proba(X_np)[:, 0] # no event probability return torch.Tensor(no_event_proba) elif hasattr(estimator, "predict"): X_np = X.detach().cpu().numpy() no_event_proba = estimator.predict(X_np) return torch.Tensor(no_event_proba) else: raise NotImplementedError(f"Invalid estimator for the {estimator}") ================================================ FILE: catenets/models/torch/utils/transformations.py ================================================ """ Unbiased Transformations for CATE """ # Author: Alicia Curth from typing import Optional import torch def dr_transformation_cate( y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> torch.Tensor: """ Transforms data to efficient influence function/aipw pseudo-outcome for CATE estimation Parameters ---------- y : array-like of shape (n_samples,) or (n_samples, ) The observed outcome variable w: array-like of shape (n_samples,) The observed treatment indicator p: array-like of shape (n_samples,) The treatment propensity, estimated or known. Can be None, then p=0.5 is assumed mu_0: array-like of shape (n_samples,) Estimated or known potential outcome mean of the control group mu_1: array-like of shape (n_samples,) Estimated or known potential outcome mean of the treatment group Returns ------- d_hat: EIF transformation for CATE """ if p is None: # assume equal p = torch.full(y.shape, 0.5) EPS = 1e-7 w_1 = w / (p + EPS) w_0 = (1 - w) / (EPS + 1 - p) return (w_1 - w_0) * y + ((1 - w_1) * mu_1 - (1 - w_0) * mu_0) def pw_transformation_cate( y: torch.Tensor, w: torch.Tensor, p: Optional[torch.Tensor] = None, mu_0: Optional[torch.Tensor] = None, mu_1: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Transform data to Horvitz-Thompson transformation for CATE Parameters ---------- y : array-like of shape (n_samples,) or (n_samples, ) The observed outcome variable w: array-like of shape (n_samples,) The observed treatment indicator p: array-like of shape (n_samples,) The treatment propensity, estimated or known. Can be None, then p=0.5 is assumed mu_0: array-like of shape (n_samples,) Estimated or known potential outcome mean of the control group. Placeholder, not used. mu_1: array-like of shape (n_samples,) Estimated or known potential outcome mean of the treatment group. Placeholder, not used. Returns ------- res: array-like of shape (n_samples,) Horvitz-Thompson transformed data """ if p is None: # assume equal propensities p = torch.full(y.shape, 0.5) return (w / p - (1 - w) / (1 - p)) * y def ra_transformation_cate( y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu_0: torch.Tensor, mu_1: torch.Tensor, ) -> torch.Tensor: """ Transform data to regression adjustment for CATE Parameters ---------- y : array-like of shape (n_samples,) or (n_samples, ) The observed outcome variable w: array-like of shape (n_samples,) The observed treatment indicator p: array-like of shape (n_samples,) Placeholder, not used. The treatment propensity, estimated or known. mu_0: array-like of shape (n_samples,) Estimated or known potential outcome mean of the control group mu_1: array-like of shape (n_samples,) Estimated or known potential outcome mean of the treatment group Returns ------- res: array-like of shape (n_samples,) Regression adjusted transformation """ return w * (y - mu_0) + (1 - w) * (mu_1 - y) def u_transformation_cate( y: torch.Tensor, w: torch.Tensor, p: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """ Transform data to U-transformation (described in Kuenzel et al, 2019, Nie & Wager, 2017) which underlies both R-learner and U-learner Parameters ---------- y : array-like of shape (n_samples,) or (n_samples, ) The observed outcome variable w: array-like of shape (n_samples,) The observed treatment indicator p: array-like of shape (n_samples,) Placeholder, not used. The treatment propensity, estimated or known. mu_0: array-like of shape (n_samples,) Estimated or known potential outcome mean of the control group mu_1: array-like of shape (n_samples,) Estimated or known potential outcome mean of the treatment group Returns ------- res: array-like of shape (n_samples,) Regression adjusted transformation """ if p is None: # assume equal propensities p = torch.full(y.shape, 0.5) return (y - mu) / (w - p) ================================================ FILE: catenets/models/torch/utils/weight_utils.py ================================================ """ Implement different reweighting/balancing strategies as in Li et al (2018) """ # Author: Alicia Curth from typing import Optional import torch IPW_NAME = "ipw" TRUNC_IPW_NAME = "truncipw" OVERLAP_NAME = "overlap" MATCHING_NAME = "match" PROP = "prop" ONE_MINUS_PROP = "1-prop" ALL_WEIGHTING_STRATEGIES = [ IPW_NAME, TRUNC_IPW_NAME, OVERLAP_NAME, MATCHING_NAME, PROP, ONE_MINUS_PROP, ] def compute_importance_weights( propensity: torch.Tensor, w: torch.Tensor, weighting_strategy: str, weight_args: Optional[dict] = None, ) -> torch.Tensor: if weighting_strategy not in ALL_WEIGHTING_STRATEGIES: raise ValueError( f"weighting_strategy should be in {ALL_WEIGHTING_STRATEGIES}" f"You passed {weighting_strategy}" ) if weight_args is None: weight_args = {} if weighting_strategy == PROP: return propensity elif weighting_strategy == ONE_MINUS_PROP: return 1 - propensity elif weighting_strategy == IPW_NAME: return compute_ipw(propensity, w) elif weighting_strategy == TRUNC_IPW_NAME: return compute_trunc_ipw(propensity, w, **weight_args) elif weighting_strategy == OVERLAP_NAME: return compute_overlap_weights(propensity, w) elif weighting_strategy == MATCHING_NAME: return compute_matching_weights(propensity, w) def compute_ipw(propensity: torch.Tensor, w: torch.Tensor) -> torch.Tensor: p_hat = torch.mean(w) return w * p_hat / propensity + (1 - w) * (1 - p_hat) / (1 - propensity) def compute_trunc_ipw( propensity: torch.Tensor, w: torch.Tensor, cutoff: float = 0.05 ) -> torch.Tensor: ipw = compute_ipw(propensity, w) return torch.where((propensity > cutoff) & (propensity < 1 - cutoff), ipw, 0) # TODO check normalizing these weights def compute_matching_weights(propensity: torch.Tensor, w: torch.Tensor) -> torch.Tensor: ipw = compute_ipw(propensity, w) return torch.minimum(ipw, 1 - ipw) * ipw def compute_overlap_weights(propensity: torch.Tensor, w: torch.Tensor) -> torch.Tensor: ipw = compute_ipw(propensity, w) return propensity * (1 - propensity) * ipw ================================================ FILE: catenets/version.py ================================================ __version__ = "0.2.3" ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os # import sys # sys.path.insert(0, os.path.abspath('.')) import datetime import os import shutil import subprocess import sys import sphinx_rtd_theme sys.path.insert(0, os.path.abspath("..")) subprocess.run( [ "sphinx-apidoc", "--ext-autodoc", "--ext-doctest", "--ext-mathjax", "--ext-viewcode", "-e", "-T", "-M", "-F", "-P", "-f", "-o", "generated", "../catenets/", ] ) # -- Project information ----------------------------------------------------- now = datetime.datetime.now() project = "CATENets" author = "Alicia Curth" copyright = f"{now.year}, {author}" # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon", "m2r2", ] autodoc_default_options = { "members": True, "inherited-members": False, "inherit_docstrings": False, } add_module_names = False autosummary_generate = True # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] ================================================ FILE: docs/datasets.rst ================================================ Datasets ========================= Dataloaders for datasets used for experiments. .. toctree:: :glob: :maxdepth: 2 IHDP dataset Twins dataset ACIC dataset Helpers ================================================ FILE: docs/index.rst ================================================ Welcome to CATENets's documentation! ==================================== .. mdinclude:: ../README.md API documentation ================= JAX models ========== .. toctree:: :glob: :maxdepth: 2 jax_models.rst PyTorch models ============== .. toctree:: :glob: :maxdepth: 2 torch_models.rst Datasets ======== .. toctree:: :glob: :maxdepth: 2 datasets.rst ================================================ FILE: docs/jax_models.rst ================================================ JAX models ========================= JAX-based CATE estimators .. toctree:: :glob: :maxdepth: 2 T-Learners R-Learners X-Learners Pseudo-Outcome Nets Representation Nets Disentangled Nets S-Nets FlexTENet OffsetNet ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/requirements.txt ================================================ autodoc bandit black catboost flake8 gdown jax>=0.3.16 jaxlib>=0.3.14; sys_platform != 'win32' jupyter loguru>=0.5.3 m2r2 myst-parser notebook numpy>=1.20 pandas>=1.3 pre-commit pytest>=6.2.4 pytest pytest pytest-cov requests scikit_learn>=0.24.2 scipy>=1.2 setuptools sklearn sphinx-autopackagesummary sphinx-rtd-theme sphinxcontrib-napoleon torch>=1.9 xgboost ================================================ FILE: docs/torch_models.rst ================================================ PyTorch models ========================= PyTorch-based CATE estimators .. toctree:: :glob: :maxdepth: 2 T-Learners S-Learners Pseudo-Outcome Nets Representation Nets S-Nets ================================================ FILE: experiments/__init__.py ================================================ ================================================ FILE: experiments/experiments_AISTATS21/ihdp_experiments.py ================================================ """ Script to run experiments on Johansson's IHDP dataset (retrieved via https://www.fredjo.com/) """ # Author: Alicia Curth import csv import os from pathlib import Path from typing import Optional, Union from sklearn import clone import catenets.logger as log from catenets.datasets.dataset_ihdp import get_one_data_set, load_raw, prepare_ihdp_data from catenets.experiment_utils.base import eval_root_mse, get_model_set from catenets.models.jax import PSEUDOOUT_NAME, PseudoOutcomeNet from catenets.models.jax.transformation_utils import RA_TRANSFORMATION # Some constants DATA_DIR = Path("catenets/datasets/data/") RESULT_DIR = Path("results/experiments_AISTATS21/ihdp/") SEP = "_" # Hyperparameters for experiments on IHDP LAYERS_OUT = 2 LAYERS_R = 3 PENALTY_L2 = 0.01 / 100 PENALTY_ORTHOGONAL_IHDP = 0 MODEL_PARAMS = { "n_layers_out": LAYERS_OUT, "n_layers_r": LAYERS_R, "penalty_l2": PENALTY_L2, "penalty_orthogonal": PENALTY_ORTHOGONAL_IHDP, "n_layers_out_t": LAYERS_OUT, "n_layers_r_t": LAYERS_R, "penalty_l2_t": PENALTY_L2, } # get basic models ALL_MODELS_IHDP = get_model_set(model_selection="all", model_params=MODEL_PARAMS) COMBINED_MODELS_IHDP = { PSEUDOOUT_NAME + SEP + RA_TRANSFORMATION + SEP + "S2": PseudoOutcomeNet( n_layers_r=LAYERS_R, n_layers_out=LAYERS_OUT, penalty_l2=PENALTY_L2, n_layers_r_t=LAYERS_R, n_layers_out_t=LAYERS_OUT, penalty_l2_t=PENALTY_L2, transformation=RA_TRANSFORMATION, first_stage_strategy="S2", ) } FULL_MODEL_SET_IHDP = dict(**ALL_MODELS_IHDP, **COMBINED_MODELS_IHDP) def do_ihdp_experiments( n_exp: Union[int, list] = 100, file_name: str = "ihdp_results_scaled", model_params: Optional[dict] = None, scale_cate: bool = True, models: Union[list, dict, str, None] = None, ) -> None: if models is None: models = FULL_MODEL_SET_IHDP elif isinstance(models, (list, str)): models = get_model_set(models) # make path if not os.path.exists(RESULT_DIR): os.makedirs(RESULT_DIR) # get file to write in out_file = open(RESULT_DIR / (file_name + ".csv"), "w", buffering=1) writer = csv.writer(out_file) header = [name + "_in" for name in models.keys()] + [ name + "_out" for name in models.keys() ] writer.writerow(header) # get data data_train, data_test = load_raw(DATA_DIR) if isinstance(n_exp, int): experiment_loop = list(range(1, n_exp + 1)) elif isinstance(n_exp, list): experiment_loop = n_exp else: raise ValueError("n_exp should be either an integer or a list of integers.") for i_exp in experiment_loop: pehe_in = [] pehe_out = [] # get data data_exp = get_one_data_set(data_train, i_exp=i_exp, get_po=True) data_exp_test = get_one_data_set(data_test, i_exp=i_exp, get_po=True) X, y, w, cate_true_in, X_t, cate_true_out = prepare_ihdp_data( data_exp, data_exp_test, rescale=scale_cate ) for model_name, estimator in models.items(): log.info(f"Experiment {i_exp} with {model_name}") estimator_temp = clone(estimator) if model_params is not None: estimator_temp.set_params(**model_params) # fit estimator estimator_temp.fit(X=X, y=y, w=w) cate_pred_in = estimator_temp.predict(X, return_po=False) cate_pred_out = estimator_temp.predict(X_t, return_po=False) pehe_in.append(eval_root_mse(cate_pred_in, cate_true_in)) pehe_out.append(eval_root_mse(cate_pred_out, cate_true_out)) writer.writerow(pehe_in + pehe_out) out_file.close() ================================================ FILE: experiments/experiments_AISTATS21/simulations_AISTATS.py ================================================ """ Script to generate synthetic simulations in AISTATS paper """ # Author: Alicia Curth import csv import os from typing import Any, Optional, Union from sklearn import clone import catenets.logger as log from catenets.experiment_utils.base import eval_root_mse, get_model_set from catenets.experiment_utils.simulation_utils import simulate_treatment_setup from catenets.models.jax import PSEUDOOUT_NAME, PseudoOutcomeNet from catenets.models.jax.pseudo_outcome_nets import S1_STRATEGY, S_STRATEGY from catenets.models.jax.snet import DEFAULT_UNITS_R_BIG_S, DEFAULT_UNITS_R_SMALL_S from catenets.models.jax.transformation_utils import ( DR_TRANSFORMATION, RA_TRANSFORMATION, ) # some constants RESULT_DIR = "results/experiments_AISTATS21/simulations/" CSV_STRING = ".csv" SEP = "_" # hyperparameters for experiments LAYERS_OUT = 2 LAYERS_R = 3 PENALTY_L2 = 0.01 / 100 PENALTY_ORTHOGONAL = 1 / 100 MODEL_PARAMS_AISTATS = { "n_layers_out": LAYERS_OUT, "n_layers_r": LAYERS_R, "penalty_l2": PENALTY_L2, "penalty_orthogonal": PENALTY_ORTHOGONAL, "n_layers_out_t": LAYERS_OUT, "n_layers_r_t": LAYERS_R, "penalty_l2_t": PENALTY_L2, } # get basic models ALL_MODELS_AISTATS = get_model_set( model_selection="all", model_params=MODEL_PARAMS_AISTATS ) # model-twostep combinations COMBINED_MODELS = { PSEUDOOUT_NAME + SEP + DR_TRANSFORMATION + SEP + S_STRATEGY: PseudoOutcomeNet( transformation=DR_TRANSFORMATION, first_stage_strategy=S_STRATEGY, n_units_r=DEFAULT_UNITS_R_BIG_S, n_layers_out=LAYERS_OUT, n_layers_r=LAYERS_R, penalty_l2_t=PENALTY_L2, penalty_l2=PENALTY_L2, n_layers_out_t=LAYERS_OUT, first_stage_args={ "n_units_r_small": DEFAULT_UNITS_R_SMALL_S, "penalty_orthogonal": PENALTY_ORTHOGONAL, }, ), PSEUDOOUT_NAME + SEP + RA_TRANSFORMATION + SEP + S_STRATEGY: PseudoOutcomeNet( transformation=RA_TRANSFORMATION, first_stage_strategy=S_STRATEGY, n_units_r=DEFAULT_UNITS_R_BIG_S, n_layers_out=LAYERS_OUT, n_layers_r=LAYERS_R, penalty_l2_t=PENALTY_L2, penalty_l2=PENALTY_L2, n_layers_out_t=LAYERS_OUT, n_layers_r_t=LAYERS_R, first_stage_args={ "n_units_r_small": DEFAULT_UNITS_R_SMALL_S, "penalty_orthogonal": PENALTY_ORTHOGONAL, }, ), PSEUDOOUT_NAME + SEP + DR_TRANSFORMATION + SEP + S1_STRATEGY: PseudoOutcomeNet( transformation=DR_TRANSFORMATION, first_stage_strategy=S1_STRATEGY, n_layers_out=LAYERS_OUT, n_layers_r=LAYERS_R, penalty_l2_t=PENALTY_L2, penalty_l2=PENALTY_L2, n_layers_out_t=LAYERS_OUT, n_layers_r_t=LAYERS_R, ), PSEUDOOUT_NAME + SEP + RA_TRANSFORMATION + SEP + S1_STRATEGY: PseudoOutcomeNet( transformation=RA_TRANSFORMATION, first_stage_strategy=S1_STRATEGY, n_layers_out=LAYERS_OUT, n_layers_r=LAYERS_R, penalty_l2_t=PENALTY_L2, penalty_l2=PENALTY_L2, n_layers_out_t=LAYERS_OUT, n_layers_r_t=LAYERS_R, ), } FULL_MODEL_SET_AISTATS = dict(**ALL_MODELS_AISTATS, **COMBINED_MODELS) # some more constants for experiments NTRAIN_BASE = 2000 NTEST_BASE = 500 D_BASE = 25 BASE_XI = 3 TARGET_PROP_BASE = None XI_STRING = "xi" N_STRING = "n" D_T_STRING = "dim_t" PROPENSITY_CONSTANT_STRING = "p" TARGET_STRING = "target_p" def simulation_experiment_loop( range_change: list, change_dim: str = N_STRING, n_train: int = NTRAIN_BASE, n_test: int = NTEST_BASE, n_repeats: int = 10, d: int = D_BASE, n_w: int = 0, n_c: int = 5, n_o: int = 5, n_t: int = 0, file_base: str = "results", xi: float = BASE_XI, mu_1_model: Optional[dict] = None, correlated_x: bool = False, mu_1_model_params: Optional[dict] = None, mu_0_model_params: Optional[dict] = None, models: Optional[dict] = None, nonlinear_prop: bool = True, prop_offset: Union[float, str] = "center", target_prop: Optional[float] = TARGET_PROP_BASE, ) -> None: if change_dim is N_STRING: for n in range_change: log.debug(f"Running experiments for {N_STRING} set to {n}") do_one_experiment_repeat( n_train=n, n_test=n_test, n_repeats=n_repeats, d=d, n_w=n_w, n_c=n_c, n_o=n_o, n_t=n_t, file_base=file_base, xi=xi, mu_1_model=mu_1_model, correlated_x=correlated_x, models=models, mu_1_model_params=mu_1_model_params, mu_0_model_params=mu_0_model_params, nonlinear_prop=nonlinear_prop, prop_offset=prop_offset, target_prop=target_prop, ) elif change_dim is XI_STRING: for xi_temp in range_change: log.debug(f"Running experiments for {XI_STRING} set to {xi_temp}") do_one_experiment_repeat( n_train=n_train, n_test=n_test, n_repeats=n_repeats, d=d, n_w=n_w, n_c=n_c, n_o=n_o, n_t=n_t, file_base=file_base, xi=xi_temp, mu_1_model=mu_1_model, correlated_x=correlated_x, models=models, mu_1_model_params=mu_1_model_params, mu_0_model_params=mu_0_model_params, nonlinear_prop=nonlinear_prop, prop_offset=prop_offset, target_prop=target_prop, ) elif change_dim is D_T_STRING: for d_t_temp in range_change: log.debug(f"Running experiments for {D_T_STRING} set to {d_t_temp}") do_one_experiment_repeat( n_train=n_train, n_test=n_test, n_repeats=n_repeats, d=d, n_w=n_w, n_c=n_c, n_o=n_o, n_t=d_t_temp, file_base=file_base, xi=xi, mu_1_model=mu_1_model, correlated_x=correlated_x, models=models, mu_1_model_params=mu_1_model_params, mu_0_model_params=mu_0_model_params, nonlinear_prop=nonlinear_prop, prop_offset=prop_offset, target_prop=target_prop, ) elif change_dim is TARGET_STRING: for target_prop_temp in range_change: log.debug( f"Running experiments for {TARGET_STRING} set to {target_prop_temp}" ) do_one_experiment_repeat( n_train=n_train, n_test=n_test, n_repeats=n_repeats, d=d, n_w=n_w, n_c=n_c, n_o=n_o, n_t=n_t, file_base=file_base, xi=xi, mu_1_model=mu_1_model, correlated_x=correlated_x, models=models, mu_1_model_params=mu_1_model_params, mu_0_model_params=mu_0_model_params, nonlinear_prop=nonlinear_prop, prop_offset=prop_offset, target_prop=target_prop_temp, ) def do_one_experiment_repeat( n_train: int = NTRAIN_BASE, n_test: int = NTEST_BASE, n_repeats: int = 10, d: int = D_BASE, n_w: int = 0, n_c: int = 0, n_o: int = 0, n_t: int = 0, file_base: str = "results", xi: float = BASE_XI, mu_1_model: Optional[dict] = None, correlated_x: bool = True, mu_1_model_params: Optional[dict] = None, mu_0_model_params: Optional[dict] = None, models: Optional[dict] = None, nonlinear_prop: bool = True, range_exp: Optional[list] = None, prop_offset: Union[float, str] = 0, target_prop: Optional[float] = None, ) -> None: # make path if not os.path.exists(RESULT_DIR): os.makedirs(RESULT_DIR) if range_exp is None: range_exp = list(range(1, n_repeats + 1)) if models is None: models = FULL_MODEL_SET_AISTATS if target_prop is None: prop_string = str(prop_offset) else: prop_string = str(target_prop) # create file name and file file_name = ( file_base + SEP + str(n_train) + SEP + str(d) + SEP + str(n_w) + SEP + str(n_c) + SEP + str(n_o) + SEP + str(n_t) + SEP + str(xi) + SEP + prop_string ) out_file = open(RESULT_DIR + file_name + CSV_STRING, "w", buffering=1) writer = csv.writer(out_file) header = [name for name in models.keys()] writer.writerow(header) for i in range_exp: log.debug(f"Running experiment {i}.") mses = one_simulation_experiment( n_train=n_train, n_test=n_test, d=d, n_w=n_w, n_c=n_c, n_o=n_o, n_t=n_t, seed=i, xi=xi, mu_1_model=mu_1_model, correlated_x=correlated_x, models=models, nonlinear_prop=nonlinear_prop, mu_0_model_params=mu_0_model_params, mu_1_model_params=mu_1_model_params, prop_offset=prop_offset, target_prop=target_prop, ) writer.writerow(mses) out_file.close() return None def one_simulation_experiment( n_train: int, n_test: int = NTEST_BASE, d: int = D_BASE, n_w: int = 0, n_c: int = 0, n_o: int = 0, n_t: int = 0, xi: float = BASE_XI, seed: int = 42, mu_1_model: Optional[dict] = None, propensity_model: Optional[dict] = None, correlated_x: bool = False, mu_1_model_params: Optional[dict] = None, mu_0_model_params: Optional[dict] = None, models: Optional[dict] = None, nonlinear_prop: bool = False, prop_offset: Union[float, str] = 0, target_prop: Optional[float] = None, ) -> list: if models is None: models = FULL_MODEL_SET_AISTATS # get data X, y, w, p, t = simulate_treatment_setup( n_train + n_test, d=d, n_w=n_w, n_c=n_c, n_o=n_o, n_t=n_t, propensity_model=propensity_model, propensity_model_params={ "xi": xi, "nonlinear": nonlinear_prop, "offset": prop_offset, "target_prop": target_prop, }, seed=seed, mu_1_model=mu_1_model, mu_0_model_params=mu_0_model_params, mu_1_model_params=mu_1_model_params, covariate_model_params={"correlated": correlated_x}, ) # split data X_train, y_train, w_train, _ = ( X[:n_train, :], y[:n_train], w[:n_train], p[:n_train], ) X_test, t_test = X[n_train:, :], t[n_train:] rmses = [] for model_name, model in models.items(): log.debug(f"Training model {model_name}") estimator = clone(model) estimator.fit(X=X_train, y=y_train, w=w_train) cate_test = estimator.predict(X_test, return_po=False) rmses.append(eval_root_mse(cate_test, t_test)) return rmses def main_AISTATS( setting: int = 1, models: Any = None, file_name: str = "res", n_repeats: int = 10, ) -> None: if models is None: models = FULL_MODEL_SET_AISTATS elif type(models) is list or type(models) is str: models = get_model_set(models) if setting == 1: # no treatment effect, with confounding, by n simulation_experiment_loop( [1000, 2000, 5000, 10000], change_dim="n", n_t=0, n_w=0, n_c=5, n_o=5, file_base=file_name, models=models, n_repeats=n_repeats, ) elif setting == 2: # with treatment effect, with confounding, by n simulation_experiment_loop( [1000, 2000, 5000, 10000], change_dim="n", n_t=5, n_w=0, n_c=5, n_o=5, file_base=file_name, models=models, n_repeats=n_repeats, ) elif setting == 3: # Potential outcomes are supported on independent covariates, no confounding, by n simulation_experiment_loop( [1000, 2000, 5000, 10000], change_dim="n", n_t=10, n_w=0, n_c=0, n_o=10, file_base=file_name, models=models, xi=0.5, mu_1_model_params={"withbase": False}, n_repeats=n_repeats, ) elif setting == 4: # vary number of predictive features at n=2000 simulation_experiment_loop( [0, 1, 3, 5, 7, 10], change_dim=D_T_STRING, n_train=2000, n_c=5, n_o=5, file_base=file_name, models=models, n_repeats=n_repeats, ) elif setting == 5: # vary percentage treated at n=2000 simulation_experiment_loop( [0.1, 0.2, 0.3, 0.4, 0.5], change_dim=TARGET_STRING, n_train=2000, n_c=5, n_o=5, n_t=0, n_repeats=n_repeats, file_base=file_name, models=models, ) ================================================ FILE: experiments/experiments_benchmarks_NeurIPS21/README.md ================================================ # Replication code for "Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation" This folder contains the files to replicate the benchmarking studies of random forest (RF) and neural network (NN) based CATE estimators using the IHDP, ACIC2016 and Twins datasets. The code for RFs is in R and relies on the R-package ‘grf’ which is available on CRAN. The code for NNs relies on the python package ‘catenets’ in this repo. This folder provides both python and R code to replicate the results of all empirical studies. Always run the python code first, this code downloads and/or creates the datasets that are used in both python and R code. The python code can also be run using the file 'run_experiments_benchmarks_NeurIPS.py' in the root of the repo. For IHDP: Setting ‘original’ reproduces results as reported in Figure 3a and setting ‘modified’ reproduces results in Figure 3b. For ACIC: we considered the simulation numbers (`simu_num’) 2, 26 and 7. ================================================ FILE: experiments/experiments_benchmarks_NeurIPS21/__init__.py ================================================ ================================================ FILE: experiments/experiments_benchmarks_NeurIPS21/acic_experiments_catenets.py ================================================ """ Utils to replicate ACIC2016 experiments with catenets """ # Author: Alicia Curth import csv import os from pathlib import Path import numpy as np from sklearn import clone from catenets.datasets import load from catenets.experiment_utils.base import eval_root_mse from catenets.models.jax import RNET_NAME, T_NAME, TARNET_NAME, RNet, TARNet, TNet RESULT_DIR = Path("results/experiments_benchmarking/acic2016/") SEP = "_" PARAMS_DEPTH = {"n_layers_r": 3, "n_layers_out": 2} PARAMS_DEPTH_2 = { "n_layers_r": 3, "n_layers_out": 2, "n_layers_r_t": 3, "n_layers_out_t": 2, } ALL_MODELS = { T_NAME: TNet(**PARAMS_DEPTH), TARNET_NAME: TARNet(**PARAMS_DEPTH), RNET_NAME: RNet(**PARAMS_DEPTH_2), } def do_acic_experiments( n_exp: int = 10, n_reps=5, file_name: str = "results_catenets", simu_num: int = 1, models: dict = None, train_size: int = 4000, pre_trans: bool = True, ): if models is None: models = ALL_MODELS # get file to write in if not os.path.isdir(RESULT_DIR): os.makedirs(RESULT_DIR) out_file = open( RESULT_DIR / ( file_name + SEP + str(pre_trans) + SEP + str(simu_num) + SEP + str(train_size) + ".csv" ), "w", buffering=1, ) writer = csv.writer(out_file) header = ( ["file_name", "run", "cate_var_in", "cate_var_out", "y_var_in"] + [name + "_in" for name in models.keys()] + [name + "_out" for name in models.keys()] ) writer.writerow(header) for i_exp in range(n_exp): # get data X, w, y, po_train, X_test, w_test, y_test, po_test = load( "acic2016", preprocessed=pre_trans, original_acic_outcomes=True, i_exp=i_exp, simu_num=simu_num, train_size=train_size, ) cate_in = po_train[:, 1] - po_train[:, 0] cate_out = po_test[:, 1] - po_test[:, 0] cate_var_in = np.var(cate_in) cate_var_out = np.var(cate_out) y_var_in = np.var(y) for k in range(n_reps): pehe_in = [] pehe_out = [] for model_name, estimator in models.items(): print(f"Experiment {i_exp}, run {k}, with {model_name}") estimator_temp = clone(estimator) estimator_temp.set_params(seed=k) # fit estimator estimator_temp.fit(X=X, y=y, w=w) cate_pred_in = estimator_temp.predict(X, return_po=False) cate_pred_out = estimator_temp.predict(X_test, return_po=False) pehe_in.append(eval_root_mse(cate_pred_in, cate_in)) pehe_out.append(eval_root_mse(cate_pred_out, cate_out)) writer.writerow( [i_exp, k, cate_var_in, cate_var_out, y_var_in] + pehe_in + pehe_out ) out_file.close() ================================================ FILE: experiments/experiments_benchmarks_NeurIPS21/acic_experiments_grf.R ================================================ library(grf) do_acic_exper_loop <- function(simnums = c(2, 26, 7), n_reps = 5, n_exp = 10, with_t = F) { # function to loop over multiple simulation settings for (k in simnums) { do_acic_exper(k, n_reps = n_reps, n_exp = n_exp, with_t = with_t) } } do_acic_exper <- function(simnum, n_reps = 5, n_exp = 10, with_t = F) { # function to do acic experiments for one simulation setting (simnum) # n_reps indicates the number of replications (random seeds used) # n_exp indicates the number of simulations to use within this setting (1-100) # with_t indicates whether to create additional results with pre-transformed data X <- data.matrix(read.csv('catenets/datasets/data/data_cf_all/x.csv')) X_trans <- data.matrix(read.csv('catenets/datasets/data/x_trans.csv')) range_train = 1:4000 range_test = 4001:4802 # get files sim_dir = paste0('catenets/datasets/data/data_cf_all/', simnum, '/') file_list <- list.files(sim_dir) for (i in 1:(n_exp)) { # loop over simulations within this setting print(paste0('Experiment number ', i)) for (k in 1:n_reps) { # loop over seeds print(paste0('Iteration number ', k)) set.seed(k * i) X_train <- X[range_train,] X_test <- X[range_test,] X_t_train <- X_trans[range_train,] outcomes = read.csv(paste0(sim_dir, file_list[i])) z = outcomes$z y = outcomes$z * outcomes$y1 + (1 - outcomes$z) * outcomes$y0 t = outcomes$mu1 - outcomes$mu0 z_train = z[range_train] y_train = y[range_train] t_train = t[range_train] t_test = t[range_test] # causal forest print('causal forest') cf <- causal_forest(X_train, y_train, z_train, seed = k * i) pred_cf <- predict(cf, X)$predictions rmse_cf_in <- sqrt(mean((t_train - pred_cf[range_train]) ^ 2)) rmse_cf_out <- sqrt(mean((t_test - pred_cf[range_test]) ^ 2)) if (with_t == T) { # also fit estimators using pre-transformed data cf.t <- causal_forest(X_t_train, y_train, z_train, seed = k * i) pred_cf.t <- predict(cf.t, X_trans)$predictions rmse_cf_in.t <- sqrt(mean((t_train - pred_cf.t[range_train]) ^ 2)) rmse_cf_out.t <- sqrt(mean((t_test - pred_cf.t[range_test]) ^ 2)) } # t-learner print('t learner') y0.forest <- regression_forest(subset(X_train, z_train == 0), y_train[z_train == 0], seed = k * i) y1.forest <- regression_forest(subset(X_train, z_train == 1), y_train[z_train == 1], seed = k * i) pred_t <- predict(y1.forest, X)$predictions - predict(y0.forest, X)$predictions rmse_t_in <- sqrt(mean((t_train - pred_t[range_train]) ^ 2)) rmse_t_out <- sqrt(mean((t_test - pred_t[range_test]) ^ 2)) if (with_t == T) { # also fit estimators using pre-transformed data y0.forest.t <- regression_forest(subset(X_t_train, z_train == 0), y_train[z_train == 0], seed = k * i) y1.forest.t <- regression_forest(subset(X_t_train, z_train == 1), y_train[z_train == 1], seed = k * i) pred_t.t <- predict(y1.forest.t, X_trans)$predictions - predict(y0.forest.t, X_trans)$predictions rmse_t_in.t <- sqrt(mean((t_train - pred_t.t[range_train]) ^ 2)) rmse_t_out.t <- sqrt(mean((t_test - pred_t.t[range_test]) ^ 2)) } # s-learner print('s learner') s_forest <- regression_forest(cbind(X_train, z_train), y_train, seed = k * i) n_total <- nrow(X) test_treated <- rep(1, n_total) test_control <- rep(0, n_total) pred_s <- predict(s_forest, cbind(X, test_treated))$predictions - predict(s_forest, cbind(X, test_control))$predictions rmse_s_in <- sqrt(mean((t_train - pred_s[range_train]) ^ 2)) rmse_s_out <- sqrt(mean((t_test - pred_s[range_test]) ^ 2)) if (with_t == T) { # also fit estimators using pre-transformed data s_forest.t <- regression_forest(data.matrix(cbind(X_t_train, z_train)), y_train, seed = k * i) pred_s.t <- predict(s_forest.t, data.matrix(cbind(X_trans, test_treated)))$predictions - predict(s_forest.t, data.matrix(cbind(X_trans, test_control)))$predictions rmse_s_in.t <- sqrt(mean((t_train - pred_s.t[range_train]) ^ 2)) rmse_s_out.t <- sqrt(mean((t_test - pred_s.t[range_test]) ^ 2)) } if (with_t == T) { df_res <- data.frame( file = file_list[i], run = k, cf_in = rmse_cf_in, cf_t_in = rmse_cf_in.t, t_in = rmse_t_in, t_t_in = rmse_t_in.t, s_in = rmse_s_in, s_in_t = rmse_s_in.t, cf_out = rmse_cf_out, cf_t_out = rmse_cf_out.t, t_out = rmse_t_out, t_t_out = rmse_t_out, s_out = rmse_s_out, s_t_out = rmse_s_out.t ) } else{ df_res <- data.frame( file = file_list[i], run = k, cf_in = rmse_cf_in, t_in = rmse_t_in, s_in = rmse_s_in, cf_out = rmse_cf_out, t_out = rmse_t_out, s_out = rmse_s_out ) } if (i * k == 1) { write.table( df_res, file = paste0( 'results/experiments_benchmarking/acic2016/grf_', simnum, '_', with_t, '_', n_exp, '_', n_reps, '.csv' ), col.names = T, sep = ',', row.names = F ) } else{ write.table( df_res, file = paste0( 'results/experiments_benchmarking/acic2016/grf_', simnum, '_', with_t, '_', n_exp, '_', n_reps, '.csv' ), col.names = F, sep = ',', row.names = F, append = T ) } } } } ================================================ FILE: experiments/experiments_benchmarks_NeurIPS21/ihdp_experiments_catenets.py ================================================ """ Utils to replicate IHDP experiments with catenets """ # Author: Alicia Curth import csv import os from pathlib import Path from typing import Optional, Union import numpy as np from sklearn import clone from catenets.datasets.dataset_ihdp import get_one_data_set, load_raw, prepare_ihdp_data from catenets.experiment_utils.base import eval_root_mse from catenets.models.jax import RNET_NAME, T_NAME, TARNET_NAME, RNet, TARNet, TNet DATA_DIR = Path("catenets/datasets/data/") RESULT_DIR = Path("results/experiments_benchmarking/ihdp/") SEP = "_" PARAMS_DEPTH = {"n_layers_r": 3, "n_layers_out": 2} PARAMS_DEPTH_2 = { "n_layers_r": 3, "n_layers_out": 2, "n_layers_r_t": 3, "n_layers_out_t": 2, } ALL_MODELS = { T_NAME: TNet(**PARAMS_DEPTH), TARNET_NAME: TARNet(**PARAMS_DEPTH), RNET_NAME: RNet(**PARAMS_DEPTH_2), } def do_ihdp_experiments( n_exp: Union[int, list] = 100, n_reps: int = 5, file_name: str = "ihdp_all", model_params: Optional[dict] = None, models: Optional[dict] = None, setting: str = "original", ) -> None: if models is None: models = ALL_MODELS if (setting == "original") or (setting == "C"): setting = "C" elif (setting == "modified") or (setting == "D"): setting = "D" else: raise ValueError( f"Setting should be one of original or modified. You passed {setting}." ) # get file to write in if not os.path.isdir(RESULT_DIR): os.makedirs(RESULT_DIR) out_file = open(RESULT_DIR / (file_name + SEP + setting + ".csv"), "w", buffering=1) writer = csv.writer(out_file) header = ( ["exp", "run", "cate_var_in", "cate_var_out", "y_var_in"] + [name + "_in" for name in models.keys()] + [name + "_out" for name in models.keys()] ) writer.writerow(header) # get data data_train, data_test = load_raw(DATA_DIR) if isinstance(n_exp, int): experiment_loop = list(range(1, n_exp + 1)) elif isinstance(n_exp, list): experiment_loop = n_exp else: raise ValueError("n_exp should be either an integer or a list of integers.") for i_exp in experiment_loop: # get data data_exp = get_one_data_set(data_train, i_exp=i_exp, get_po=True) data_exp_test = get_one_data_set(data_test, i_exp=i_exp, get_po=True) X, y, w, cate_true_in, X_t, cate_true_out = prepare_ihdp_data( data_exp, data_exp_test, setting=setting ) # compute some stats cate_var_in = np.var(cate_true_in) cate_var_out = np.var(cate_true_out) y_var_in = np.var(y) for k in range(n_reps): pehe_in = [] pehe_out = [] for model_name, estimator in models.items(): print(f"Experiment {i_exp}, run {k}, with {model_name}") estimator_temp = clone(estimator) estimator_temp.set_params(seed=k) if model_params is not None: estimator_temp.set_params(**model_params) # fit estimator estimator_temp.fit(X=X, y=y, w=w) cate_pred_in = estimator_temp.predict(X, return_po=False) cate_pred_out = estimator_temp.predict(X_t, return_po=False) pehe_in.append(eval_root_mse(cate_pred_in, cate_true_in)) pehe_out.append(eval_root_mse(cate_pred_out, cate_true_out)) writer.writerow( [i_exp, k, cate_var_in, cate_var_out, y_var_in] + pehe_in + pehe_out ) out_file.close() ================================================ FILE: experiments/experiments_benchmarks_NeurIPS21/ihdp_experiments_grf.R ================================================ library(grf) library(reticulate) do_ihdp_exper <- function(n_exp = 100, n_reps = 5, setup = 'original') { # read IHDP data (originally saved in numpy format) np <- import("numpy") npz_train <- np$load('catenets/datasets/data/ihdp_npci_1-100.train.npz') x_train <- npz_train$f[['x']] y_train <- npz_train$f[['yf']] w_train <- npz_train$f[['t']] mu0_train <- npz_train$f[['mu0']] mu1_train <- npz_train$f[['mu1']] npz_test <- np$load('catenets/datasets/data/ihdp_npci_1-100.test.npz') x_test <- npz_test$f[['x']] y_test <- npz_test$f[['yf']] w_test <- npz_test$f[['t']] mu0_test <- npz_test$f[['mu0']] mu1_test <- npz_test$f[['mu1']] if (setup == 'modified') { # make TE additive instead y_train[w_train == 1] = y_train[w_train == 1] + mu0_train[w_train == 1] mu1_train = mu0_train + mu1_train mu1_test = mu0_test + mu1_test } cate_train <- mu1_train - mu0_train cate_test <- mu1_test - mu0_test for (i in 1:n_exp) { # loop over runs print(paste0('Experiment number', i)) for (k in 1:n_reps) { # loop over seeds # Causal forest ------------------------------ print('causal forest') cf <- causal_forest(x_train[, , i], y_train[, i], w_train[, i], seed = k) # predict CATE pred_cf_in <- predict(cf, x_train[, , i])$predictions pred_cf_out <- predict(cf, x_test[, , i])$predictions # Evaluate rmse_cf_in <- sqrt(mean((cate_train[, i] - pred_cf_in) ^ 2)) rmse_cf_out <- sqrt(mean((cate_test[, i] - pred_cf_out) ^ 2)) # T-learner ----------------------------------------------------- print('t learner') y0.forest <- regression_forest(subset(x_train[, , i], w_train[, i] == 0), y_train[w_train[, i] == 0, i], seed = k) y1.forest <- regression_forest(subset(x_train[, , i], w_train[, i] == 1), y_train[w_train[, i] == 1, i], seed = k) # predict CATE pred_t_in <- predict(y1.forest, x_train[, , i])$predictions - predict(y0.forest, x_train[, , i])$predictions pred_t_out <- predict(y1.forest, x_test[, , i])$predictions - predict(y0.forest, x_test[, , i])$predictions # Evaluate rmse_t_in <- sqrt(mean((cate_train[, i] - pred_t_in) ^ 2)) rmse_t_out <- sqrt(mean((cate_test[, i] - pred_t_out) ^ 2)) # s-learner ------------------------------------------------------------- print('s learner') s_forest <- regression_forest(cbind(x_train[, , i], w_train[, i]), y_train[, i], seed = k) # create extended feature matrices n_train <- nrow(x_train[, , i]) n_test <- nrow(x_test[, , i]) train_treated <- rep(1, n_train) train_control <- rep(0, n_train) test_treated <- rep(1, n_test) test_control <- rep(0, n_test) # predict CATE pred_s_in <- predict(s_forest, cbind(x_train[, , i], train_treated))$predictions - predict(s_forest, cbind(x_train[, , i], train_control))$predictions pred_s_out <- predict(s_forest, cbind(x_test[, , i], test_treated))$predictions - predict(s_forest, cbind(x_test[, , i], test_control))$predictions # evaluate rmse_s_in <- sqrt(mean((cate_train[, i] - pred_s_in) ^ 2)) rmse_s_out <- sqrt(mean((cate_test[, i] - pred_s_out) ^ 2)) df_res <- data.frame( simu = i, run = k, cf_in = rmse_cf_in, t_in = rmse_t_in, s_in = rmse_s_in, cf_out = rmse_cf_out, t_out = rmse_t_out, s_out = rmse_s_out ) if (i * k == 1) { write.table( df_res, file = paste0('results/experiments_benchmarking/ihdp/grf_', setup, '.csv'), col.names = T, sep = ',', row.names = F ) } else{ write.table( df_res, file = paste0('results/experiments_benchmarking/ihdp/grf_', setup, '.csv'), col.names = F, append = T, sep = ',', row.names = F ) } } } } ================================================ FILE: experiments/experiments_benchmarks_NeurIPS21/twins_experiments_catenets.py ================================================ """ Utils to replicate Twins experiments with catenets """ import csv # Author: Alicia Curth import os from pathlib import Path import numpy as onp import pandas as pd from sklearn import clone from sklearn.model_selection import train_test_split from catenets.datasets import load from catenets.experiment_utils.base import eval_root_mse from catenets.models.jax import RNET_NAME, T_NAME, TARNET_NAME, RNet, TARNet, TNet RESULT_DIR = Path("results/experiments_benchmarking/twins/") EXP_DIR = Path("experiments/experiments_benchmarks_NeurIPS21/twins_datasets/") SEP = "_" PARAMS_DEPTH = {"n_layers_r": 1, "n_layers_out": 1} PARAMS_DEPTH_2 = { "n_layers_r": 1, "n_layers_out": 1, "n_layers_r_t": 1, "n_layers_out_t": 1, } ALL_MODELS = { T_NAME: TNet(**PARAMS_DEPTH), TARNET_NAME: TARNet(**PARAMS_DEPTH), RNET_NAME: RNet(**PARAMS_DEPTH_2), } def do_twins_experiment_loop( n_train_loop=[500, 1000, 2000, 5000, None], n_exp: int = 10, file_name: str = "twins", models: dict = None, test_size=0.5, ): for n in n_train_loop: print(f"Running twins experiments for subset_train {n}") do_twins_experiments( n_exp=n_exp, file_name=file_name, models=models, subset_train=n, test_size=test_size, ) def do_twins_experiments( n_exp: int = 10, file_name: str = "twins", models: dict = None, subset_train: int = None, prop_treated=0.5, test_size=0.5, ): if models is None: models = ALL_MODELS # get file to write in if not os.path.isdir(RESULT_DIR): os.makedirs(RESULT_DIR) out_file = open( RESULT_DIR / (file_name + SEP + str(prop_treated) + SEP + str(subset_train) + ".csv"), "w", buffering=1, ) writer = csv.writer(out_file) header = [name + "_pehe" for name in models.keys()] writer.writerow(header) for i_exp in range(n_exp): pehe_out = [] # get data X, X_t, y, w, y0_out, y1_out = prepare_twins( seed=i_exp, treat_prop=prop_treated, subset_train=subset_train, test_size=test_size, ) ite_out = y1_out - y0_out # split data for model_name, estimator in models.items(): print(f"Experiment {i_exp} with {model_name}") estimator_temp = clone(estimator) estimator_temp.set_params(**{"binary_y": True, "seed": i_exp}) # fit estimator estimator_temp.fit(X=X, y=y, w=w) cate_pred_out = estimator_temp.predict(X_t) pehe_out.append(eval_root_mse(cate_pred_out, ite_out)) writer.writerow(pehe_out) out_file.close() # utils --------------------------------------------------------------------- def prepare_twins(treat_prop=0.5, seed=42, test_size=0.5, subset_train: int = None): if not os.path.isdir(EXP_DIR): os.makedirs(EXP_DIR) out_base = ( "preprocessed" + SEP + str(treat_prop) + SEP + str(subset_train) + SEP + str(test_size) + SEP + str(seed) ) outfile_train = EXP_DIR / (out_base + SEP + "train.csv") outfile_test = EXP_DIR / (out_base + SEP + "test.csv") feat_list = [ "dmage", "mpcb", "cigar", "drink", "wtgain", "gestat", "dmeduc", "nprevist", "dmar", "anemia", "cardiac", "lung", "diabetes", "herpes", "hydra", "hemo", "chyper", "phyper", "eclamp", "incervix", "pre4000", "dtotord", "preterm", "renal", "rh", "uterine", "othermr", "adequacy_1", "adequacy_2", "adequacy_3", "pldel_1", "pldel_2", "pldel_3", "pldel_4", "pldel_5", "resstatb_1", "resstatb_2", "resstatb_3", "resstatb_4", ] if os.path.exists(outfile_train): print(f"Reading existing preprocessed twins file {out_base}") # use existing file df_train = pd.read_csv(outfile_train) X = onp.asarray(df_train[feat_list]) y = onp.asarray(df_train[["y"]]).reshape((-1,)) w = onp.asarray(df_train[["w"]]).reshape((-1,)) df_test = pd.read_csv(outfile_test) X_t = onp.asarray(df_test[feat_list]) y0_out = onp.asarray(df_test[["y0"]]).reshape((-1,)) y1_out = onp.asarray(df_test[["y1"]]).reshape((-1,)) else: # create file print(f"Creating preprocessed twins file {out_base}") onp.random.seed(seed) x, w, y, pos, _, _ = load( "twins", seed=seed, treat_prop=treat_prop, train_ratio=1 ) X, X_t, y, y_t, w, w_t, y0_in, y0_out, y1_in, y1_out = train_test_split( x, y, w, pos[:, 0], pos[:, 1], test_size=test_size, random_state=seed ) if subset_train is not None: X, y, w, y0_in, y1_in = ( X[:subset_train, :], y[:subset_train], w[:subset_train], y0_in[:subset_train], y1_in[:subset_train], ) # save data save_df_train = pd.DataFrame(X, columns=feat_list) save_df_train["y0"] = y0_in save_df_train["y1"] = y1_in save_df_train["w"] = w save_df_train["y"] = y save_df_train.to_csv(outfile_train) save_df_train = pd.DataFrame(X_t, columns=feat_list) save_df_train["y0"] = y0_out save_df_train["y1"] = y1_out save_df_train["w"] = w_t save_df_train["y"] = y_t save_df_train.to_csv(outfile_test) return X, X_t, y, w, y0_out, y1_out ================================================ FILE: experiments/experiments_benchmarks_NeurIPS21/twins_experiments_grf.R ================================================ library(grf) do_twins_exper <- function( n_reps = 10, subset_train = 500, test_size = 0.5, treat_prop=0.5) { i=1 for (k in 0:(n_reps-1)) { # loop over seeds print(paste0('Iteration number ', k)) set.seed(k) # read data (need to run the catenets script first; that creates the preprocessed data) if (subset_train == 5700){ df_train <- read.csv(paste0('experiments/experiments_benchmarks_NeurIPS21/twins_datasets/preprocessed_', treat_prop, '_None_', test_size, '_', k, '_train.csv')) df_test <- read.csv(paste0('experiments/experiments_benchmarks_NeurIPS21/twins_datasets/preprocessed_', treat_prop, '_None_', test_size, '_', k, '_test.csv')) }else{ df_train <- read.csv(paste0('experiments/experiments_benchmarks_NeurIPS21/twins_datasets/preprocessed_', treat_prop, '_', subset_train, '_', test_size, '_', k, '_train.csv')) df_test <- read.csv(paste0('experiments/experiments_benchmarks_NeurIPS21/twins_datasets/preprocessed_', treat_prop, '_', subset_train, '_', test_size, '_', k, '_test.csv')) } X_train <- data.matrix(df_train[,2:40]) X_test <- data.matrix(df_test[,2:40]) z_train = df_train$w y_train = df_train$y t_train = df_train$y1 - df_train$y0 z_test = df_test$w y_test = df_test$y t_test = df_test$y1 - df_test$y0 # causal forest print('causal forest') cf <- causal_forest(X_train, y_train, z_train, seed = k) pred_cf_in <- predict(cf, X_train)$predictions pred_cf_out <- predict(cf, X_test)$predictions rmse_cf_in <- sqrt(mean((t_train - pred_cf_in) ^ 2)) rmse_cf_out <- sqrt(mean((t_test - pred_cf_out) ^ 2)) # t-learner print('t learner') y0.forest <- regression_forest(subset(X_train, z_train == 0), y_train[z_train == 0], seed = k * i) y1.forest <- regression_forest(subset(X_train, z_train == 1), y_train[z_train == 1], seed = k * i) pred_t_in <- predict(y1.forest, X_train)$predictions - predict(y0.forest, X_train)$predictions pred_t_out <- predict(y1.forest, X_test)$predictions - predict(y0.forest, X_test)$predictions rmse_t_in <- sqrt(mean((t_train - pred_t_in) ^ 2)) rmse_t_out <- sqrt(mean((t_test - pred_t_out) ^ 2)) # s-learner print('s learner') s_forest <- regression_forest(cbind(X_train, z_train), y_train, seed = k * i) n_train <- nrow(X_train) n_test <- nrow(X_test) train_treated <- rep(1, n_train) train_control <- rep(0, n_train) test_treated <- rep(1, n_test) test_control <- rep(0, n_test) pred_s_in <- predict(s_forest, cbind(X_train, train_treated))$predictions - predict(s_forest, cbind(X_train, train_control))$predictions pred_s_out <- predict(s_forest, cbind(X_test, test_treated))$predictions - predict(s_forest, cbind(X_test, test_control))$predictions rmse_s_in <- sqrt(mean((t_train - pred_s_in) ^ 2)) rmse_s_out <- sqrt(mean((t_test - pred_s_out) ^ 2)) df_res <- data.frame( run = k, cf_in = rmse_cf_in, t_in = rmse_t_in, s_in = rmse_s_in, cf_out = rmse_cf_out, t_out = rmse_t_out, s_out = rmse_s_out ) if (k == 0) { write.table( df_res, file = paste0( 'results/experiments_benchmarking/twins/twins_grf_', subset_train, '_', n_reps, '.csv' ), col.names = T, sep = ',', row.names = F ) } else{ write.table( df_res, file = paste0( 'results/experiments_benchmarking/twins/twins_grf_', subset_train, '_', n_reps, '.csv' ), col.names = F, sep = ',', row.names = F, append = T ) } } } ================================================ FILE: experiments/experiments_inductivebias_NeurIPS21/__init__.py ================================================ ================================================ FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_AB.py ================================================ """ Utils to replicate setups A & B """ # Author: Alicia Curth import csv import os from typing import Optional, Tuple, Union import numpy as onp from sklearn import clone from catenets.datasets import load from catenets.experiment_utils.base import eval_root_mse from catenets.models.jax import ( DRAGON_NAME, DRNET_NAME, FLEXTE_NAME, OFFSET_NAME, RANET_NAME, RNET_NAME, SNET_NAME, T_NAME, TARNET_NAME, XNET_NAME, DragonNet, DRNet, FlexTENet, OffsetNet, RANet, RNet, SNet, TARNet, TNet, XNet, ) RESULT_DIR_SIMU = "results/experiments_inductive_bias/acic2016/simulations/" SEP = "_" # Hyperparms for all models PARAMS_DEPTH: dict = {"n_layers_r": 1, "n_layers_out": 1} PARAMS_DEPTH_2: dict = { "n_layers_r": 1, "n_layers_out": 1, "n_layers_r_t": 1, "n_layers_out_t": 1, } PENALTY_DIFF = 0.01 PENALTY_ORTHOGONAL = 0.1 # For main results ALL_MODELS = { T_NAME: TNet(**PARAMS_DEPTH), T_NAME + "_reg": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH), TARNET_NAME: TARNet(**PARAMS_DEPTH), TARNET_NAME + "_reg": TARNet( reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH ), OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH), FLEXTE_NAME: FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH ), FLEXTE_NAME + "_noortho_reg_same": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH), DRNET_NAME: DRNet(**PARAMS_DEPTH_2), DRNET_NAME + "_TAR": DRNet(first_stage_strategy="Tar", **PARAMS_DEPTH_2), } # For figure 4 in main text ABLATIONS = { T_NAME: TNet(**PARAMS_DEPTH), T_NAME + "_reg": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH), T_NAME + "_reg_same": TNet(train_separate=False, **PARAMS_DEPTH), OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH), OFFSET_NAME + "_reg_same": OffsetNet(**PARAMS_DEPTH), FLEXTE_NAME: FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH ), FLEXTE_NAME + "_reg_same": FlexTENet(penalty_orthogonal=PENALTY_ORTHOGONAL, **PARAMS_DEPTH), FLEXTE_NAME + "_noortho": FlexTENet( penalty_orthogonal=0, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH ), FLEXTE_NAME + "_noortho_reg_same": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH), } # For results in Appendix B.3 FLEX_LAMBDA = { "FlexTENet_001": FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=1 / 100, **PARAMS_DEPTH ), "FlexTENet_01": FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=1 / 10, **PARAMS_DEPTH ), "FlexTENet_0001": FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=1 / 1000, **PARAMS_DEPTH ), "FlexTENet_00001": FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=1 / 10000, **PARAMS_DEPTH ), } T_LAMBDA = { T_NAME: TNet(**PARAMS_DEPTH), T_NAME + "_reg_01": TNet(train_separate=False, penalty_diff=1 / 10, **PARAMS_DEPTH), T_NAME + "_reg_001": TNet(train_separate=False, penalty_diff=1 / 100, **PARAMS_DEPTH), T_NAME + "_reg_0001": TNet(train_separate=False, penalty_diff=1 / 1000, **PARAMS_DEPTH), T_NAME + "_reg_00001": TNet(train_separate=False, penalty_diff=1 / 10000, **PARAMS_DEPTH), } OFFSET_LAMBDA = { OFFSET_NAME + "_reg_01": OffsetNet(penalty_l2_p=1 / 10, **PARAMS_DEPTH), OFFSET_NAME + "_reg_001": OffsetNet(penalty_l2_p=1 / 100, **PARAMS_DEPTH), OFFSET_NAME + "_reg_0001": OffsetNet(penalty_l2_p=1 / 1000, **PARAMS_DEPTH), OFFSET_NAME + "_reg_00001": OffsetNet(penalty_l2_p=1 / 10000, **PARAMS_DEPTH), } # For results in appendix D.1 TWOSTEP_LEARNERS = { XNET_NAME: XNet(**PARAMS_DEPTH_2), RANET_NAME: RANet(**PARAMS_DEPTH_2), RNET_NAME: RNet(**PARAMS_DEPTH_2), DRNET_NAME: DRNet(**PARAMS_DEPTH_2), T_NAME: TNet(**PARAMS_DEPTH), } # For results in Appendix D.2 DRAGON_VARIANTS = { DRAGON_NAME: DragonNet(**PARAMS_DEPTH), DRAGON_NAME + "_reg": DragonNet( reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH ), } SNET_VARIANTS = { SNET_NAME: SNet( n_units_r=100, n_units_r_small=100, ortho_reg_type="fro", penalty_orthogonal=PENALTY_ORTHOGONAL, with_prop=False, **PARAMS_DEPTH, ), SNET_NAME + "_reg": SNet( n_units_r=100, n_units_r_small=100, ortho_reg_type="fro", penalty_orthogonal=PENALTY_ORTHOGONAL, with_prop=False, penalty_diff=PENALTY_DIFF, same_init=True, reg_diff=True, **PARAMS_DEPTH, ), } # For results in appendix D.6 DR_VARIANTS = { DRNET_NAME + "_t_reg": DRNet( first_stage_args={"train_separate": False, "penalty_diff": PENALTY_DIFF}, **PARAMS_DEPTH_2, ), DRNET_NAME + "_Flex": DRNet( first_stage_strategy="Flex", first_stage_args={ "private_out": False, "penalty_orthogonal": PENALTY_ORTHOGONAL, "penalty_l2_p": PENALTY_DIFF, "normalize_ortho": False, }, **PARAMS_DEPTH_2, ), } # results in appendix D.6 X_VARIANTS = { XNET_NAME + "_t_reg": XNet( first_stage_args={"train_separate": False, "penalty_diff": PENALTY_DIFF}, **PARAMS_DEPTH_2, ), XNET_NAME + "_Flex": XNet( first_stage_strategy="Flex", first_stage_args={ "private_out": False, "penalty_orthogonal": PENALTY_ORTHOGONAL, "penalty_l2_p": PENALTY_DIFF, "normalize_ortho": False, }, **PARAMS_DEPTH_2, ), } def do_acic_simu_loops( rho_loop: list = [0, 0.05, 0.1, 0.2, 0.5, 0.8], n1_loop: list = [200, 2000, 500], n_exp: int = 10, file_name: str = "acic_simu", models: Optional[dict] = None, n_0: int = 2000, n_test: int = 500, setting: str = "A", factual_eval: bool = False, ) -> None: if models is None: models = ALL_MODELS for n_1 in n1_loop: if setting == "A": for rho in rho_loop: do_acic_simu( n_1=n_1, n_exp=n_exp, file_name=file_name, models=models, n_0=n_0, n_test=n_test, prop_omega=0, prop_gamma=rho, factual_eval=factual_eval, ) else: for rho in rho_loop: do_acic_simu( n_1=n_1, n_exp=n_exp, file_name=file_name, models=models, n_0=n_0, n_test=n_test, prop_gamma=0, prop_omega=rho, factual_eval=factual_eval, ) def do_acic_simu( n_exp: Union[int, list] = 10, file_name: str = "acic_simu", models: Union[dict, str, None] = None, n_0: int = 2000, n_1: int = 200, n_test: int = 500, error_sd: float = 1, sp_lin: float = 0.6, sp_nonlin: float = 0.3, prop_gamma: float = 0, ate_goal: float = 0, inter: bool = True, prop_omega: float = 0, factual_eval: bool = False, ) -> None: if models is None: models = ALL_MODELS elif isinstance(models, str): if models == "all": models = ALL_MODELS elif models == "ablations": models = ABLATIONS elif models == "flex_lambda": models = FLEX_LAMBDA elif models == "t_lambda": models = T_LAMBDA elif models == "offset_lambda": models = OFFSET_LAMBDA elif models == "snet": models = SNET_VARIANTS elif models == "dragon": models = DRAGON_VARIANTS elif models == "twostep": models = TWOSTEP_LEARNERS elif models == "dr": models = DR_VARIANTS elif models == "x": models = X_VARIANTS else: raise ValueError(f"{models} is not a valid model selection string.") # get file to write in if not os.path.isdir(RESULT_DIR_SIMU): os.makedirs(RESULT_DIR_SIMU) out_file = open( RESULT_DIR_SIMU + file_name + SEP + str(n_0) + SEP + str(n_1) + SEP + str(prop_gamma) + SEP + str(prop_omega) + ".csv", "w", buffering=1, ) writer = csv.writer(out_file) header = ( ["y_var", "cate_var"] + [name + "_cate" for name in models.keys()] + [ name + "_mu0" for name in models.keys() if "R" not in name and "X" not in name ] + [ name + "_mu1" for name in models.keys() if "R" not in name and "X" not in name ] ) if factual_eval: header = header + [ name + "_factual" for name in models.keys() if "R" not in name and "X" not in name ] writer.writerow(header) if isinstance(n_exp, int): experiment_loop = list(range(1, n_exp + 1)) elif isinstance(n_exp, list): experiment_loop = n_exp else: raise ValueError("n_exp should be either an integer or a list of integers.") for i_exp in experiment_loop: rmse_cate = [] rmse_mu0 = [] rmse_mu1 = [] # get data if not factual_eval: X, y, w, X_t, mu_0_t, mu_1_t, cate_t = acic_simu( i_exp, n_0=n_0, n_1=n_1, n_test=n_test, error_sd=error_sd, sp_lin=sp_lin, sp_nonlin=sp_nonlin, prop_gamma=prop_gamma, ate_goal=ate_goal, inter=inter, prop_omega=prop_omega, ) else: rmse_factual = [] X, y, w, X_t, y_t, w_t, mu_0_t, mu_1_t, cate_t = acic_simu( i_exp, n_0=n_0, n_1=n_1, n_test=n_test, error_sd=error_sd, sp_lin=sp_lin, sp_nonlin=sp_nonlin, prop_gamma=prop_gamma, ate_goal=ate_goal, inter=inter, prop_omega=prop_omega, return_ytest=True, ) y_var = onp.var(y) cate_var = onp.var(cate_t) # split data for model_name, estimator in models.items(): print(f"Experiment {i_exp} with {model_name}") estimator_temp = clone(estimator) # fit estimator estimator_temp.fit(X=X, y=y, w=w) if "R" not in model_name and "X" not in model_name: cate_pred_out, mu0_pred, mu1_pred = estimator_temp.predict( X_t, return_po=True ) rmse_mu0.append(eval_root_mse(mu0_pred, mu_0_t)) rmse_mu1.append(eval_root_mse(mu1_pred, mu_1_t)) if factual_eval: pred_factual = w_t * mu1_pred + (1 - w_t) * mu0_pred rmse_factual.append(eval_root_mse(pred_factual, y_t)) else: cate_pred_out = estimator_temp.predict(X_t) rmse_cate.append(eval_root_mse(cate_pred_out, cate_t)) if not factual_eval: writer.writerow([y_var, cate_var] + rmse_cate + rmse_mu0 + rmse_mu1) else: writer.writerow( [y_var, cate_var] + rmse_cate + rmse_mu0 + rmse_mu1 + rmse_factual ) out_file.close() def acic_simu( i_exp: onp.ndarray, n_0: int = 2000, n_1: int = 200, n_test: int = 500, error_sd: float = 1, sp_lin: float = 0.6, sp_nonlin: float = 0.3, prop_gamma: float = 0, prop_omega: float = 0, ate_goal: float = 0, inter: bool = True, return_ytest: bool = False, ) -> Tuple: X_train, w_train, y_train, _, X_test, w_test, y_test, po_test = load( "acic2016", i_exp=i_exp, n_0=n_0, n_1=n_1, n_test=n_test, error_sd=error_sd, sp_lin=sp_lin, sp_nonlin=sp_nonlin, prop_gamma=prop_gamma, prop_omega=prop_omega, ate_goal=ate_goal, inter=inter, ) mu_0_t = po_test[:, 0] mu_1_t = po_test[:, 1] cate_t = mu_1_t - mu_0_t if return_ytest: return X_train, y_train, w_train, X_test, y_test, w_test, mu_0_t, mu_1_t, cate_t return X_train, y_train, w_train, X_test, mu_0_t, mu_1_t, cate_t ================================================ FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_CD.py ================================================ """ Utils to replicate experiments C and D """ # Author: Alicia Curth import csv import os from pathlib import Path from typing import Optional, Union from sklearn import clone from catenets.datasets.dataset_ihdp import get_one_data_set, load_raw, prepare_ihdp_data from catenets.experiment_utils.base import eval_root_mse from catenets.models.jax import ( DRNET_NAME, FLEXTE_NAME, OFFSET_NAME, T_NAME, TARNET_NAME, DRNet, FlexTENet, OffsetNet, TARNet, TNet, ) DATA_DIR = Path("catenets/datasets/data/") RESULT_DIR = Path("results/experiments_inductive_bias/ihdp/") SEP = "_" PARAMS_DEPTH: dict = {"n_layers_r": 2, "n_layers_out": 2} PARAMS_DEPTH_2: dict = { "n_layers_r": 2, "n_layers_out": 2, "n_layers_r_t": 2, "n_layers_out_t": 2, } PENALTY_DIFF = 0.01 PENALTY_ORTHOGONAL = 0.1 ALL_MODELS = { T_NAME: TNet(**PARAMS_DEPTH), T_NAME + "_reg": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH), TARNET_NAME: TARNet(**PARAMS_DEPTH), TARNET_NAME + "_reg": TARNet( reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH ), OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH), FLEXTE_NAME: FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH ), FLEXTE_NAME + "_noortho_reg_same": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH), DRNET_NAME: DRNet(**PARAMS_DEPTH_2), DRNET_NAME + "_TAR": DRNet(first_stage_strategy="Tar", **PARAMS_DEPTH_2), } def do_ihdp_experiments( n_exp: Union[int, list] = 100, file_name: str = "ihdp_all", model_params: Optional[dict] = None, models: Optional[dict] = None, setting: str = "C", ) -> None: if models is None: models = ALL_MODELS # get file to write in if not os.path.isdir(RESULT_DIR): os.makedirs(RESULT_DIR) out_file = open(RESULT_DIR / (file_name + SEP + setting + ".csv"), "w", buffering=1) writer = csv.writer(out_file) header = [name + "_in" for name in models.keys()] + [ name + "_out" for name in models.keys() ] writer.writerow(header) # get data data_train, data_test = load_raw(DATA_DIR) if isinstance(n_exp, int): experiment_loop = list(range(1, n_exp + 1)) elif isinstance(n_exp, list): experiment_loop = n_exp else: raise ValueError("n_exp should be either an integer or a list of integers.") for i_exp in experiment_loop: pehe_in = [] pehe_out = [] # get data data_exp = get_one_data_set(data_train, i_exp=i_exp, get_po=True) data_exp_test = get_one_data_set(data_test, i_exp=i_exp, get_po=True) X, y, w, cate_true_in, X_t, cate_true_out = prepare_ihdp_data( data_exp, data_exp_test, setting=setting ) for model_name, estimator in models.items(): print(f"Experiment {i_exp} with {model_name}") estimator_temp = clone(estimator) if model_params is not None: estimator_temp.set_params(**model_params) # fit estimator estimator_temp.fit(X=X, y=y, w=w) cate_pred_in = estimator_temp.predict(X, return_po=False) cate_pred_out = estimator_temp.predict(X_t, return_po=False) pehe_in.append(eval_root_mse(cate_pred_in, cate_true_in)) pehe_out.append(eval_root_mse(cate_pred_out, cate_true_out)) writer.writerow(pehe_in + pehe_out) out_file.close() ================================================ FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_acic.py ================================================ """ Utils to replicate ACIC2016 experiments (Appendix E.1) """ # Author: Alicia Curth import csv import os from pathlib import Path import numpy as np from sklearn import clone from catenets.datasets import load from catenets.experiment_utils.base import eval_root_mse from catenets.models.jax import ( DRNET_NAME, FLEXTE_NAME, OFFSET_NAME, T_NAME, TARNET_NAME, DRNet, FlexTENet, OffsetNet, TARNet, TNet, ) RESULT_DIR = Path("results/experiments_inductive_bias/acic2016/original") SEP = "_" PARAMS_DEPTH = {"n_layers_r": 1, "n_layers_out": 1} PARAMS_DEPTH_2 = { "n_layers_r": 1, "n_layers_out": 1, "n_layers_r_t": 1, "n_layers_out_t": 1, } PENALTY_DIFF = 0.01 PENALTY_ORTHOGONAL = 0.1 ALL_MODELS = { T_NAME: TNet(**PARAMS_DEPTH), T_NAME + "_reg": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH), TARNET_NAME: TARNet(**PARAMS_DEPTH), TARNET_NAME + "_reg": TARNet( reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH ), OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH), FLEXTE_NAME: FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH ), FLEXTE_NAME + "_noortho_reg_same": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH), DRNET_NAME: DRNet(**PARAMS_DEPTH_2), DRNET_NAME + "_TAR": DRNet(first_stage_strategy="Tar", **PARAMS_DEPTH_2), } def do_acic_orig_loop( simu_nums, n_exp: int = 10, file_name: str = "results", models: dict = None, train_size: float = 0.8, ): if models is None: models = ALL_MODELS for simu_num in simu_nums: print(f"Running simulation setting {simu_num}") do_acic_experiments( n_exp=n_exp, file_name=file_name, simu_num=simu_num, models=models, train_size=train_size, ) def do_acic_experiments( n_exp: int = 10, file_name: str = "results_catenets", simu_num: int = 1, models: dict = None, train_size: float = 0.8, pre_trans: bool = False, ): if models is None: models = ALL_MODELS # get file to write in if not os.path.isdir(RESULT_DIR): os.makedirs(RESULT_DIR) out_file = open( RESULT_DIR / ( file_name + SEP + str(pre_trans) + SEP + str(simu_num) + SEP + str(train_size) + ".csv" ), "w", buffering=1, ) writer = csv.writer(out_file) header = ( ["file_name", "cate_var_in", "cate_var_out", "y_var_in"] + [name + "_in" for name in models.keys()] + [name + "_out" for name in models.keys()] ) writer.writerow(header) for i_exp in range(n_exp): # get data X, w, y, po_train, X_test, w_test, y_test, po_test = load( "acic2016", preprocessed=pre_trans, original_acic_outcomes=True, keep_categorical=False, random_split=True, i_exp=i_exp, simu_num=simu_num, train_size=train_size, ) cate_in = po_train[:, 1] - po_train[:, 0] cate_out = po_test[:, 1] - po_test[:, 0] cate_var_in = np.var(cate_in) cate_var_out = np.var(cate_out) y_var_in = np.var(y) pehe_in = [] pehe_out = [] for model_name, estimator in models.items(): print(f"Experiment {i_exp} with {model_name}") estimator_temp = clone(estimator) # fit estimator estimator_temp.fit(X=X, y=y, w=w) cate_pred_in = estimator_temp.predict(X, return_po=False) cate_pred_out = estimator_temp.predict(X_test, return_po=False) pehe_in.append(eval_root_mse(cate_pred_in, cate_in)) pehe_out.append(eval_root_mse(cate_pred_out, cate_out)) writer.writerow( [i_exp, cate_var_in, cate_var_out, y_var_in] + pehe_in + pehe_out ) out_file.close() ================================================ FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_twins.py ================================================ """ Utils to replicate Twins experiments (Appendix E.2) """ # Author: Alicia Curth import csv import os from pathlib import Path import numpy as np from sklearn import clone from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.model_selection import train_test_split from sklearn.preprocessing import label_binarize from catenets.datasets import load from catenets.experiment_utils.base import eval_root_mse from catenets.models.jax import ( DRNET_NAME, FLEXTE_NAME, OFFSET_NAME, T_NAME, TARNET_NAME, DRNet, FlexTENet, OffsetNet, TARNet, TNet, ) from catenets.models.jax.base import check_shape_1d_data RESULT_DIR = Path("results/experiments_inductive_bias/twins") SEP = "_" PARAMS_DEPTH = {"n_layers_r": 1, "n_layers_out": 1} PARAMS_DEPTH_2 = { "n_layers_r": 1, "n_layers_out": 1, "n_layers_r_t": 1, "n_layers_out_t": 1, } PENALTY_DIFF = 0.01 PENALTY_ORTHOGONAL = 0.1 ALL_MODELS = { T_NAME: TNet(**PARAMS_DEPTH), T_NAME + "_reg": TNet(train_separate=False, penalty_diff=PENALTY_DIFF, **PARAMS_DEPTH), TARNET_NAME: TARNet(**PARAMS_DEPTH), TARNET_NAME + "_reg": TARNet( reg_diff=True, penalty_diff=PENALTY_DIFF, same_init=True, **PARAMS_DEPTH ), OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH), FLEXTE_NAME: FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH ), FLEXTE_NAME + "_noortho_reg_same": FlexTENet(penalty_orthogonal=0, **PARAMS_DEPTH), DRNET_NAME: DRNet(**PARAMS_DEPTH_2), DRNET_NAME + "_TAR": DRNet(first_stage_strategy="Tar", **PARAMS_DEPTH_2), } def do_twins_experiment_loop( n_train_loop=[500, 1000, 2000, 5000, None], prop_loop=[0.1, 0.25, 0.5, 0.75, 0.9], n_exp: int = 10, file_name: str = "twins", models: dict = None, test_size=0.5, ): for n in n_train_loop: for prop in prop_loop: print( "Running twins experiment for {} training samples with {} treated.".format( n, prop ) ) do_twins_experiments( n_exp=n_exp, file_name=file_name, models=models, subset_train=n, prop_treated=prop, test_size=test_size, ) def do_twins_experiments( n_exp: int = 10, file_name: str = "twins", models: dict = None, subset_train: int = None, prop_treated=0.5, test_size=0.5, ): if models is None: models = ALL_MODELS # get file to write in if not os.path.isdir(RESULT_DIR): os.makedirs(RESULT_DIR) out_file = open( RESULT_DIR / (file_name + SEP + str(prop_treated) + SEP + str(subset_train) + ".csv"), "w", buffering=1, ) writer = csv.writer(out_file) header = ( [name + "_cate" for name in models.keys()] + [ name + "_auc_ite" for name in models.keys() if "R" not in name and "X" not in name ] + [ name + "_auc_mu0" for name in models.keys() if "R" not in name and "X" not in name ] + [ name + "_auc_mu1" for name in models.keys() if "R" not in name and "X" not in name ] + [ name + "_ap_mu0" for name in models.keys() if "R" not in name and "X" not in name ] + [ name + "_ap_mu1" for name in models.keys() if "R" not in name and "X" not in name ] ) writer.writerow(header) for i_exp in range(n_exp): pehe_out = [] auc_ite = [] auc_mu0 = [] auc_mu1 = [] ap_mu0 = [] ap_mu1 = [] # get data x, w, y, pos, _, _ = load( "twins", seed=i_exp, treat_prop=prop_treated, train_ratio=1 ) # split data X, X_t, y, y_t, w, w_t, y0_in, y0_out, y1_in, y1_out = split_data( x, y, w, pos, random_state=i_exp, subset_train=subset_train, test_size=test_size, ) ite_out = y1_out - y0_out ite_out_encoded = label_binarize(ite_out, [-1, 0, 1]) n_test = X_t.shape[0] # split data for model_name, estimator in models.items(): print(f"Experiment {i_exp} with {model_name}") estimator_temp = clone(estimator) estimator_temp.set_params(**{"binary_y": True}) # fit estimator estimator_temp.fit(X=X, y=y, w=w) if ( "DR" not in model_name and "R" not in model_name and "X" not in model_name ): cate_pred_out, mu0_pred, mu1_pred = estimator_temp.predict( X_t, return_po=True ) # create probabilities for each possible level of ITE probs = np.zeros((n_test, 3)) probs[:, 0] = (mu0_pred * (1 - mu1_pred)).reshape((-1,)) # P(Y1-Y0=-1) probs[:, 1] = ( (mu0_pred * mu1_pred) + ((1 - mu0_pred) * (1 - mu1_pred)) ).reshape( (-1,) ) # P(Y1-Y0=0) probs[:, 2] = (mu1_pred * (1 - mu0_pred)).reshape((-1,)) # P(Y1-Y0=1) auc_ite.append(roc_auc_score(ite_out_encoded, probs)) # evaluate performance on potential outcomes auc_mu0.append(eval_roc_auc(y0_out, mu0_pred)) auc_mu1.append(eval_roc_auc(y1_out, mu1_pred)) ap_mu0.append(eval_ap(y0_out, mu0_pred)) ap_mu1.append(eval_ap(y1_out, mu1_pred)) else: cate_pred_out = estimator_temp.predict(X_t) pehe_out.append(eval_root_mse(cate_pred_out, ite_out)) writer.writerow(pehe_out + auc_ite + auc_mu0 + auc_mu1 + ap_mu0 + ap_mu1) out_file.close() # utils ------- def split_data(X, y, w, pos, test_size=0.5, random_state=42, subset_train: int = None): X, X_t, y, y_t, w, w_t, y0_in, y0_out, y1_in, y1_out = train_test_split( X, y, w, pos[:, 0], pos[:, 1], test_size=test_size, random_state=random_state ) if subset_train is not None: X, y, w, y0_in, y1_in = ( X[:subset_train, :], y[:subset_train], w[:subset_train], y0_in[:subset_train], y1_in[:subset_train], ) return X, X_t, y, y_t, w, w_t, y0_in, y0_out, y1_in, y1_out def eval_roc_auc(targets, preds): preds = check_shape_1d_data(preds) targets = check_shape_1d_data(targets) return roc_auc_score(targets, preds) def eval_ap(targets, preds): preds = check_shape_1d_data(preds) targets = check_shape_1d_data(targets) return average_precision_score(targets, preds) ================================================ FILE: pyproject.toml ================================================ [build-system] # AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! requires = ["setuptools>=46.1.0", "wheel"] build-backend = "setuptools.build_meta" ================================================ FILE: pytest.ini ================================================ [pytest] markers = slow: mark a test as slow. ================================================ FILE: run_experiments_AISTATS.py ================================================ """ File to run AISTATS experiments from shell """ # Author: Alicia Curth import argparse import sys from typing import Any import catenets.logger as log from experiments.experiments_AISTATS21.ihdp_experiments import do_ihdp_experiments from experiments.experiments_AISTATS21.simulations_AISTATS import main_AISTATS log.add(sink=sys.stderr, level="DEBUG") def init_arg() -> Any: # arg parser if script is run from shell parser = argparse.ArgumentParser() parser.add_argument("--experiment", default="simulation", type=str) parser.add_argument("--setting", default=1, type=int) parser.add_argument("--models", default=None, type=str) parser.add_argument("--file_name", default="results", type=str) parser.add_argument("--n_repeats", default=10, type=int) return parser.parse_args() if __name__ == "__main__": args = init_arg() if args.experiment == "simulation": main_AISTATS( setting=args.setting, models=args.models, file_name=args.file_name, n_repeats=args.n_repeats, ) elif args.experiment == "ihdp": do_ihdp_experiments( models=args.models, file_name=args.file_name, n_exp=args.n_repeats ) ================================================ FILE: run_experiments_benchmarks_NeurIPS.py ================================================ """ File to run the catenets experiments for "Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation" (Curth & vdS, NeurIPS21) from shell """ # Author: Alicia Curth import argparse import sys from typing import Any import catenets.logger as log from experiments.experiments_benchmarks_NeurIPS21.acic_experiments_catenets import ( do_acic_experiments, ) from experiments.experiments_benchmarks_NeurIPS21.ihdp_experiments_catenets import ( do_ihdp_experiments, ) from experiments.experiments_benchmarks_NeurIPS21.twins_experiments_catenets import ( do_twins_experiment_loop, ) log.add(sink=sys.stderr, level="DEBUG") def init_arg() -> Any: # arg parser parser = argparse.ArgumentParser() parser.add_argument("--setting", default="C", type=str) parser.add_argument("--experiment", default="ihdp", type=str) parser.add_argument("--file_name", default="results", type=str) parser.add_argument("--n_exp", default=10, type=int) parser.add_argument("--n_reps", default=5, type=int) parser.add_argument("--pre_trans", type=bool, default=False) parser.add_argument("--simu_num", type=int, default=2) return parser.parse_args() if __name__ == "__main__": args = init_arg() if (args.experiment == "ihdp") or (args.experiment == "IHDP"): do_ihdp_experiments( file_name=args.file_name, n_exp=args.n_exp, setting=args.setting, n_reps=args.n_reps, ) elif (args.experiment == "acic") or (args.experiment == "ACIC"): do_acic_experiments( file_name=args.file_name, n_reps=args.n_reps, simu_num=args.simu_num, n_exp=args.n_exp, pre_trans=args.pre_trans, ) elif (args.experiment == "twins") or (args.experiment == "Twins"): do_twins_experiment_loop(file_name=args.file_name, n_exp=args.n_reps) else: raise ValueError( f"Experiment should be one of ihdp/IHDP, acic/ACIC and twins/Twins. You " f"passed {args.experiment}" ) ================================================ FILE: run_experiments_inductive_bias_NeurIPS.py ================================================ """ File to run experiments for "On Inductive Biases for Heterogeneous Treatment Effect Estimation" (Curth & vdS, NeurIPS21) from shell """ # Author: Alicia Curth import argparse import sys from typing import Any import catenets.logger as log from experiments.experiments_inductivebias_NeurIPS21.experiments_AB import ( do_acic_simu_loops, ) from experiments.experiments_inductivebias_NeurIPS21.experiments_acic import ( do_acic_orig_loop, ) from experiments.experiments_inductivebias_NeurIPS21.experiments_CD import ( do_ihdp_experiments, ) from experiments.experiments_inductivebias_NeurIPS21.experiments_twins import ( do_twins_experiment_loop, ) log.add(sink=sys.stderr, level="DEBUG") def init_arg() -> Any: # arg parser parser = argparse.ArgumentParser() parser.add_argument("--setup", default="A", type=str) parser.add_argument("--file_name", default="results", type=str) parser.add_argument("--n_exp", default=10, type=int) parser.add_argument("--n_0", default=2000, type=int) parser.add_argument("--models", default=None, type=str) parser.add_argument("--n1_loop", nargs="+", default=[200, 2000, 500], type=int) parser.add_argument( "--rho_loop", nargs="+", default=[0, 0.05, 0.1, 0.2, 0.5, 0.8], type=float ) parser.add_argument("--factual_eval", default=False, type=bool) parser.add_argument( "--simu_nums", nargs="+", default=[x for x in range(1, 78)], type=int ) return parser.parse_args() if __name__ == "__main__": args = init_arg() if (args.setup == "A") or (args.setup == "B"): do_acic_simu_loops( n_exp=args.n_exp, file_name=args.file_name, setting=args.setup, n_0=args.n_0, models=args.models, n1_loop=args.n1_loop, rho_loop=args.rho_loop, factual_eval=args.factual_eval, ) elif (args.setup == "C") or (args.setup == "D"): do_ihdp_experiments( file_name=args.file_name, n_exp=args.n_exp, setting=args.setup ) elif (args.setup == "acic") or (args.setup == "ACIC"): # Appendix E.1 do_acic_orig_loop( simu_nums=args.simu_nums, n_exp=args.n_exp, file_name=args.file_name ) elif (args.setup == "twins") or (args.setup == "Twins"): # Appendix E.2 do_twins_experiment_loop(file_name=args.file_name, n_exp=args.n_exp) else: raise ValueError( f"Setup should be one of A, B, C, D, acic/ACIC or twins/Twins You passed" f" {args.setup}" ) ================================================ FILE: setup.py ================================================ # stdlib import os import re # third party from setuptools import setup PKG_DIR = os.path.dirname(os.path.abspath(__file__)) def read(fname: str) -> str: return open(os.path.join(os.path.dirname(__file__), fname)).read() def find_version() -> str: version_file = read("catenets/version.py").split("\n")[0] version_re = r"__version__ = \"(?P.+)\"" version_raw = re.match(version_re, version_file) if version_raw is None: return "0.0.1" version = version_raw.group("version") return version if __name__ == "__main__": try: setup( version=find_version(), ) except: # noqa print( "\n\nAn error occurred while building the project, " "please ensure you have the most updated version of setuptools, " "setuptools_scm and wheel with:\n" " pip install -U setuptools setuptools_scm wheel\n\n" ) raise ================================================ FILE: tests/conftest.py ================================================ import sys import catenets.logger as log log.add(sink=sys.stderr, level="CRITICAL") ================================================ FILE: tests/datasets/test_datasets.py ================================================ import pytest from catenets.datasets import load @pytest.mark.parametrize("train_ratio", [0.5, 0.8]) @pytest.mark.parametrize("treatment_type", ["rand", "logistic"]) @pytest.mark.parametrize("treat_prop", [0.1, 0.9]) def test_dataset_sanity_twins( train_ratio: float, treatment_type: str, treat_prop: float ) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load( "twins", train_ratio=train_ratio, treatment_type=treatment_type, treat_prop=treat_prop, ) total = X_train.shape[0] + X_test.shape[0] assert int(total * train_ratio) == X_train.shape[0] assert X_train.shape[1] == X_test.shape[1] assert X_train.shape[0] == Y_train.shape[0] assert X_train.shape[0] == Y_train_full.shape[0] assert X_train.shape[0] == W_train.shape[0] assert X_test.shape[0] == Y_test.shape[0] def test_dataset_sanity_ihdp() -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load("ihdp") assert X_train.shape[1] == X_test.shape[1] assert X_train.shape[0] == Y_train.shape[0] assert X_train.shape[0] == Y_train_full.shape[0] assert X_train.shape[0] == W_train.shape[0] assert X_test.shape[0] == Y_test.shape[0] @pytest.mark.slow @pytest.mark.parametrize("preprocessed", [False, True]) def test_dataset_sanity_acic2016(preprocessed: bool) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load( "acic2016", preprocessed=preprocessed ) assert X_train.shape[1] == X_test.shape[1] assert X_train.shape[0] == Y_train.shape[0] assert X_train.shape[0] == Y_train_full.shape[0] assert X_train.shape[0] == W_train.shape[0] assert X_test.shape[0] == Y_test.shape[0] ================================================ FILE: tests/models/jax/test_jax_ite.py ================================================ from copy import deepcopy import pytest from catenets.datasets import load from catenets.experiment_utils.tester import evaluate_treatments_model from catenets.models.jax import FLEXTE_NAME, OFFSET_NAME, FlexTENet, OffsetNet LAYERS_OUT = 2 LAYERS_R = 3 PENALTY_L2 = 0.01 / 100 PENALTY_ORTHOGONAL_IHDP = 0 PARAMS_DEPTH: dict = {"n_layers_r": 2, "n_layers_out": 2, "n_iter": 10} PENALTY_DIFF = 0.01 PENALTY_ORTHOGONAL = 0.1 ALL_MODELS = { OFFSET_NAME: OffsetNet(penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH), FLEXTE_NAME: FlexTENet( penalty_orthogonal=PENALTY_ORTHOGONAL, penalty_l2_p=PENALTY_DIFF, **PARAMS_DEPTH ), } models = list(ALL_MODELS.keys()) @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4), ("ihdp", 3)]) @pytest.mark.parametrize("model_name", models) def test_model_sanity(dataset: str, pehe_threshold: float, model_name: str) -> None: model = deepcopy(ALL_MODELS[model_name]) X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train) print(f"Evaluation for model jax.{model_name} on {dataset} = {score['str']}") def test_model_score() -> None: model = OffsetNet(n_iter=10) X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load("ihdp") model.fit(X_train[:10], Y_train[:10], W_train[:10]) result = model.score(X_test, Y_test) assert result > 0 with pytest.raises(ValueError): model.score(X_train, Y_train) # Y_train has just one outcome ================================================ FILE: tests/models/jax/test_jax_model_utils.py ================================================ from typing import Any import jax.numpy as jnp import numpy as np import pandas as pd import pytest from catenets.models.jax.model_utils import ( check_shape_1d_data, check_X_is_np, make_val_split, ) @pytest.mark.parametrize("data", [np.array([1, 2, 3]), np.array([[1, 2], [3, 4]])]) def test_check_shape_1d_data_sanity(data: np.ndarray) -> None: out = check_shape_1d_data(data) assert len(out.shape) == 2 @pytest.mark.parametrize("data", [np.array([1, 2, 3]), pd.DataFrame([1, 2])]) def test_check_X_is_np_sanity(data: Any) -> None: out = check_X_is_np(data) assert isinstance(out, jnp.ndarray) def test_make_val_split_sanity() -> None: X = np.random.rand(1000, 5) y = np.random.randint(0, 1, size=1000) w = np.random.randint(0, 1, size=1000) X_t, y_t, w_t, X_val, y_val, w_val, VALIDATION_STRING = make_val_split(X, y, w) assert X_t.shape[0] == 700 assert y_t.shape[0] == 700 assert w_t.shape[0] == 700 assert X_val.shape[0] == 300 assert y_val.shape[0] == 300 assert w_val.shape[0] == 300 assert VALIDATION_STRING == "validation" ================================================ FILE: tests/models/jax/test_jax_transformation_utils.py ================================================ from typing import Callable import numpy as np import pytest from catenets.models.jax.transformation_utils import ( ALL_TRANSFORMATIONS, DR_TRANSFORMATION, PW_TRANSFORMATION, RA_TRANSFORMATION, _get_transformation_function, aipw_te_transformation, ht_te_transformation, ra_te_transformation, ) def test_get_transformation_function_sanity() -> None: expected_fns = [ht_te_transformation, aipw_te_transformation, ra_te_transformation] for tr, expected in zip(ALL_TRANSFORMATIONS, expected_fns): assert _get_transformation_function(tr) is expected with pytest.raises(ValueError): _get_transformation_function("invalid") @pytest.mark.parametrize( "fn", [aipw_te_transformation, _get_transformation_function(DR_TRANSFORMATION)] ) def test_aipw_te_transformation_sanity(fn: Callable) -> None: res = fn( y=np.array([0, 1]), w=np.array([1, 0]), p=None, mu_0=np.array([0.4, 0.6]), mu_1=np.array([0.6, 0.4]), ) assert res.shape[0] == 2 @pytest.mark.parametrize( "fn", [ht_te_transformation, _get_transformation_function(PW_TRANSFORMATION)] ) def test_ht_te_transformation_sanity(fn: Callable) -> None: res = fn( y=np.array([0, 1]), w=np.array([1, 0]), ) assert res.shape[0] == 2 @pytest.mark.parametrize( "fn", [ra_te_transformation, _get_transformation_function(RA_TRANSFORMATION)] ) def test_ra_te_transformation_sanity(fn: Callable) -> None: res = fn( y=np.array([0, 1]), w=np.array([1, 0]), p=None, mu_0=np.array([0.4, 0.6]), mu_1=np.array([0.6, 0.4]), ) assert res.shape[0] == 2 ================================================ FILE: tests/models/torch/test_torch_flextenet.py ================================================ import numpy as np import pytest from catenets.datasets import load from catenets.experiment_utils.tester import evaluate_treatments_model from catenets.models.torch import FlexTENet def test_flextenet_model_params() -> None: model = FlexTENet( 2, binary_y=True, n_layers_out=1, n_layers_r=2, n_units_s_out=20, n_units_p_out=30, n_units_s_r=40, n_units_p_r=50, private_out=True, weight_decay=1e-5, penalty_orthogonal=1e-7, lr=1e-2, n_iter=123, batch_size=234, early_stopping=True, patience=5, n_iter_min=13, n_iter_print=7, seed=42, shared_repr=False, normalize_ortho=False, mode=1, ) assert model.binary_y is True assert model.n_layers_out == 1 assert model.n_layers_r == 2 assert model.n_units_s_out == 20 assert model.n_units_p_out == 30 assert model.n_units_s_r == 40 assert model.n_units_p_r == 50 assert model.private_out is True assert model.weight_decay == 1e-5 assert model.penalty_orthogonal == 1e-7 assert model.lr == 1e-2 assert model.n_iter == 123 assert model.batch_size == 234 assert model.early_stopping is True assert model.patience == 5 assert model.n_iter_min == 13 assert model.n_iter_print == 7 assert model.seed == 42 assert model.shared_repr is False assert model.normalize_ortho is False assert model.mode == 1 @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4), ("ihdp", 1.5)]) def test_flextenet_model_sanity(dataset: str, pehe_threshold: float) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = FlexTENet( X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), batch_size=1024, lr=1e-3, n_iter=10, ) score = evaluate_treatments_model( model, X_train, Y_train, Y_train_full, W_train, n_folds=2 ) print(f"Evaluation for model FlexTENet on {dataset} = {score['str']}") @pytest.mark.parametrize("shared_repr", [False, True]) @pytest.mark.parametrize("private_out", [False, True]) @pytest.mark.parametrize("n_units_p_r", [50, 150]) def test_flextenet_model_predict_api( shared_repr: bool, private_out: bool, n_units_p_r: int ) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load("ihdp") W_train = W_train.ravel() model = FlexTENet( X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), batch_size=1024, lr=1e-3, shared_repr=shared_repr, private_out=private_out, n_units_p_r=n_units_p_r, n_iter=10, ) model.fit(X_train, Y_train, W_train) out = model.predict(X_test) assert len(out) == len(X_test) out, p0, p1 = model.predict(X_test, return_po=True) assert len(out) == len(X_test) assert len(p0) == len(X_test) assert len(p1) == len(X_test) score = model.score(X_test, Y_test) assert score > 0 ================================================ FILE: tests/models/torch/test_torch_pseudo_outcome_nets.py ================================================ from typing import Any import numpy as np import pytest from sklearn.ensemble import RandomForestRegressor from torch import nn from xgboost import XGBClassifier from catenets.datasets import load from catenets.experiment_utils.tester import evaluate_treatments_model from catenets.models.torch import ( DRLearner, PWLearner, RALearner, RLearner, ULearner, XLearner, ) @pytest.mark.parametrize( "model_t", [DRLearner, PWLearner, RALearner, RLearner, ULearner, XLearner] ) def test_nn_model_params(model_t: Any) -> None: model = model_t( 2, binary_y=True, ) assert model._te_estimator is not None assert model._po_estimator is not None assert model._propensity_estimator is not None @pytest.mark.parametrize("nonlin", ["elu", "relu", "sigmoid"]) @pytest.mark.parametrize( "model_t", [DRLearner, PWLearner, RALearner, RLearner, ULearner, XLearner] ) def test_nn_model_params_nonlin(nonlin: str, model_t: Any) -> None: model = model_t(2, binary_y=True, nonlin=nonlin) nonlins = { "elu": nn.ELU, "relu": nn.ReLU, "sigmoid": nn.Sigmoid, } for mod in [model._te_estimator, model._po_estimator, model._propensity_estimator]: assert isinstance(mod.model[2], nonlins[nonlin]) @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4), ("ihdp", 4)]) @pytest.mark.parametrize("model_t", [DRLearner, RALearner, XLearner]) def test_nn_model_sanity(dataset: str, pehe_threshold: float, model_t: Any) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = model_t( X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), n_iter=10 ) score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train) print( f"Evaluation for model torch.{model_t} with NNs on {dataset} = {score['str']}" ) @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4)]) @pytest.mark.parametrize( "po_estimator", [ XGBClassifier( n_estimators=100, reg_lambda=1e-3, reg_alpha=1e-3, colsample_bytree=0.1, colsample_bynode=0.1, colsample_bylevel=0.1, max_depth=6, tree_method="hist", learning_rate=1e-2, min_child_weight=0, max_bin=256, random_state=0, eval_metric="logloss", use_label_encoder=False, ), ], ) @pytest.mark.parametrize( "te_estimator", [ RandomForestRegressor( n_estimators=100, max_depth=6, ), ], ) @pytest.mark.parametrize("model_t", [DRLearner, RALearner]) def test_sklearn_model_pseudo_outcome_binary( dataset: str, pehe_threshold: float, po_estimator: Any, te_estimator: Any, model_t: Any, ) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = model_t( X_train.shape[1], binary_y=True, po_estimator=po_estimator, te_estimator=te_estimator, batch_size=1024, n_iter=10, ) score = evaluate_treatments_model( model, X_train, Y_train, Y_train_full, W_train, n_folds=3 ) print( f"Evaluation for model {model_t} with po_estimator = {type(po_estimator)}," f"te_estimator = {type(te_estimator)} on {dataset} = {score['str']}" ) def test_model_predict_api() -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load("ihdp") W_train = W_train.ravel() model = XLearner(X_train.shape[1], binary_y=False, batch_size=1024, n_iter=10) model.fit(X_train, Y_train, W_train) out = model.predict(X_test) assert len(out) == len(X_test) score = model.score(X_test, Y_test) assert score > 0 ================================================ FILE: tests/models/torch/test_torch_representation_net.py ================================================ from typing import Type import pytest from torch import nn from catenets.datasets import load from catenets.experiment_utils.tester import evaluate_treatments_model from catenets.models.torch import DragonNet, TARNet @pytest.mark.parametrize("snet", [TARNet, DragonNet]) def test_model_params(snet: Type) -> None: model = snet( 2, binary_y=True, n_layers_out=1, n_units_out=2, n_layers_r=3, n_units_r=4, weight_decay=0.5, lr=0.6, n_iter=700, batch_size=80, val_split_prop=0.9, n_iter_print=10, seed=11, ) assert model._repr_estimator is not None assert model._propensity_estimator is not None assert len(model._po_estimators) == 2 for mod in model._po_estimators: assert len(mod.model) == 5 # 1 in + NL + 4 * (n_layers_out - 1) + 1 out + NL assert len(model._repr_estimator.model) == 9 @pytest.mark.parametrize("nonlin", ["elu", "relu", "sigmoid"]) @pytest.mark.parametrize("snet", [TARNet, DragonNet]) def test_model_params_nonlin(nonlin: str, snet: Type) -> None: model = snet(2, nonlin=nonlin) nonlins = { "elu": nn.ELU, "relu": nn.ReLU, "sigmoid": nn.Sigmoid, } for mod in [ model._repr_estimator, model._po_estimators[0], model._po_estimators[1], model._propensity_estimator, ]: assert isinstance(mod.model[2], nonlins[nonlin]) @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4)]) @pytest.mark.parametrize("snet", [TARNet, DragonNet]) def test_model_sanity(dataset: str, pehe_threshold: float, snet: Type) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = snet( X_train.shape[1], batch_size=256, n_iter=10, ) score = evaluate_treatments_model( model, X_train, Y_train, Y_train_full, W_train, n_folds=3 ) print(f"Evaluation for model {snet} on {dataset} = {score['str']}") assert score["raw"]["pehe"][0] < pehe_threshold def test_model_predict_api() -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load("ihdp") W_train = W_train.ravel() model = TARNet(X_train.shape[1], batch_size=1024, n_iter=10) model.fit(X_train, Y_train, W_train) out = model.predict(X_test) assert len(out) == len(X_test) out, p0, p1 = model.predict(X_test, return_po=True) assert len(out) == len(X_test) assert len(p0) == len(X_test) assert len(p1) == len(X_test) score = model.score(X_test, Y_test) assert score > 0 ================================================ FILE: tests/models/torch/test_torch_slearner.py ================================================ from typing import Any, Optional import numpy as np import pytest from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.linear_model import LogisticRegression from torch import nn from xgboost import XGBClassifier from catenets.datasets import load from catenets.experiment_utils.tester import evaluate_treatments_model from catenets.models.torch import SLearner def test_nn_model_params() -> None: model = SLearner( 2, binary_y=True, n_layers_out=1, n_units_out=2, n_units_out_prop=33, n_layers_out_prop=12, weight_decay=0.5, lr=0.6, n_iter=700, batch_size=80, val_split_prop=0.9, n_iter_print=10, seed=11, weighting_strategy="ipw", ) assert model._weighting_strategy == "ipw" assert model._propensity_estimator is not None assert model._po_estimator is not None assert model._po_estimator.n_iter == 700 assert model._po_estimator.batch_size == 80 assert model._po_estimator.n_iter_print == 10 assert model._po_estimator.seed == 11 assert model._po_estimator.val_split_prop == 0.9 assert ( len(model._po_estimator.model) == 5 ) # 1 in + NL + 3 * (n_layers_hidden -1) + out + Sigmoid assert model._propensity_estimator.n_iter == 700 assert model._propensity_estimator.batch_size == 80 assert model._propensity_estimator.n_iter_print == 10 assert model._propensity_estimator.seed == 11 assert model._propensity_estimator.val_split_prop == 0.9 assert ( len(model._propensity_estimator.model) == 38 ) # 1 in + NL + 3 * (n_layers_hidden - 1) + out + Softmax @pytest.mark.parametrize("nonlin", ["elu", "relu", "sigmoid"]) def test_nn_model_params_nonlin(nonlin: str) -> None: model = SLearner(2, True, nonlin=nonlin, weighting_strategy="ipw") nonlins = { "elu": nn.ELU, "relu": nn.ReLU, "sigmoid": nn.Sigmoid, } for mod in [model._propensity_estimator, model._po_estimator]: assert isinstance(mod.model[2], nonlins[nonlin]) @pytest.mark.parametrize("weighting_strategy", ["ipw", None]) @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4), ("ihdp", 1.5)]) def test_nn_model_sanity( dataset: str, pehe_threshold: float, weighting_strategy: Optional[str] ) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = SLearner( X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), weighting_strategy=weighting_strategy, n_iter=10, ) score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train) print( f"Evaluation for model torch.SLearner(NN)(weighting_strategy={weighting_strategy}) on {dataset} = {score['str']}" ) @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4)]) @pytest.mark.parametrize( "po_estimator", [ XGBClassifier( n_estimators=100, reg_lambda=1e-3, reg_alpha=1e-3, colsample_bytree=0.1, colsample_bynode=0.1, colsample_bylevel=0.1, max_depth=6, tree_method="hist", learning_rate=1e-2, min_child_weight=0, max_bin=256, random_state=0, eval_metric="logloss", use_label_encoder=False, ), RandomForestClassifier( n_estimators=100, max_depth=6, ), LogisticRegression( C=1.0, solver="sag", max_iter=10000, penalty="l2", ), ], ) def test_sklearn_model_sanity_binary_output( dataset: str, pehe_threshold: float, po_estimator: Any ) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = SLearner( X_train.shape[1], binary_y=True, po_estimator=po_estimator, n_iter=10, ) score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train) print( f"Evaluation for model torch.SLearner with {po_estimator.__class__} on {dataset} = {score['str']}" ) assert score["raw"]["pehe"][0] < pehe_threshold @pytest.mark.parametrize("exp", [1, 10, 40, 50, 99]) @pytest.mark.parametrize( "po_estimator", [ RandomForestRegressor( n_estimators=100, max_depth=6, ), ], ) def test_slearner_sklearn_model_ihdp(po_estimator: Any, exp: int) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load( "ihdp", exp=exp, rescale=True ) W_train = W_train.ravel() model = SLearner( X_train.shape[1], binary_y=False, po_estimator=po_estimator, n_iter=10, ) score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train) print( f"Evaluation for model torch.SLearner with {po_estimator.__class__} on ihdp[{exp}] = {score['str']}" ) assert score["raw"]["pehe"][0] < 1.5 def test_model_predict_api() -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load("ihdp") W_train = W_train.ravel() model = SLearner(X_train.shape[1], binary_y=False, batch_size=1024, n_iter=10) model.fit(X_train, Y_train, W_train) out = model.predict(X_test) assert len(out) == len(X_test) out, p0, p1 = model.predict(X_test, return_po=True) assert len(out) == len(X_test) assert len(p0) == len(X_test) assert len(p1) == len(X_test) score = model.score(X_test, Y_test) assert score > 0 ================================================ FILE: tests/models/torch/test_torch_snet.py ================================================ import numpy as np import pytest from torch import nn from catenets.datasets import load from catenets.experiment_utils.tester import evaluate_treatments_model from catenets.models.torch import SNet def test_model_params() -> None: # with propensity estimator model = SNet( 2, binary_y=True, n_layers_out=1, n_units_out=2, n_layers_r=3, n_units_r=4, weight_decay=0.5, lr=0.6, n_iter=700, batch_size=80, val_split_prop=0.9, n_iter_print=10, seed=11, ) assert model._reps_c is not None assert model._reps_o is not None assert model._reps_mu0 is not None assert model._reps_mu1 is not None assert model._reps_prop is not None assert model._propensity_estimator is not None assert len(model._po_estimators) == 2 for mod in model._po_estimators: assert len(mod.model) == 5 # 1 in + NL + 4 * (n_layers_out - 1) + 1 out + NL assert len(model._reps_c.model) == 9 assert len(model._reps_o.model) == 9 assert len(model._reps_mu0.model) == 9 assert len(model._reps_mu1.model) == 9 assert len(model._propensity_estimator.model) == 8 # remove propensity estimator model = SNet( 2, binary_y=True, n_layers_out=1, n_units_out=2, n_layers_r=3, n_units_r=4, weight_decay=0.5, lr=0.6, n_iter=700, batch_size=80, val_split_prop=0.9, n_iter_print=10, seed=11, with_prop=False, ) with np.testing.assert_raises(AttributeError): model._reps_c with np.testing.assert_raises(AttributeError): model._reps_prop with np.testing.assert_raises(AttributeError): model._propensity_estimator assert model._reps_o is not None assert model._reps_mu0 is not None assert model._reps_mu1 is not None assert len(model._po_estimators) == 2 for mod in model._po_estimators: assert len(mod.model) == 5 # 1 in + NL + 4 * (n_layers_out - 1) + 1 out + NL assert len(model._reps_o.model) == 9 assert len(model._reps_mu0.model) == 9 assert len(model._reps_mu1.model) == 9 @pytest.mark.parametrize("nonlin", ["elu", "relu", "sigmoid", "selu", "leaky_relu"]) def test_model_params_nonlin(nonlin: str) -> None: model = SNet(2, nonlin=nonlin) nonlins = { "elu": nn.ELU, "relu": nn.ReLU, "sigmoid": nn.Sigmoid, "selu": nn.SELU, "leaky_relu": nn.LeakyReLU, } for mod in [ model._reps_c, model._reps_o, model._reps_mu0, model._reps_mu1, model._reps_prop, model._po_estimators[0], model._po_estimators[1], model._propensity_estimator, ]: assert isinstance(mod.model[2], nonlins[nonlin]) @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4)]) def test_model_sanity(dataset: str, pehe_threshold: float) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() # with propensity estimator model = SNet( X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), batch_size=1024, n_iter=10, ) score = evaluate_treatments_model( model, X_train, Y_train, Y_train_full, W_train, n_folds=3 ) print(f"Evaluation for model SNet on {dataset} = {score['str']}") model = SNet( X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), batch_size=1024, n_iter=10, with_prop=False, ) score = evaluate_treatments_model( model, X_train, Y_train, Y_train_full, W_train, n_folds=3 ) print(f"Evaluation for model SNet (with_prop=False) on {dataset} = {score['str']}") def test_model_predict_api() -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load("ihdp") W_train = W_train.ravel() model = SNet(X_train.shape[1], batch_size=1024, n_iter=10) model.fit(X_train, Y_train, W_train) out = model.predict(X_test) assert len(out) == len(X_test) out, p0, p1 = model.predict(X_test, return_po=True) assert len(out) == len(X_test) assert len(p0) == len(X_test) assert len(p1) == len(X_test) score = model.score(X_test, Y_test) assert score > 0 ================================================ FILE: tests/models/torch/test_torch_tlearner.py ================================================ from typing import Any import numpy as np import pytest from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.linear_model import LogisticRegression from torch import nn from xgboost import XGBClassifier, XGBRegressor from catenets.datasets import load from catenets.experiment_utils.tester import evaluate_treatments_model from catenets.models.torch import TLearner def test_nn_model_params() -> None: model = TLearner( 2, True, n_layers_out=1, n_units_out=2, weight_decay=0.5, lr=0.6, n_iter=700, batch_size=80, val_split_prop=0.9, n_iter_print=10, seed=11, ) assert len(model._plug_in) == 2 for mod in model._plug_in: assert mod.n_iter == 700 assert mod.batch_size == 80 assert mod.n_iter_print == 10 assert mod.seed == 11 assert mod.val_split_prop == 0.9 assert len(mod.model) == 5 # 2 in + NL + 3 * (n_layers_hidden - 1) + 2 out @pytest.mark.parametrize("nonlin", ["elu", "relu", "sigmoid"]) def test_nn_model_params_nonlin(nonlin: str) -> None: model = TLearner(2, True, nonlin=nonlin) assert len(model._plug_in) == 2 nonlins = { "elu": nn.ELU, "relu": nn.ReLU, "sigmoid": nn.Sigmoid, } for mod in model._plug_in: assert isinstance(mod.model[2], nonlins[nonlin]) @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4), ("ihdp", 1.5)]) def test_nn_model_sanity(dataset: str, pehe_threshold: float) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = TLearner( X_train.shape[1], binary_y=(len(np.unique(Y_train)) == 2), n_iter=10 ) score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train) print(f"Evaluation for model torch.TLearner(NN) on {dataset} = {score['str']}") @pytest.mark.parametrize("dataset, pehe_threshold", [("twins", 0.4)]) @pytest.mark.parametrize( "po_estimator", [ XGBClassifier( n_estimators=100, reg_lambda=1e-3, reg_alpha=1e-3, colsample_bytree=0.1, colsample_bynode=0.1, colsample_bylevel=0.1, max_depth=6, tree_method="hist", learning_rate=1e-2, min_child_weight=0, max_bin=256, random_state=0, eval_metric="logloss", use_label_encoder=False, ), RandomForestClassifier( n_estimators=100, max_depth=6, ), LogisticRegression( C=1.0, solver="sag", max_iter=10000, penalty="l2", ), ], ) def test_sklearn_model_sanity_binary_output( dataset: str, pehe_threshold: float, po_estimator: Any ) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = TLearner( X_train.shape[1], binary_y=True, po_estimator=po_estimator, n_iter=10, ) score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train) print( f"Evaluation for model torch.TLearner with {po_estimator.__class__} on {dataset} = {score['str']}" ) assert score["raw"]["pehe"][0] < pehe_threshold @pytest.mark.parametrize("dataset, pehe_threshold", [("ihdp", 1.5)]) @pytest.mark.parametrize( "po_estimator", [ XGBRegressor( n_estimators=1000, reg_lambda=1e-3, reg_alpha=1e-3, colsample_bytree=0.1, colsample_bynode=0.1, colsample_bylevel=0.1, max_depth=7, tree_method="hist", learning_rate=1e-2, min_child_weight=0, max_bin=256, random_state=0, eval_metric="logloss", ), RandomForestRegressor( n_estimators=100, max_depth=6, ), ], ) def test_sklearn_model_sanity_regression( dataset: str, pehe_threshold: float, po_estimator: Any ) -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load(dataset) W_train = W_train.ravel() model = TLearner( X_train.shape[1], binary_y=False, po_estimator=po_estimator, n_iter=10, ) score = evaluate_treatments_model(model, X_train, Y_train, Y_train_full, W_train) print( f"Evaluation for model torch.TLearner with {po_estimator.__class__ } on {dataset} = {score['str']}" ) def test_model_predict_api() -> None: X_train, W_train, Y_train, Y_train_full, X_test, Y_test = load("ihdp") W_train = W_train.ravel() model = TLearner( X_train.shape[1], binary_y=False, n_iter=10, ) model.fit(X_train, Y_train, W_train) out = model.predict(X_test) assert len(out) == len(X_test) out, p0, p1 = model.predict(X_test, return_po=True) assert len(out) == len(X_test) assert len(p0) == len(X_test) assert len(p1) == len(X_test) score = model.score(X_test, Y_test) assert score > 0