Showing preview only (553K chars total). Download the full file or copy to clipboard to get everything.
Repository: AliciaCurth/CATENets
Branch: main
Commit: f8c961c30776
Files: 91
Total size: 523.3 KB
Directory structure:
gitextract_9k1q0bj_/
├── .github/
│ └── workflows/
│ ├── release.yml
│ ├── scripts/
│ │ ├── release_linux.sh
│ │ ├── release_osx.sh
│ │ └── release_windows.bat
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── catenets/
│ ├── __init__.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── dataset_acic2016.py
│ │ ├── dataset_ihdp.py
│ │ ├── dataset_twins.py
│ │ └── network.py
│ ├── experiment_utils/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── simulation_utils.py
│ │ ├── tester.py
│ │ └── torch_metrics.py
│ ├── logger.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── jax/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── disentangled_nets.py
│ │ │ ├── flextenet.py
│ │ │ ├── model_utils.py
│ │ │ ├── offsetnet.py
│ │ │ ├── pseudo_outcome_nets.py
│ │ │ ├── representation_nets.py
│ │ │ ├── rnet.py
│ │ │ ├── snet.py
│ │ │ ├── tnet.py
│ │ │ ├── transformation_utils.py
│ │ │ └── xnet.py
│ │ └── torch/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── flextenet.py
│ │ ├── pseudo_outcome_nets.py
│ │ ├── representation_nets.py
│ │ ├── slearner.py
│ │ ├── snet.py
│ │ ├── tlearner.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── decorators.py
│ │ ├── model_utils.py
│ │ ├── transformations.py
│ │ └── weight_utils.py
│ └── version.py
├── docs/
│ ├── Makefile
│ ├── conf.py
│ ├── datasets.rst
│ ├── index.rst
│ ├── jax_models.rst
│ ├── make.bat
│ ├── requirements.txt
│ └── torch_models.rst
├── experiments/
│ ├── __init__.py
│ ├── experiments_AISTATS21/
│ │ ├── ihdp_experiments.py
│ │ └── simulations_AISTATS.py
│ ├── experiments_benchmarks_NeurIPS21/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── acic_experiments_catenets.py
│ │ ├── acic_experiments_grf.R
│ │ ├── ihdp_experiments_catenets.py
│ │ ├── ihdp_experiments_grf.R
│ │ ├── twins_experiments_catenets.py
│ │ └── twins_experiments_grf.R
│ └── experiments_inductivebias_NeurIPS21/
│ ├── __init__.py
│ ├── experiments_AB.py
│ ├── experiments_CD.py
│ ├── experiments_acic.py
│ └── experiments_twins.py
├── pyproject.toml
├── pytest.ini
├── run_experiments_AISTATS.py
├── run_experiments_benchmarks_NeurIPS.py
├── run_experiments_inductive_bias_NeurIPS.py
├── setup.py
└── tests/
├── conftest.py
├── datasets/
│ └── test_datasets.py
└── models/
├── jax/
│ ├── test_jax_ite.py
│ ├── test_jax_model_utils.py
│ └── test_jax_transformation_utils.py
└── torch/
├── test_torch_flextenet.py
├── test_torch_pseudo_outcome_nets.py
├── test_torch_representation_net.py
├── test_torch_slearner.py
├── test_torch_snet.py
└── test_torch_tlearner.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/release.yml
================================================
name: Package release
on:
release:
types: [created]
jobs:
deploy_osx:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
os: [macos-latest]
steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_osx.sh
deploy_linux:
strategy:
matrix:
python-version:
- cp37-cp37m
- cp38-cp38
- cp39-cp39
- cp310-cp310
runs-on: ubuntu-latest
container: quay.io/pypa/manylinux2014_x86_64
steps:
- uses: actions/checkout@v1
with:
submodules: true
- name: Set target Python version PATH
run: |
echo "/opt/python/${{ matrix.python-version }}/bin" >> $GITHUB_PATH
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_linux.sh
deploy_windows:
runs-on: windows-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
../../.github/workflows/scripts/release_windows.bat
================================================
FILE: .github/workflows/scripts/release_linux.sh
================================================
#!/bin/bash
set -e
yum makecache -y
yum install centos-release-scl -y
yum-config-manager --enable rhel-server-rhscl-7-rpms
yum install llvm-toolset-7.0 python3 python3-devel -y
# Python
python3 -m pip install --upgrade pip
python3 -m pip install setuptools wheel twine auditwheel
# Publish
python3 -m pip wheel . -w dist/ --no-deps
twine upload --verbose --skip-existing dist/*
================================================
FILE: .github/workflows/scripts/release_osx.sh
================================================
#!/bin/sh
export MACOSX_DEPLOYMENT_TARGET=10.14
python -m pip install --upgrade pip
pip install setuptools wheel twine auditwheel
python3 setup.py build bdist_wheel --plat-name macosx_10_14_x86_64 --dist-dir wheel
twine upload --skip-existing wheel/*
================================================
FILE: .github/workflows/scripts/release_windows.bat
================================================
echo on
python -m pip install --upgrade pip
pip install setuptools wheel twine auditwheel
pip wheel . -w wheel/ --no-deps
twine upload --skip-existing wheel/*
================================================
FILE: .github/workflows/test.yml
================================================
name: CATENets Tests
on:
push:
branches: [main, release]
pull_request:
types: [opened, synchronize, reopened]
schedule:
- cron: '0 0 * * 0'
workflow_dispatch:
jobs:
Linter:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.8]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install .[testing]
- name: pre-commit validation
run: pre-commit run --files catenets/*
- name: Security checks
run: |
bandit -r catenets/*
Library:
needs: [Linter]
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.8', '3.9', "3.10"]
os: [macos-latest, ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install MacOS dependencies
run: |
brew install libomp
if: ${{ matrix.os == 'macos-latest' }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[testing]
- name: Test with pytest Unix
run: pytest -vvvsx -m "not slow"
if: ${{ matrix.os != 'windows-latest' }}
- name: Test with pytest Windows
run: |
cd tests\datasets
pytest -vvvsx -m "not slow"
cd ..\..
cd tests\models\torch
pytest -vvvsx -m "not slow"
if: ${{ matrix.os == 'windows-latest' }}
================================================
FILE: .gitignore
================================================
*.pyc
*.xml
*.iml
*.csv
*.xlsx
*.Rhistory
.idea/
.coverage
.ipynb_checkpoints
.ipynb_checkpoints/
*/.ipynb_checkpoints/
*/bin/
*/include/
*/lib/
*/lib64/
*/share/
*.cfg
.pytest_cache
data/
build/
catenets.egg-info/
dist/
generated/
_build
================================================
FILE: .pre-commit-config.yaml
================================================
exclude: 'setup.py|^docs'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
- id: check-ast
- id: check-json
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
- id: debug-statements
- id: check-executables-have-shebangs
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: mixed-line-ending
args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 3.9.1
hooks:
- id: flake8
args: [
"--max-line-length=140",
"--extend-ignore=E203,W503"
]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
hooks:
- id: mypy
args: [
"--ignore-missing-imports",
"--scripts-are-modules",
"--disallow-incomplete-defs",
"--no-implicit-optional",
"--warn-unused-ignores",
"--warn-redundant-casts",
"--strict-equality",
"--warn-unreachable",
"--disallow-untyped-defs",
"--disallow-untyped-calls",
]
- repo: local
hooks:
- id: flynt
name: flynt
entry: flynt
args: [--fail-on-change]
types: [python]
language: python
additional_dependencies:
- flynt
================================================
FILE: LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2021, Alicia Curth
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# CATENets - Conditional Average Treatment Effect Estimation Using Neural Networks
[](https://github.com/AliciaCurth/CATENets/actions/workflows/test.yml)
[](https://catenets.readthedocs.io/en/latest/?badge=latest)
[](https://github.com/AliciaCurth/CATENets/blob/main/LICENSE)
Code Author: Alicia Curth (amc253@cam.ac.uk)
This repo contains Jax-based, sklearn-style implementations of Neural Network-based Conditional
Average Treatment Effect (CATE) Estimators, which were used in the AISTATS21 paper
['Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning
Algorithms']( https://arxiv.org/abs/2101.10943) (Curth & vd Schaar, 2021a) as well as the follow up
NeurIPS21 paper ["On Inductive Biases for Heterogeneous Treatment Effect Estimation"](https://arxiv.org/abs/2106.03765) (Curth & vd
Schaar, 2021b) and the NeurIPS21 Datasets & Benchmarks track paper ["Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation"](https://openreview.net/forum?id=FQLzQqGEAH) (Curth et al, 2021).
We implement the SNet-class we introduce in Curth & vd Schaar (2021a), as well as FlexTENet and
OffsetNet as discussed in Curth & vd Schaar (2021b), and re-implement a number of
NN-based algorithms from existing literature (Shalit et al (2017), Shi et al (2019), Hassanpour
& Greiner (2020)). We also provide Neural Network (NN)-based instantiations of a number of so-called
meta-learners for CATE estimation, including two-step pseudo-outcome regression estimators (the
DR-learner (Kennedy, 2020) and single-robust propensity-weighted (PW) and regression-adjusted (RA) learners), Nie & Wager (2017)'s R-learner and Kuenzel et al (2019)'s X-learner. The jax implementations in ``catenets.models.jax`` were used in all papers listed; additionally, pytorch versions of some models (``catenets.models.torch``) were contributed by [Bogdan Cebere](https://github.com/bcebere).
### Interface
The repo contains a package ``catenets``, which contains all general code used for modeling and evaluation, and a folder ``experiments``, in which the code for replicating experimental results is contained. All implemented learning algorithms in ``catenets`` (``SNet, FlexTENet, OffsetNet, TNet, SNet1 (TARNet), SNet2
(DragonNet), SNet3, DRNet, RANet, PWNet, RNet, XNet``) come with a sklearn-style wrapper, implementing a ``.fit(X, y, w)`` and a ``.predict(X)`` method, where predict returns CATE by default. All hyperparameters are documented in detail in the respective files in ``catenets.models`` folder.
Example usage:
```python
from catenets.models.jax import TNet, SNet
from catenets.experiment_utils.simulation_utils import simulate_treatment_setup
# simulate some data (here: unconfounded, 10 prognostic variables and 5 predictive variables)
X, y, w, p, cate = simulate_treatment_setup(n=2000, n_o=10, n_t=5, n_c=0)
# estimate CATE using TNet
t = TNet()
t.fit(X, y, w)
cate_pred_t = t.predict(X) # without potential outcomes
cate_pred_t, po0_pred_t, po1_pred_t = t.predict(X, return_po=True) # predict potential outcomes too
# estimate CATE using SNet
s = SNet(penalty_orthogonal=0.01)
s.fit(X, y, w)
cate_pred_s = s.predict(X)
```
All experiments in Curth & vd Schaar (2021a) can be replicated using this repository; the necessary
code is in ``experiments.experiments_AISTATS21``. To do so from shell, clone the repo, create a new
virtual environment and run
```shell
pip install catenets # install the library from PyPI
# OR
pip install . # install the library from the local repository
# Run the experiments
python run_experiments_AISTATS.py
```
```shell
Options:
--experiment # defaults to 'simulation', 'ihdp' will run ihdp experiments
--setting # different simulation settings in synthetic experiments (can be 1-5)
--models # defaults to None which will train all models considered in paper,
# can be string of model name (e.g 'TNet'), 'plug' for all plugin models,
# 'pseudo' for all pseudo-outcome regression models
--file_name # base file name to write to, defaults to 'results'
--n_repeats # number of experiments to run for each configuration, defaults to 10 (should be set to 100 for IHDP)
```
Similarly, the experiments in Curth & vd Schaar (2021b) can be replicated using the code in
``experiments.experiments_inductivebias_NeurIPS21`` (or from shell using ```python
run_experiments_inductive_bias_NeurIPS.py```) and the experiments in Curth et al (2021) can be replicated using the code in ``experiments.experiments_benchmarks_NeurIPS21`` (the catenets experiments can also be run from shell using ``python run_experiments_benchmarks_NeurIPS``).
The code can also be installed as a python package (``catenets``). From a local copy of the repo, run ``python setup.py install``.
Note: jax is currently only supported on macOS and linux, but can be run from windows using WSL (the windows subsystem for linux).
### Citing
If you use this software please cite the corresponding paper(s):
```
@inproceedings{curth2021nonparametric,
title={Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning Algorithms},
author={Curth, Alicia and van der Schaar, Mihaela},
year={2021},
booktitle={Proceedings of the 24th International Conference on Artificial
Intelligence and Statistics (AISTATS)},
organization={PMLR}
}
@article{curth2021inductive,
title={On Inductive Biases for Heterogeneous Treatment Effect Estimation},
author={Curth, Alicia and van der Schaar, Mihaela},
booktitle={Proceedings of the Thirty-Fifth Conference on Neural Information Processing Systems},
year={2021}
}
@article{curth2021really,
title={Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation},
author={Curth, Alicia and Svensson, David and Weatherall, James and van der Schaar, Mihaela},
booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks},
year={2021}
}
```
================================================
FILE: catenets/__init__.py
================================================
import sys
from . import logger # noqa: F401
from . import datasets, models # noqa: F401
logger.add(sink=sys.stderr, level="CRITICAL")
================================================
FILE: catenets/datasets/__init__.py
================================================
# stdlib
import os
from pathlib import Path
from typing import Any, Tuple
from . import dataset_acic2016, dataset_ihdp, dataset_twins
DATA_PATH = Path(os.path.dirname(__file__)) / Path("data")
try:
os.mkdir(DATA_PATH)
except BaseException:
pass
def load(dataset: str, *args: Any, **kwargs: Any) -> Tuple:
"""
Input:
dataset: the name of the dataset to load
Outputs:
- Train_X, Test_X: Train and Test features
- Train_Y: Observable outcomes
- Train_T: Assigned treatment
- Test_Y: Potential outcomes.
"""
if dataset == "twins":
return dataset_twins.load(DATA_PATH, *args, **kwargs)
if dataset == "ihdp":
return dataset_ihdp.load(DATA_PATH, *args, **kwargs)
if dataset == "acic2016":
return dataset_acic2016.load(DATA_PATH, *args, **kwargs)
else:
raise Exception("Unsupported dataset")
__all__ = ["dataset_ihdp", "dataset_twins", "dataset_acic2016", "load"]
================================================
FILE: catenets/datasets/dataset_acic2016.py
================================================
"""
ACIC2016 dataset
"""
import glob
# stdlib
import random
from pathlib import Path
from typing import Any, Tuple
# third party
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import catenets.logger as log
from .network import download_if_needed
np.random.seed(0)
random.seed(0)
FILE_ID = "0B7pG5PPgj6A3N09ibmFwNWE1djA"
PREPROCESSED_FILE_ID = "1iOfEAk402o3jYBs2Prfiz6oaailwWcR5"
NUMERIC_COLS = [
0,
3,
4,
16,
17,
18,
20,
21,
22,
24,
24,
25,
30,
31,
32,
33,
39,
40,
41,
53,
54,
]
N_NUM_COLS = len(NUMERIC_COLS)
def get_acic_covariates(
fn_csv: Path, keep_categorical: bool = False, preprocessed: bool = True
) -> np.ndarray:
X = pd.read_csv(fn_csv)
if not keep_categorical:
X = X.drop(columns=["x_2", "x_21", "x_24"])
else:
# encode categorical features
feature_list = []
for cols_ in X.columns:
if type(X.loc[X.index[0], cols_]) not in [np.int64, np.float64]:
enc = OneHotEncoder(drop="first")
enc.fit(np.array(X[[cols_]]).reshape((-1, 1)))
for k in range(len(list(enc.get_feature_names()))):
X[cols_ + list(enc.get_feature_names())[k]] = enc.transform(
np.array(X[[cols_]]).reshape((-1, 1))
).toarray()[:, k]
feature_list.append(cols_)
X.drop(feature_list, axis=1, inplace=True)
if preprocessed:
X_t = X.values
else:
scaler = StandardScaler()
X_t = scaler.fit_transform(X)
return X_t
def preprocess_simu(
fn_csv: Path,
n_0: int = 2000,
n_1: int = 200,
n_test: int = 500,
error_sd: float = 1,
sp_lin: float = 0.6,
sp_nonlin: float = 0.3,
prop_gamma: float = 0,
prop_omega: float = 0,
ate_goal: float = 0,
inter: bool = True,
i_exp: int = 0,
keep_categorical: bool = False,
preprocessed: bool = True,
) -> Tuple:
X = get_acic_covariates(
fn_csv, keep_categorical=keep_categorical, preprocessed=preprocessed
)
np.random.seed(i_exp)
# shuffle indices
n_total, n_cov = X.shape
ind = np.arange(n_total)
np.random.shuffle(ind)
ind_test = ind[-n_test:]
ind_1 = ind[n_0 : (n_0 + n_1)]
# create treatment indicator (treatment assignment does not matter in test set)
w = np.zeros(n_total).reshape((-1, 1))
w[ind_1] = 1
# create dgp
coeffs_ = [0, 1]
# sample baseline coefficients
beta_0 = np.random.choice(coeffs_, size=n_cov, replace=True, p=[1 - sp_lin, sp_lin])
intercept = np.random.choice([x for x in np.arange(-1, 1.25, 0.25)])
# sample treatment effect coefficients
gamma = np.random.choice(
coeffs_, size=n_cov, replace=True, p=[1 - prop_gamma, prop_gamma]
)
omega = np.random.choice(
[0, 1], replace=True, size=n_cov, p=[prop_omega, 1 - prop_omega]
)
# simulate mu_0 and mu_1
mu_0 = (intercept + np.dot(X, beta_0)).reshape((-1, 1))
mu_1 = (intercept + np.dot(X, gamma + beta_0 * omega)).reshape((-1, 1))
if sp_nonlin > 0:
coefs_sq = [0, 0.1]
beta_sq = np.random.choice(
coefs_sq, size=N_NUM_COLS, replace=True, p=[1 - sp_nonlin, sp_nonlin]
)
omega = np.random.choice(
[0, 1], replace=True, size=N_NUM_COLS, p=[prop_omega, 1 - prop_omega]
)
X_sq = X[:, NUMERIC_COLS] ** 2
mu_0 = mu_0 + np.dot(X_sq, beta_sq).reshape((-1, 1))
mu_1 = mu_1 + np.dot(X_sq, beta_sq * omega).reshape((-1, 1))
if inter:
# randomly add some interactions
ind_c = np.arange(n_cov)
np.random.shuffle(ind_c)
inter_list = list()
for i in range(0, n_cov - 2, 2):
inter_list.append(X[:, ind_c[i]] * X[:, ind_c[i + 1]])
X_inter = np.array(inter_list).T
n_inter = X_inter.shape[1]
beta_inter = np.random.choice(
coefs_sq, size=n_inter, replace=True, p=[1 - sp_nonlin, sp_nonlin]
)
omega = np.random.choice(
[0, 1], replace=True, size=n_inter, p=[prop_omega, 1 - prop_omega]
)
mu_0 = mu_0 + np.dot(X_inter, beta_inter).reshape((-1, 1))
mu_1 = mu_1 + np.dot(X_inter, beta_inter * omega).reshape((-1, 1))
ate = np.mean(mu_1 - mu_0)
mu_1 = mu_1 - ate + ate_goal
y = (
w * mu_1
+ (1 - w) * mu_0
+ np.random.normal(0, error_sd, n_total).reshape((-1, 1))
)
X_train, y_train, w_train, mu_0_train, mu_1_train = (
X[ind[: (n_0 + n_1)], :],
y[ind[: (n_0 + n_1)]],
w[ind[: (n_0 + n_1)]],
mu_0[ind[: (n_0 + n_1)]],
mu_1[ind[: (n_0 + n_1)]],
)
X_test, y_test, w_test, mu_0_t, mu_1_t = (
X[ind_test, :],
y[ind_test],
w[ind_test],
mu_0[ind_test],
mu_1[ind_test],
)
return (
X_train,
w_train,
y_train,
np.asarray([mu_0_train, mu_1_train]).squeeze().T,
X_test,
w_test,
y_test,
np.asarray([mu_0_t, mu_1_t]).squeeze().T,
)
def get_acic_orig_filenames(data_path: Path, simu_num: int) -> list:
return sorted(
glob.glob(
(data_path / ("data_cf_all/" + str(simu_num) + "/zymu_*.csv")).__str__()
)
)
def get_acic_orig_outcomes(data_path: Path, simu_num: int, i_exp: int) -> Tuple:
file_list = get_acic_orig_filenames(data_path=data_path, simu_num=simu_num)
out = pd.read_csv(file_list[i_exp])
w = out["z"]
y = w * out["y1"] + (1 - w) * out["y0"]
mu_0, mu_1 = out["mu0"], out["mu1"]
return y.values, w.values, mu_0.values, mu_1.values
def preprocess_acic_orig(
fn_csv: Path,
data_path: Path,
preprocessed: bool = False,
keep_categorical: bool = True,
simu_num: int = 1,
i_exp: int = 0,
train_size: int = 4000,
random_split: bool = False,
) -> Tuple:
X = get_acic_covariates(
fn_csv, keep_categorical=keep_categorical, preprocessed=preprocessed
)
y, w, mu_0, mu_1 = get_acic_orig_outcomes(
data_path=data_path, simu_num=simu_num, i_exp=i_exp
)
if not random_split:
X_train, y_train, w_train, mu_0_train, mu_1_train = (
X[:train_size, :],
y[:train_size],
w[:train_size],
mu_0[:train_size],
mu_1[:train_size],
)
X_test, y_test, w_test, mu_0_test, mu_1_test = (
X[train_size:, :],
y[train_size:],
w[train_size:],
mu_0[train_size:],
mu_1[train_size:],
)
else:
(
X_train,
X_test,
y_train,
y_test,
w_train,
w_test,
mu_0_train,
mu_0_test,
mu_1_train,
mu_1_test,
) = train_test_split(
X, y, w, mu_0, mu_1, test_size=1 - train_size, random_state=i_exp
)
return (
X_train,
w_train,
y_train,
np.asarray([mu_0_train, mu_1_train]).squeeze().T,
X_test,
w_test,
y_test,
np.asarray([mu_0_test, mu_1_test]).squeeze().T,
)
def preprocess(
fn_csv: Path,
data_path: Path,
preprocessed: bool = True,
original_acic_outcomes: bool = False,
**kwargs: Any,
) -> Tuple:
if not original_acic_outcomes:
return preprocess_simu(fn_csv=fn_csv, preprocessed=preprocessed, **kwargs)
else:
return preprocess_acic_orig(
fn_csv=fn_csv, preprocessed=preprocessed, data_path=data_path, **kwargs
)
def load(
data_path: Path,
preprocessed: bool = True,
original_acic_outcomes: bool = False,
**kwargs: Any,
) -> Tuple:
"""
ACIC2016 dataset dataloader.
- Download the dataset if needed.
- Load the dataset.
- Preprocess the data.
- Return train/test split.
Parameters
----------
data_path: Path
Path to the CSV. If it is missing, it will be downloaded.
preprocessed: bool
Switch between the raw and preprocessed versions of the dataset.
original_acic_outcomes: bool
Switch between new simulations (Inductive bias paper) and original acic outcomes
Returns
-------
train_x: array or pd.DataFrame
Features in training data.
train_t: array or pd.DataFrame
Treatments in training data.
train_y: array or pd.DataFrame
Observed outcomes in training data.
train_potential_y: array or pd.DataFrame
Potential outcomes in training data.
test_x: array or pd.DataFrame
Features in testing data.
test_potential_y: array or pd.DataFrame
Potential outcomes in testing data.
"""
if preprocessed:
csv = data_path / "x_trans.csv"
download_if_needed(csv, file_id=PREPROCESSED_FILE_ID)
else:
arch = data_path / "data_cf_all.tar.gz"
download_if_needed(
arch, file_id=FILE_ID, unarchive=True, unarchive_folder=data_path
)
csv = data_path / "data_cf_all/x.csv"
log.debug(f"load dataset {csv}")
return preprocess(
csv,
data_path=data_path,
preprocessed=preprocessed,
original_acic_outcomes=original_acic_outcomes,
**kwargs,
)
================================================
FILE: catenets/datasets/dataset_ihdp.py
================================================
"""
IHDP (Infant Health and Development Program) dataset
"""
# stdlib
import os
import random
from pathlib import Path
from typing import Any, Tuple
# third party
import numpy as np
import catenets.logger as log
from .network import download_if_needed
np.random.seed(0)
random.seed(0)
TRAIN_DATASET = "ihdp_npci_1-100.train.npz"
TEST_DATASET = "ihdp_npci_1-100.test.npz"
TRAIN_URL = "https://www.fredjo.com/files/ihdp_npci_1-100.train.npz"
TEST_URL = "https://www.fredjo.com/files/ihdp_npci_1-100.test.npz"
# helper functions
def load_data_npz(fname: Path, get_po: bool = True) -> dict:
"""
Helper function for loading the IHDP data set (adapted from https://github.com/clinicalml/cfrnet)
Parameters
----------
fname: Path
Dataset path
Returns
-------
data: dict
Raw IHDP dict, with X, w, y and yf keys.
"""
data_in = np.load(fname)
data = {"X": data_in["x"], "w": data_in["t"], "y": data_in["yf"]}
try:
data["ycf"] = data_in["ycf"]
except BaseException:
data["ycf"] = None
if get_po:
data["mu0"] = data_in["mu0"]
data["mu1"] = data_in["mu1"]
data["HAVE_TRUTH"] = not data["ycf"] is None
data["dim"] = data["X"].shape[1]
data["n"] = data["X"].shape[0]
return data
def prepare_ihdp_data(
data_train: dict,
data_test: dict,
rescale: bool = False,
setting: str = "C",
return_pos: bool = False,
) -> Tuple:
"""
Helper for preprocessing the IHDP dataset.
Parameters
----------
data_train: pd.DataFrame or dict
Train dataset
data_test: pd.DataFrame or dict
Test dataset
rescale: bool, default False
Rescale the outcomes to have similar scale
setting: str, default C
Experiment setting
return_pos: bool
Return potential outcomes
Returns
-------
X: dict or pd.DataFrame
Training Feature set
y: pd.DataFrame or list
Outcome list
t: pd.DataFrame or list
Treatment list
cate_true_in: pd.DataFrame or list
Average treatment effects for the training set
X_t: pd.Dataframe or list
Test feature set
cate_true_out: pd.DataFrame of list
Average treatment effects for the testing set
"""
X, y, w, mu0, mu1 = (
data_train["X"],
data_train["y"],
data_train["w"],
data_train["mu0"],
data_train["mu1"],
)
X_t, _, _, mu0_t, mu1_t = (
data_test["X"],
data_test["y"],
data_test["w"],
data_test["mu0"],
data_test["mu1"],
)
if setting == "D":
y[w == 1] = y[w == 1] + mu0[w == 1]
mu1 = mu0 + mu1
mu1_t = mu0_t + mu1_t
if rescale:
# rescale all outcomes to have similar scale of CATEs if sd_cate > 1
cate_in = mu0 - mu1
sd_cate = np.sqrt(cate_in.var())
if sd_cate > 1:
# training data
error = y - w * mu1 - (1 - w) * mu0
mu0 = mu0 / sd_cate
mu1 = mu1 / sd_cate
y = w * mu1 + (1 - w) * mu0 + error
# test data
mu0_t = mu0_t / sd_cate
mu1_t = mu1_t / sd_cate
cate_true_in = mu1 - mu0
cate_true_out = mu1_t - mu0_t
if return_pos:
return X, y, w, cate_true_in, X_t, cate_true_out, mu0, mu1, mu0_t, mu1_t
return X, y, w, cate_true_in, X_t, cate_true_out
def get_one_data_set(D: dict, i_exp: int, get_po: bool = True) -> dict:
"""
Helper for getting the IHDP data for one experiment. Adapted from https://github.com/clinicalml/cfrnet
Parameters
----------
D: dict or pd.DataFrame
All the experiment
i_exp: int
Experiment number
Returns
-------
data: dict or pd.Dataframe
dict with the experiment
"""
D_exp = {}
D_exp["X"] = D["X"][:, :, i_exp - 1]
D_exp["w"] = D["w"][:, i_exp - 1 : i_exp]
D_exp["y"] = D["y"][:, i_exp - 1 : i_exp]
if D["HAVE_TRUTH"]:
D_exp["ycf"] = D["ycf"][:, i_exp - 1 : i_exp]
else:
D_exp["ycf"] = None
if get_po:
D_exp["mu0"] = D["mu0"][:, i_exp - 1 : i_exp]
D_exp["mu1"] = D["mu1"][:, i_exp - 1 : i_exp]
return D_exp
def load(data_path: Path, exp: int = 1, rescale: bool = False, **kwargs: Any) -> Tuple:
"""
Get IHDP train/test datasets with treatments and labels.
Parameters
----------
data_path: Path
Path to the dataset csv. If the data is missing, it will be downloaded.
Returns
-------
X: pd.Dataframe or array
The training feature set
w: pd.DataFrame or array
Training treatment assignments.
y: pd.Dataframe or array
The training labels
training potential outcomes: pd.DataFrame or array.
Potential outcomes for the training set.
X_t: pd.DataFrame or array
The testing feature set
testing potential outcomes: pd.DataFrame of array
Potential outcomes for the testing set.
"""
data_train, data_test = load_raw(data_path)
data_exp = get_one_data_set(data_train, i_exp=exp, get_po=True)
data_exp_test = get_one_data_set(data_test, i_exp=exp, get_po=True)
(
X,
y,
w,
cate_true_in,
X_t,
cate_true_out,
mu0,
mu1,
mu0_t,
mu1_t,
) = prepare_ihdp_data(
data_exp,
data_exp_test,
rescale=rescale,
return_pos=True,
)
return (
X,
w,
y,
np.asarray([mu0, mu1]).squeeze().T,
X_t,
np.asarray([mu0_t, mu1_t]).squeeze().T,
)
def load_raw(data_path: Path) -> Tuple:
"""
Get IHDP raw train/test sets.
Parameters
----------
data_path: Path
Path to the dataset csv. If the data is missing, it will be downloaded.
Returns
-------
data_train: dict or pd.DataFrame
Training data
data_test: dict or pd.DataFrame
Testing data
"""
try:
os.mkdir(data_path)
except BaseException:
pass
train_csv = data_path / TRAIN_DATASET
test_csv = data_path / TEST_DATASET
log.debug(f"load raw dataset {train_csv}")
download_if_needed(train_csv, http_url=TRAIN_URL)
download_if_needed(test_csv, http_url=TEST_URL)
data_train = load_data_npz(train_csv, get_po=True)
data_test = load_data_npz(test_csv, get_po=True)
return data_train, data_test
================================================
FILE: catenets/datasets/dataset_twins.py
================================================
"""
Twins dataset
Load real-world individualized treatment effects estimation datasets
- Reference: http://data.nber.org/data/linked-birth-infant-death-data-vital-statistics-data.html
"""
# stdlib
import random
from pathlib import Path
from typing import Tuple
# third party
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import catenets.logger as log
from .network import download_if_needed
DATASET = "Twin_Data.csv.gz"
URL = "https://bitbucket.org/mvdschaar/mlforhealthlabpub/raw/0b0190bcd38a76c405c805f1ca774971fcd85233/data/twins/Twin_Data.csv.gz" # noqa: E501
def preprocess(
fn_csv: Path,
train_ratio: float = 0.8,
treatment_type: str = "rand",
seed: int = 42,
treat_prop: float = 0.5,
) -> Tuple:
"""Helper for preprocessing the Twins dataset.
Parameters
----------
fn_csv: Path
Dataset CSV file path.
train_ratio: float
The ratio of training data.
treatment_type: string
The treatment selection strategy.
seed: float
Random seed.
Returns
-------
train_x: array or pd.DataFrame
Features in training data.
train_t: array or pd.DataFrame
Treatments in training data.
train_y: array or pd.DataFrame
Observed outcomes in training data.
train_potential_y: array or pd.DataFrame
Potential outcomes in training data.
test_x: array or pd.DataFrame
Features in testing data.
test_potential_y: array or pd.DataFrame
Potential outcomes in testing data.
"""
np.random.seed(seed)
random.seed(seed)
# Load original data (11400 patients, 30 features, 2 dimensional potential outcomes)
df = pd.read_csv(fn_csv)
cleaned_columns = []
for col in df.columns:
cleaned_columns.append(col.replace("'", "").replace("’", ""))
df.columns = cleaned_columns
feat_list = list(df)
# 8: factor not on certificate, 9: factor not classifiable --> np.nan --> mode imputation
medrisk_list = [
"anemia",
"cardiac",
"lung",
"diabetes",
"herpes",
"hydra",
"hemo",
"chyper",
"phyper",
"eclamp",
"incervix",
"pre4000",
"dtotord",
"preterm",
"renal",
"rh",
"uterine",
"othermr",
]
# 99: missing
other_list = ["cigar", "drink", "wtgain", "gestat", "dmeduc", "nprevist"]
other_list2 = ["pldel", "resstatb"] # but no samples are missing..
bin_list = ["dmar"] + medrisk_list
con_list = ["dmage", "mpcb"] + other_list
cat_list = ["adequacy"] + other_list2
for feat in medrisk_list:
df[feat] = df[feat].apply(lambda x: df[feat].mode()[0] if x in [8, 9] else x)
for feat in other_list:
df.loc[df[feat] == 99, feat] = df.loc[df[feat] != 99, feat].mean()
df_features = df[con_list + bin_list]
for feat in cat_list:
df_features = pd.concat(
[df_features, pd.get_dummies(df[feat], prefix=feat)], axis=1
)
# Define features
feat_list = [
"dmage",
"mpcb",
"cigar",
"drink",
"wtgain",
"gestat",
"dmeduc",
"nprevist",
"dmar",
"anemia",
"cardiac",
"lung",
"diabetes",
"herpes",
"hydra",
"hemo",
"chyper",
"phyper",
"eclamp",
"incervix",
"pre4000",
"dtotord",
"preterm",
"renal",
"rh",
"uterine",
"othermr",
"adequacy_1",
"adequacy_2",
"adequacy_3",
"pldel_1",
"pldel_2",
"pldel_3",
"pldel_4",
"pldel_5",
"resstatb_1",
"resstatb_2",
"resstatb_3",
"resstatb_4",
]
x = np.asarray(df_features[feat_list])
y0 = np.asarray(df[["outcome(t=0)"]]).reshape((-1,))
y0 = np.array(y0 < 9999, dtype=int)
y1 = np.asarray(df[["outcome(t=1)"]]).reshape((-1,))
y1 = np.array(y1 < 9999, dtype=int)
# Preprocessing
scaler = MinMaxScaler()
scaler.fit(x)
x = scaler.transform(x)
no, dim = x.shape
if treatment_type == "rand":
# assign with p=0.5
prob = np.ones(x.shape[0]) * treat_prop
elif treatment_type == "logistic":
# assign with logistic prob
coef = np.random.uniform(-0.1, 0.1, size=[np.shape(x)[1], 1])
prob = 1 / (1 + np.exp(-np.matmul(x, coef)))
w = np.random.binomial(1, prob)
y = y1 * w + y0 * (1 - w)
potential_y = np.vstack((y0, y1)).T
# Train/test division
if train_ratio < 1:
idx = np.random.permutation(no)
train_idx = idx[: int(train_ratio * no)]
test_idx = idx[int(train_ratio * no) :]
train_x = x[train_idx, :]
train_w = w[train_idx]
train_y = y[train_idx]
train_potential_y = potential_y[train_idx, :]
test_x = x[test_idx, :]
test_potential_y = potential_y[test_idx, :]
else:
train_x = x
train_w = w
train_y = y
train_potential_y = potential_y
test_x = None
test_potential_y = None
return train_x, train_w, train_y, train_potential_y, test_x, test_potential_y
def load(
data_path: Path,
train_ratio: float = 0.8,
treatment_type: str = "rand",
seed: int = 42,
treat_prop: float = 0.5,
) -> Tuple:
"""
Twins dataset dataloader.
- Download the dataset if needed.
- Load the dataset.
- Preprocess the data.
- Return train/test split.
Parameters
----------
data_path: Path
Path to the CSV. If it is missing, it will be downloaded.
train_ratio: float
Train/test ratio
treatment_type: str
Treatment generation strategy
seed: float
Random seed
treat_prop: float
Treatment proportion
Returns
-------
train_x: array or pd.DataFrame
Features in training data.
train_t: array or pd.DataFrame
Treatments in training data.
train_y: array or pd.DataFrame
Observed outcomes in training data.
train_potential_y: array or pd.DataFrame
Potential outcomes in training data.
test_x: array or pd.DataFrame
Features in testing data.
test_potential_y: array or pd.DataFrame
Potential outcomes in testing data.
"""
csv = data_path / DATASET
download_if_needed(csv, http_url=URL)
log.debug(f"load dataset {csv}")
return preprocess(
csv,
train_ratio=train_ratio,
treatment_type=treatment_type,
seed=seed,
treat_prop=treat_prop,
)
================================================
FILE: catenets/datasets/network.py
================================================
"""
Utilities and helpers for retrieving the datasets
"""
# stdlib
import tarfile
import urllib.request
from pathlib import Path
from typing import Optional
import gdown
def download_gdrive_if_needed(path: Path, file_id: str) -> None:
"""
Helper for downloading a file from Google Drive, if it is now already on the disk.
Parameters
----------
path: Path
Where to download the file
file_id: str
Google Drive File ID. Details: https://developers.google.com/drive/api/v3/about-files
"""
path = Path(path)
if path.exists():
return
gdown.download(id=file_id, output=str(path), quiet=False)
def download_http_if_needed(path: Path, url: str) -> None:
"""
Helper for downloading a file, if it is now already on the disk.
Parameters
----------
path: Path
Where to download the file.
url: URL string
HTTP URL for the dataset.
"""
path = Path(path)
if path.exists():
return
if url.lower().startswith("http"):
urllib.request.urlretrieve(url, path) # nosec
return
raise ValueError(f"Invalid url provided {url}")
def unarchive_if_needed(path: Path, output_folder: Path) -> None:
"""
Helper for uncompressing archives. Supports .tar.gz and .tar.
Parameters
----------
path: Path
Source archive.
output_folder: Path
Where to unarchive.
"""
if str(path).endswith(".tar.gz"):
tar = tarfile.open(path, "r:gz")
tar.extractall(path=output_folder) # nosec
tar.close()
elif str(path).endswith(".tar"):
tar = tarfile.open(path, "r:")
tar.extractall(path=output_folder) # nosec
tar.close()
else:
raise NotImplementedError(f"archive not supported {path}")
def download_if_needed(
download_path: Path,
file_id: Optional[str] = None, # used for downloading from Google Drive
http_url: Optional[str] = None, # used for downloading from a HTTP URL
unarchive: bool = False, # unzip a downloaded archive
unarchive_folder: Optional[Path] = None, # unzip folder
) -> None:
"""
Helper for retrieving online datasets.
Parameters
----------
download_path: str
Where to download the archive
file_id: str, optional
Set this if you want to download from a public Google drive share
http_url: str, optional
Set this if you want to download from a HTTP URL
unarchive: bool
Set this if you want to try to unarchive the downloaded file
unarchive_folder: str
Mandatory if you set unarchive to True.
"""
download_path = Path(download_path)
if file_id is not None:
download_gdrive_if_needed(download_path, file_id)
elif http_url is not None:
download_http_if_needed(download_path, http_url)
else:
raise ValueError("Please provide a download URL")
if unarchive and unarchive_folder is None:
raise ValueError("Please provide a folder for the archive")
if unarchive and unarchive_folder is not None:
try:
unarchive_if_needed(download_path, unarchive_folder)
except BaseException as e:
print(f"Failed to unpack {download_path}. Error {e}")
download_path.unlink()
================================================
FILE: catenets/experiment_utils/__init__.py
================================================
================================================
FILE: catenets/experiment_utils/base.py
================================================
"""
Some utils for experiments
"""
# Author: Alicia Curth
from typing import Callable, Dict, Optional, Union
import jax.numpy as jnp
from catenets.models.jax import (
DRNET_NAME,
PSEUDOOUT_NAME,
RANET_NAME,
RNET_NAME,
SNET1_NAME,
SNET2_NAME,
SNET3_NAME,
SNET_NAME,
T_NAME,
XNET_NAME,
PseudoOutcomeNet,
get_catenet,
)
from catenets.models.jax.base import check_shape_1d_data
from catenets.models.jax.transformation_utils import (
DR_TRANSFORMATION,
PW_TRANSFORMATION,
RA_TRANSFORMATION,
)
SEP = "_"
def eval_mse_model(
inputs: jnp.ndarray,
targets: jnp.ndarray,
predict_fun: Callable,
params: jnp.ndarray,
) -> jnp.ndarray:
# evaluate the mse of a model given its function and params
preds = predict_fun(params, inputs)
return jnp.mean((preds - targets) ** 2)
def eval_mse(preds: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
preds = check_shape_1d_data(preds)
targets = check_shape_1d_data(targets)
return jnp.mean((preds - targets) ** 2)
def eval_root_mse(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray:
cate_true = check_shape_1d_data(cate_true)
cate_pred = check_shape_1d_data(cate_pred)
return jnp.sqrt(eval_mse(cate_pred, cate_true))
def eval_abs_error_ate(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray:
cate_true = check_shape_1d_data(cate_true)
cate_pred = check_shape_1d_data(cate_pred)
return jnp.abs(jnp.mean(cate_pred) - jnp.mean(cate_true))
def get_model_set(
model_selection: Union[str, list] = "all", model_params: Optional[dict] = None
) -> Dict:
"""Helper function to retrieve a set of models"""
# get model selection
if type(model_selection) is str:
if model_selection == "snet":
models = get_all_snets()
elif model_selection == "pseudo":
models = get_all_pseudoout_models()
elif model_selection == "twostep":
models = get_all_twostep_models()
elif model_selection == "all":
models = dict(**get_all_snets(), **get_all_pseudoout_models())
else:
models = {model_selection: get_catenet(model_selection)()} # type: ignore
elif type(model_selection) is list:
models = {}
for model in model_selection:
models.update({model: get_catenet(model)()})
else:
raise ValueError("model_selection should be string or list.")
# set hyperparameters
if model_params is not None:
for model in models.values():
existing_params = model.get_params()
new_params = {
key: val
for key, val in model_params.items()
if key in existing_params.keys()
}
model.set_params(**new_params)
return models
ALL_SNETS = [T_NAME, SNET1_NAME, SNET2_NAME, SNET3_NAME, SNET_NAME]
ALL_PSEUDOOUT_MODELS = [DR_TRANSFORMATION, PW_TRANSFORMATION, RA_TRANSFORMATION]
ALL_TWOSTEP_MODELS = [DRNET_NAME, RANET_NAME, XNET_NAME, RNET_NAME]
def get_all_snets() -> Dict:
model_dict = {}
for name in ALL_SNETS:
model_dict.update({name: get_catenet(name)()})
return model_dict
def get_all_pseudoout_models() -> Dict: # DR, RA, PW learner
model_dict = {}
for trans in ALL_PSEUDOOUT_MODELS:
model_dict.update(
{PSEUDOOUT_NAME + SEP + trans: PseudoOutcomeNet(transformation=trans)}
)
return model_dict
def get_all_twostep_models() -> Dict: # DR, RA, R, X learner
model_dict = {}
for name in ALL_TWOSTEP_MODELS:
model_dict.update({name: get_catenet(name)()})
return model_dict
================================================
FILE: catenets/experiment_utils/simulation_utils.py
================================================
"""
Simulation utils, allowing to flexibly consider different DGPs
"""
# Author: Alicia Curth
from typing import Any, Optional, Tuple
import numpy as np
from scipy.special import expit
def simulate_treatment_setup(
n: int,
d: int = 25,
n_w: int = 0,
n_c: int = 0,
n_o: int = 0,
n_t: int = 0,
covariate_model: Any = None,
covariate_model_params: Optional[dict] = None,
propensity_model: Any = None,
propensity_model_params: Optional[dict] = None,
mu_0_model: Any = None,
mu_0_model_params: Optional[dict] = None,
mu_1_model: Any = None,
mu_1_model_params: Optional[dict] = None,
error_sd: float = 1,
seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Generic function to flexibly simulate a treatment setup.
Parameters
----------
n: int
Number of observations to generate
d: int
dimension of X to generate
n_o: int
Dimension of outcome-factor
n_c: int
Dimension of confounding factor
n_t: int
Dimension of purely predictive variables (support of tau(x)
n_w: int
Dimension of treatment assignment factor
covariate_model:
Model to generate covariates. Default: multivariate normal
covariate_model_params: dict
Additional parameters to pass to covariate model
propensity_model:
Model to generate propensity scores
propensity_model_params:
Additional parameters to pass to propensity model
mu_0_model:
Model to generate untreated outcomes
mu_0_model_params:
Additional parameters to pass to untreated outcome model
mu_1_model:
Model to generate treated outcomes.
mu_1_model_params:
Additional parameters to pass to treated outcome model
error_sd: float, default 1
Standard deviation of normal errors
seed: int
Seed
Returns
-------
X, y, w, p, t - Covariates, observed outcomes, treatment indicators, propensities, CATE
"""
# input checks
n_nuisance = d - (n_c + n_o + n_w + n_t)
if n_nuisance < 0:
raise ValueError("Dimensions should add up to maximally d.")
# set defaults
if covariate_model is None:
covariate_model = normal_covariate_model
if covariate_model_params is None:
covariate_model_params = {}
if propensity_model is None:
propensity_model = propensity_AISTATS
if propensity_model_params is None:
propensity_model_params = {}
if mu_0_model is None:
mu_0_model = mu0_AISTATS
if mu_0_model_params is None:
mu_0_model_params = {}
if mu_1_model is None:
mu_1_model = mu1_AISTATS
if mu_1_model_params is None:
mu_1_model_params = {}
np.random.seed(seed)
# generate data and outcomes
X = covariate_model(
n=n,
n_nuisance=n_nuisance,
n_c=n_c,
n_o=n_o,
n_w=n_w,
n_t=n_t,
**covariate_model_params
)
mu_0 = mu_0_model(X, n_c=n_c, n_o=n_o, n_w=n_w, **mu_0_model_params)
mu_1 = mu_1_model(
X, n_c=n_c, n_o=n_o, n_w=n_w, n_t=n_t, mu_0=mu_0, **mu_1_model_params
)
t = mu_1 - mu_0
# generate treatments
p = propensity_model(X, n_c=n_c, n_w=n_w, **propensity_model_params)
w = np.random.binomial(1, p=p)
# generate observables
y = w * mu_1 + (1 - w) * mu_0 + np.random.normal(0, error_sd, n)
return X, y, w, p, t
# normal covariate model (Adapted from Hassanpour & Greiner, 2020) -------------
def get_multivariate_normal_params(
m: int, correlated: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
# Adapted from Hassanpour & Greiner (2020)
if correlated:
mu = np.zeros(m) # np.random.normal(size=m)/10
temp = np.random.uniform(size=(m, m))
temp = 0.5 * (np.transpose(temp) + temp)
sig = (np.ones((m, m)) - np.eye(m)) * temp / 10 + 0.5 * np.eye(
m
) # (temp + m * np.eye(m)) / 10
else:
mu = np.zeros(m)
sig = np.eye(m)
return mu, sig
def get_set_normal_covariates(m: int, n: int, correlated: bool = False) -> np.ndarray:
if m == 0:
return
mu, sig = get_multivariate_normal_params(m, correlated=correlated)
return np.random.multivariate_normal(mean=mu, cov=sig, size=n)
def normal_covariate_model(
n: int,
n_nuisance: int = 25,
n_c: int = 0,
n_o: int = 0,
n_w: int = 0,
n_t: int = 0,
correlated: bool = False,
) -> np.ndarray:
X_stack: Tuple = ()
for n_x in [n_w, n_c, n_o, n_t, n_nuisance]:
if n_x > 0:
X_stack = (*X_stack, get_set_normal_covariates(n_x, n, correlated))
return np.hstack(X_stack)
def propensity_AISTATS(
X: np.ndarray,
n_c: int = 0,
n_w: int = 0,
xi: float = 0.5,
nonlinear: bool = True,
offset: Any = 0,
target_prop: Optional[np.ndarray] = None,
) -> np.ndarray:
if n_c + n_w == 0:
# constant propensity
return xi * np.ones(X.shape[0])
else:
coefs = np.ones(n_c + n_w)
if nonlinear:
z = np.dot(X[:, : (n_c + n_w)] ** 2, coefs) / (n_c + n_w)
else:
z = np.dot(X[:, : (n_c + n_w)], coefs) / (n_c + n_w)
if type(offset) is float or type(offset) is int:
prop = expit(xi * z + offset)
if target_prop is not None:
avg_prop = np.average(prop)
prop = target_prop / avg_prop * prop
return prop
elif offset == "center":
# center the propensity scores to median 0.5
prop = expit(xi * (z - np.median(z)))
if target_prop is not None:
avg_prop = np.average(prop)
prop = target_prop / avg_prop * prop
return prop
else:
raise ValueError("Not a valid value for offset")
def propensity_constant(
X: np.ndarray, n_c: int = 0, n_w: int = 0, xi: float = 0.5
) -> np.ndarray:
return xi * np.ones(X.shape[0])
def mu0_AISTATS(
X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0, scale: bool = False
) -> np.ndarray:
if n_c + n_o == 0:
return np.zeros((X.shape[0]))
else:
if not scale:
coefs = np.ones(n_c + n_o)
else:
coefs = 10 * np.ones(n_c + n_o) / (n_c + n_o)
return np.dot(X[:, n_w : (n_w + n_c + n_o)] ** 2, coefs)
def mu1_AISTATS(
X: np.ndarray,
n_w: int = 0,
n_c: int = 0,
n_o: int = 0,
n_t: int = 0,
mu_0: Optional[np.ndarray] = None,
nonlinear: int = 2,
withbase: bool = True,
scale: bool = False,
) -> np.ndarray:
if n_t == 0:
return mu_0
# use additive effect
else:
if scale:
coefs = 10 * np.ones(n_t) / n_t
else:
coefs = np.ones(n_t)
X_sel = X[:, (n_w + n_c + n_o) : (n_w + n_c + n_o + n_t)]
if withbase:
return mu_0 + np.dot(X_sel**nonlinear, coefs)
else:
return np.dot(X_sel**nonlinear, coefs)
# Other simulation settings not used in AISTATS paper
# uniform covariate model
def uniform_covariate_model(
n: int,
n_nuisance: int = 0,
n_c: int = 0,
n_o: int = 0,
n_w: int = 0,
n_t: int = 0,
low: int = -1,
high: int = 1,
) -> np.ndarray:
d = n_nuisance + n_c + n_o + n_w + n_t
return np.random.uniform(low=low, high=high, size=(n, d))
def mu1_additive(
X: np.ndarray,
n_w: int = 0,
n_c: int = 0,
n_o: int = 0,
n_t: int = 0,
mu_0: Optional[np.ndarray] = None,
) -> np.ndarray:
if n_t == 0:
return mu_0
else:
coefs = np.random.normal(size=n_t)
return np.dot(X[:, (n_w + n_c + n_o) : (n_w + n_c + n_o + n_t)], coefs) / n_t
# regression surfaces from Hassanpour & Greiner
def mu0_hg(X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0) -> np.ndarray:
if n_c + n_o == 0:
return np.zeros((X.shape[0]))
else:
coefs = np.random.normal(size=n_c + n_o)
return np.dot(X[:, n_w : (n_w + n_c + n_o)], coefs) / (n_c + n_o)
def mu1_hg(
X: np.ndarray,
n_w: int = 0,
n_c: int = 0,
n_o: int = 0,
n_t: int = 0,
mu_0: Optional[np.ndarray] = None,
) -> np.ndarray:
if n_c + n_o == 0:
return np.zeros((X.shape[0]))
else:
coefs = np.random.normal(size=n_c + n_o)
return np.dot(X[:, n_w : (n_w + n_c + n_o)] ** 2, coefs) / (n_c + n_o)
def propensity_hg(
X: np.ndarray, n_c: int = 0, n_w: int = 0, xi: Optional[float] = None
) -> np.ndarray:
# propensity set-up used in Hassanpour & Greiner (2020)
if n_c + n_w == 0:
return 0.5 * np.ones(X.shape[0])
else:
if xi is None:
xi = 1
coefs = np.random.normal(size=n_c + n_w)
z = np.dot(X[:, : (n_c + n_w)], coefs)
return expit(xi * z)
================================================
FILE: catenets/experiment_utils/tester.py
================================================
# stdlib
import copy
from typing import Any, Tuple
# third party
import numpy as np
import torch
from sklearn.model_selection import KFold, StratifiedKFold
from catenets.experiment_utils.torch_metrics import abs_error_ATE, sqrt_PEHE
def generate_score(metric: np.ndarray) -> Tuple[float, float]:
percentile_val = 1.96
return (np.mean(metric), percentile_val * np.std(metric) / np.sqrt(len(metric)))
def print_score(score: Tuple[float, float]) -> str:
return str(round(score[0], 4)) + " +/- " + str(round(score[1], 4))
def evaluate_treatments_model(
estimator: Any,
X: torch.Tensor,
Y: torch.Tensor,
Y_full: torch.Tensor,
W: torch.Tensor,
n_folds: int = 3,
seed: int = 0,
) -> dict:
metric_pehe = np.zeros(n_folds)
metric_ate = np.zeros(n_folds)
indx = 0
if len(np.unique(Y)) == 2:
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
else:
skf = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
for train_index, test_index in skf.split(X, Y):
X_train = X[train_index]
Y_train = Y[train_index]
W_train = W[train_index]
X_test = X[test_index]
Y_full_test = Y_full[test_index]
model = copy.deepcopy(estimator)
model.fit(X_train, Y_train, W_train)
try:
te_pred = model.predict(X_test).detach().cpu().numpy()
except BaseException:
te_pred = np.asarray(model.predict(X_test))
metric_ate[indx] = abs_error_ATE(Y_full_test, te_pred)
metric_pehe[indx] = sqrt_PEHE(Y_full_test, te_pred)
indx += 1
output_pehe = generate_score(metric_pehe)
output_ate = generate_score(metric_ate)
return {
"raw": {
"pehe": output_pehe,
"ate": output_ate,
},
"str": {
"pehe": print_score(output_pehe),
"ate": print_score(output_ate),
},
}
================================================
FILE: catenets/experiment_utils/torch_metrics.py
================================================
# third party
import torch
def sqrt_PEHE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:
"""
Precision in Estimation of Heterogeneous Effect(PyTorch version).
PEHE reflects the ability to capture individual variation in treatment effects.
Args:
po: expected outcome.
hat_te: estimated outcome.
"""
po = torch.Tensor(po)
hat_te = torch.Tensor(hat_te)
return torch.sqrt(torch.mean(((po[:, 1] - po[:, 0]) - hat_te) ** 2))
def abs_error_ATE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:
"""
Average Treatment Effect.
ATE measures what is the expected causal effect of the treatment across all individuals in the population.
Args:
po: expected outcome.
hat_te: estimated outcome.
"""
po = torch.Tensor(po)
hat_te = torch.Tensor(hat_te)
return torch.abs(torch.mean(po[:, 1] - po[:, 0]) - torch.mean(hat_te))
================================================
FILE: catenets/logger.py
================================================
# stdlib
import logging
import os
from typing import Any, Callable, NoReturn, TextIO, Union
# third party
from loguru import logger
LOG_FORMAT = "[{time}][{process.id}][{level}] {message}"
logger.remove()
DEFAULT_SINK = "catenets_{time}.log"
def remove() -> None:
logger.remove()
def add(
sink: Union[None, str, os.PathLike, TextIO, logging.Handler] = None,
level: str = "ERROR",
) -> None:
sink = DEFAULT_SINK if sink is None else sink
try:
logger.add(
sink=sink,
format=LOG_FORMAT,
enqueue=True,
colorize=False,
diagnose=True,
backtrace=True,
rotation="10 MB",
retention="1 day",
level=level,
)
except BaseException:
logger.add(
sink=sink,
format=LOG_FORMAT,
colorize=False,
diagnose=True,
backtrace=True,
level=level,
)
def traceback_and_raise(e: Any, verbose: bool = False) -> NoReturn:
try:
if verbose:
logger.opt(lazy=True).exception(e)
else:
logger.opt(lazy=True).critical(e)
except BaseException as ex:
logger.debug("failed to print exception", ex)
if not issubclass(type(e), Exception):
e = Exception(e)
raise e
def create_log_and_print_function(level: str) -> Callable:
def log_and_print(*args: Any, **kwargs: Any) -> None:
try:
method = getattr(logger.opt(lazy=True), level, None)
if method is not None:
method(*args, **kwargs)
else:
logger.debug(*args, **kwargs)
except BaseException as e:
msg = f"failed to log exception. {e}"
try:
logger.debug(msg)
except Exception as e:
print(f"{msg}. {e}")
return log_and_print
def traceback(*args: Any, **kwargs: Any) -> None:
return create_log_and_print_function(level="exception")(*args, **kwargs)
def critical(*args: Any, **kwargs: Any) -> None:
return create_log_and_print_function(level="critical")(*args, **kwargs)
def error(*args: Any, **kwargs: Any) -> None:
return create_log_and_print_function(level="error")(*args, **kwargs)
def warning(*args: Any, **kwargs: Any) -> None:
return create_log_and_print_function(level="warning")(*args, **kwargs)
def info(*args: Any, **kwargs: Any) -> None:
return create_log_and_print_function(level="info")(*args, **kwargs)
def debug(*args: Any, **kwargs: Any) -> None:
return create_log_and_print_function(level="debug")(*args, **kwargs)
def trace(*args: Any, **kwargs: Any) -> None:
return create_log_and_print_function(level="trace")(*args, **kwargs)
================================================
FILE: catenets/models/__init__.py
================================================
import catenets.logger as log
try:
from . import jax
except ImportError:
log.error("JAX models disabled")
try:
from . import torch
except ImportError:
log.error("PyTorch models disabled")
__all__ = ["jax", "torch"]
================================================
FILE: catenets/models/constants.py
================================================
"""
Define some constants for initialisation of hyperparamters etc
"""
import numpy as np
# default model architectures
DEFAULT_LAYERS_OUT = 2
DEFAULT_LAYERS_OUT_T = 2
DEFAULT_LAYERS_R = 3
DEFAULT_LAYERS_R_T = 3
DEFAULT_UNITS_OUT = 100
DEFAULT_UNITS_R = 200
DEFAULT_UNITS_OUT_T = 100
DEFAULT_UNITS_R_T = 200
DEFAULT_NONLIN = "elu"
# other default hyperparameters
DEFAULT_STEP_SIZE = 0.0001
DEFAULT_STEP_SIZE_T = 0.0001
DEFAULT_N_ITER = 10000
DEFAULT_BATCH_SIZE = 100
DEFAULT_PENALTY_L2 = 1e-4
DEFAULT_PENALTY_DISC = 0
DEFAULT_PENALTY_ORTHOGONAL = 1 / 100
DEFAULT_AVG_OBJECTIVE = True
# defaults for early stopping
DEFAULT_VAL_SPLIT = 0.3
DEFAULT_N_ITER_MIN = 200
DEFAULT_PATIENCE = 10
# Defaults for crossfitting
DEFAULT_CF_FOLDS = 2
# other defaults
DEFAULT_SEED = 42
DEFAULT_N_ITER_PRINT = 50
LARGE_VAL = np.iinfo(np.int32).max
DEFAULT_UNITS_R_BIG_S = 100
DEFAULT_UNITS_R_SMALL_S = 50
DEFAULT_UNITS_R_BIG_S3 = 150
DEFAULT_UNITS_R_SMALL_S3 = 50
N_SUBSPACES = 3
DEFAULT_DIM_S_OUT = 50
DEFAULT_DIM_S_R = 100
DEFAULT_DIM_P_OUT = 50
DEFAULT_DIM_P_R = 100
================================================
FILE: catenets/models/jax/__init__.py
================================================
"""
JAX-based implementations for the CATE estimators.
"""
from typing import Any
from catenets.models.jax.disentangled_nets import SNet3
from catenets.models.jax.flextenet import FlexTENet
from catenets.models.jax.offsetnet import OffsetNet
from catenets.models.jax.pseudo_outcome_nets import (
DRNet,
PseudoOutcomeNet,
PWNet,
RANet,
)
from catenets.models.jax.representation_nets import DragonNet, SNet1, SNet2, TARNet
from catenets.models.jax.rnet import RNet
from catenets.models.jax.snet import SNet
from catenets.models.jax.tnet import TNet
from catenets.models.jax.xnet import XNet
SNET1_NAME = "SNet1"
T_NAME = "TNet"
SNET2_NAME = "SNet2"
PSEUDOOUT_NAME = "PseudoOutcomeNet"
SNET3_NAME = "SNet3"
SNET_NAME = "SNet"
XNET_NAME = "XNet"
RNET_NAME = "RNet"
DRNET_NAME = "DRNet"
PWNET_NAME = "PWNet"
RANET_NAME = "RANet"
TARNET_NAME = "TARNet"
FLEXTE_NAME = "FlexTENet"
OFFSET_NAME = "OffsetNet"
DRAGON_NAME = "DragonNet"
ALL_MODELS = [
T_NAME,
SNET1_NAME,
SNET2_NAME,
SNET3_NAME,
SNET_NAME,
PSEUDOOUT_NAME,
RNET_NAME,
XNET_NAME,
DRNET_NAME,
PWNET_NAME,
RANET_NAME,
TARNET_NAME,
FLEXTE_NAME,
OFFSET_NAME,
]
MODEL_DICT = {
T_NAME: TNet,
SNET1_NAME: SNet1,
SNET2_NAME: SNet2,
SNET3_NAME: SNet3,
SNET_NAME: SNet,
PSEUDOOUT_NAME: PseudoOutcomeNet,
RNET_NAME: RNet,
XNET_NAME: XNet,
DRNET_NAME: DRNet,
PWNET_NAME: PWNet,
RANET_NAME: RANet,
TARNET_NAME: TARNet,
DRAGON_NAME: DragonNet,
OFFSET_NAME: OffsetNet,
FLEXTE_NAME: FlexTENet,
}
__all__ = [
T_NAME,
SNET1_NAME,
SNET2_NAME,
SNET3_NAME,
SNET_NAME,
PSEUDOOUT_NAME,
RNET_NAME,
XNET_NAME,
DRNET_NAME,
PWNET_NAME,
RANET_NAME,
TARNET_NAME,
DRAGON_NAME,
FLEXTE_NAME,
OFFSET_NAME,
]
def get_catenet(name: str) -> Any:
if name not in ALL_MODELS:
raise ValueError(
f"Model name should be in catenets.models.jax.ALL_MODELS You passed {name}"
)
return MODEL_DICT[name]
================================================
FILE: catenets/models/jax/base.py
================================================
"""
Base modules shared across different nets
"""
# Author: Alicia Curth
import abc
from typing import Any, Callable, List, Optional, Tuple
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.example_libraries import optimizers, stax
from jax.example_libraries.stax import Dense, Elu, Relu, Sigmoid
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.model_selection import ParameterGrid
import catenets.logger as log
from catenets.models.constants import (
DEFAULT_BATCH_SIZE,
DEFAULT_LAYERS_OUT,
DEFAULT_N_ITER,
DEFAULT_N_ITER_MIN,
DEFAULT_N_ITER_PRINT,
DEFAULT_NONLIN,
DEFAULT_PATIENCE,
DEFAULT_PENALTY_L2,
DEFAULT_SEED,
DEFAULT_STEP_SIZE,
DEFAULT_UNITS_OUT,
DEFAULT_UNITS_R,
DEFAULT_VAL_SPLIT,
LARGE_VAL,
)
from catenets.models.jax.model_utils import (
check_shape_1d_data,
check_X_is_np,
make_val_split,
)
def ReprBlock(
n_layers: int = 3, n_units: int = 100, nonlin: str = DEFAULT_NONLIN
) -> Any:
# Creates a representation block using jax.stax
# create first layer
if nonlin == "elu":
NL = Elu
elif nonlin == "relu":
NL = Relu
elif nonlin == "sigmoid":
NL = Sigmoid
else:
raise ValueError("Unknown nonlinearity")
layers: Tuple
layers = (Dense(n_units), NL)
# add required number of layers
for i in range(n_layers - 1):
layers = (*layers, Dense(n_units), NL)
return stax.serial(*layers)
def OutputHead(
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
binary_y: bool = False,
n_layers_r: int = 0,
n_units_r: int = DEFAULT_UNITS_R,
nonlin: str = DEFAULT_NONLIN,
) -> Any:
# Creates an output head using jax.stax
if nonlin == "elu":
NL = Elu
elif nonlin == "relu":
NL = Relu
elif nonlin == "sigmoid":
NL = Sigmoid
else:
raise ValueError("Unknown nonlinearity")
layers: Tuple = ()
# add required number of layers
for i in range(n_layers_r):
layers = (*layers, Dense(n_units_r), NL)
# add required number of layers
for i in range(n_layers_out):
layers = (*layers, Dense(n_units_out), NL)
# return final architecture
if not binary_y:
return stax.serial(*layers, Dense(1))
else:
return stax.serial(*layers, Dense(1), Sigmoid)
class BaseCATENet(BaseEstimator, RegressorMixin, abc.ABC):
"""
Base CATENet class to serve as template for all other nets
"""
def score(
self,
X: jnp.ndarray,
y: jnp.ndarray,
sample_weight: Optional[jnp.ndarray] = None,
) -> float:
"""
Return the sqrt PEHE error (Oracle metric).
Parameters
----------
X: pd.DataFrame or np.array
Covariate matrix
y: np.array
Expected potential outcome vector
"""
X = check_X_is_np(X)
y = check_X_is_np(y)
if len(X) != len(y):
raise ValueError("X/y length mismatch for score")
if y.shape[-1] != 2:
raise ValueError(f"y has invalid shape {y.shape}")
hat_te = self.predict(X)
return jnp.sqrt(jnp.mean(((y[:, 1] - y[:, 0]) - hat_te) ** 2))
@abc.abstractmethod
def _get_train_function(self) -> Callable:
...
def fit(
self,
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
p: Optional[jnp.ndarray] = None,
) -> "BaseCATENet":
"""
Fit method for a CATENet. Takes covariates, outcome variable and treatment indicator as
input
Parameters
----------
X: pd.DataFrame or np.array
Covariate matrix
y: np.array
Outcome vector
w: np.array
Treatment indicator
p: np.array
Vector of (known) treatment propensities. Currently only supported for TwoStepNets.
"""
# some quick input checks
if p is not None:
raise NotImplementedError("Only two-step-nets take p as input. ")
X = check_X_is_np(X)
self._check_inputs(w, p)
train_func = self._get_train_function()
train_params = self.get_params()
self._params, self._predict_funs = train_func(X, y, w, **train_params)
return self
@abc.abstractmethod
def _get_predict_function(self) -> Callable:
...
def predict(
self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False
) -> jnp.ndarray:
"""
Predict treatment effect estimates using a CATENet. Depending on method, can also return
potential outcome estimate and propensity score estimate.
Parameters
----------
X: pd.DataFrame or np.array
Covariate matrix
return_po: bool, default False
Whether to return potential outcome estimate
return_prop: bool, default False
Whether to return propensity estimate
Returns
-------
array of CATE estimates, optionally also potential outcomes and propensity
"""
X = check_X_is_np(X)
predict_func = self._get_predict_function()
return predict_func(
X,
trained_params=self._params,
predict_funs=self._predict_funs,
return_po=return_po,
return_prop=return_prop,
)
@staticmethod
def _check_inputs(w: jnp.ndarray, p: jnp.ndarray) -> None:
if p is not None:
if onp.sum(p > 1) > 0 or onp.sum(p < 0) > 0:
raise ValueError("p should be in [0,1]")
if not ((w == 0) | (w == 1)).all():
raise ValueError("W should be binary")
def fit_and_select_params(
self,
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
p: Optional[jnp.ndarray] = None,
param_grid: dict = {},
) -> "BaseCATENet":
# some quick input checks
if param_grid is None:
raise ValueError("No param_grid to evaluate. ")
X = check_X_is_np(X)
self._check_inputs(w, p)
param_grid = ParameterGrid(param_grid)
self_param_dict = self.get_params()
train_function = self._get_train_function()
models = []
losses = []
param_settings: list = []
for param_setting in param_grid:
log.debug(
"Testing parameter setting: "
+ " ".join(
[key + ": " + str(value) for key, value in param_setting.items()]
)
)
# replace params
train_param_dict = {
key: (val if key not in param_setting.keys() else param_setting[key])
for key, val in self_param_dict.items()
}
if p is not None:
params, funs, val_loss = train_function(
X, y, w, p=p, return_val_loss=True, **train_param_dict
)
else:
params, funs, val_loss = train_function(
X, y, w, return_val_loss=True, **train_param_dict
)
models.append((params, funs))
losses.append(val_loss)
# save results
param_settings.extend(param_grid)
self._selection_results = {
"param_settings": param_settings,
"val_losses": losses,
}
# find lowest loss and set params
best_idx = jnp.array(losses).argmin()
self._params, self._predict_funs = models[best_idx]
self.set_params(**param_settings[best_idx])
return self
def train_output_net_only(
X: jnp.ndarray,
y: jnp.ndarray,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
n_layers_r: int = 0,
n_units_r: int = DEFAULT_UNITS_R,
penalty_l2: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
return_val_loss: bool = False,
nonlin: str = DEFAULT_NONLIN,
avg_objective: bool = False,
) -> Any:
# function to train a single output head
# input check
y = check_shape_1d_data(y)
d = X.shape[1]
input_shape = (-1, d)
rng_key = random.PRNGKey(seed)
onp.random.seed(seed) # set seed for data generation via numpy as well
# get validation split (can be none)
X, y, X_val, y_val, val_string = make_val_split(
X, y, val_split_prop=val_split_prop, seed=seed
)
n = X.shape[0] # could be different from before due to split
# get output head
init_fun, predict_fun = OutputHead(
n_layers_out=n_layers_out,
n_units_out=n_units_out,
binary_y=binary_y,
n_layers_r=n_layers_r,
n_units_r=n_units_r,
nonlin=nonlin,
)
# get functions
if not binary_y:
# define loss and grad
@jit
def loss(
params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float
) -> jnp.ndarray:
# mse loss function
inputs, targets = batch
preds = predict_fun(params, inputs)
weightsq = sum(
[
jnp.sum(params[i][0] ** 2)
for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2)
]
)
if not avg_objective:
return jnp.sum((preds - targets) ** 2) + 0.5 * penalty * weightsq
else:
return jnp.average((preds - targets) ** 2) + 0.5 * penalty * weightsq
else:
# get loss and grad
@jit
def loss(
params: List, batch: Tuple[jnp.ndarray, jnp.ndarray], penalty: float
) -> jnp.ndarray:
# mse loss function
inputs, targets = batch
preds = predict_fun(params, inputs)
weightsq = sum(
[
jnp.sum(params[i][0] ** 2)
for i in range(0, 2 * (n_layers_out + n_layers_r) + 1, 2)
]
)
if not avg_objective:
return (
-jnp.sum(
targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)
)
+ 0.5 * penalty * weightsq
)
else:
return (
-jnp.average(
targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)
)
+ 0.5 * penalty * weightsq
)
# set optimization routine
# set optimizer
opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
# set update function
@jit
def update(i: int, state: dict, batch: jnp.ndarray, penalty: float) -> jnp.ndarray:
params = get_params(state)
g_params = grad(loss)(params, batch, penalty)
return opt_update(i, g_params, state)
# initialise states
_, init_params = init_fun(rng_key, input_shape)
opt_state = opt_init(init_params)
# calculate number of batches per epoch
batch_size = batch_size if batch_size < n else n
n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
train_indices = onp.arange(n)
l_best = LARGE_VAL
p_curr = 0
# do training
for i in range(n_iter):
# shuffle data for minibatches
onp.random.shuffle(train_indices)
for b in range(n_batches):
idx_next = train_indices[
(b * batch_size) : min((b + 1) * batch_size, n - 1)
]
next_batch = X[idx_next, :], y[idx_next, :]
opt_state = update(i * n_batches + b, opt_state, next_batch, penalty_l2)
if (i % n_iter_print == 0) or early_stopping:
params_curr = get_params(opt_state)
l_curr = loss(params_curr, (X_val, y_val), penalty_l2)
if i % n_iter_print == 0:
log.info(f"Epoch: {i}, current {val_string} loss: {l_curr}")
if early_stopping and ((i + 1) * n_batches > n_iter_min):
# check if loss updated
if l_curr < l_best:
l_best = l_curr
p_curr = 0
else:
p_curr = p_curr + 1
if p_curr > patience:
trained_params = get_params(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss(trained_params, (X_val, y_val), 0)
return trained_params, predict_fun, l_final
return trained_params, predict_fun
# get final parameters
trained_params = get_params(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss(trained_params, (X_val, y_val), 0)
return trained_params, predict_fun, l_final
return trained_params, predict_fun
================================================
FILE: catenets/models/jax/disentangled_nets.py
================================================
"""
Class implements SNet-3, a variation on DR-CFR discussed in
Hassanpour and Greiner (2020) and Wu et al (2020).
"""
# Author: Alicia Curth
from typing import Any, Callable, List, Tuple
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.example_libraries import optimizers
import catenets.logger as log
from catenets.models.constants import (
DEFAULT_AVG_OBJECTIVE,
DEFAULT_BATCH_SIZE,
DEFAULT_LAYERS_OUT,
DEFAULT_LAYERS_R,
DEFAULT_N_ITER,
DEFAULT_N_ITER_MIN,
DEFAULT_N_ITER_PRINT,
DEFAULT_NONLIN,
DEFAULT_PATIENCE,
DEFAULT_PENALTY_DISC,
DEFAULT_PENALTY_L2,
DEFAULT_PENALTY_ORTHOGONAL,
DEFAULT_SEED,
DEFAULT_STEP_SIZE,
DEFAULT_UNITS_OUT,
DEFAULT_UNITS_R_BIG_S3,
DEFAULT_UNITS_R_SMALL_S3,
DEFAULT_VAL_SPLIT,
LARGE_VAL,
)
from catenets.models.jax.base import BaseCATENet, OutputHead, ReprBlock
from catenets.models.jax.model_utils import (
check_shape_1d_data,
heads_l2_penalty,
make_val_split,
)
from catenets.models.jax.representation_nets import mmd2_lin
# helper functions to avoid abstract tracer values in jit
def _get_absolute_rowsums(mat: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(jnp.abs(mat), axis=1)
def _concatenate_representations(reps: jnp.ndarray) -> jnp.ndarray:
return jnp.concatenate(reps, axis=1)
class SNet3(BaseCATENet):
"""
Class implements SNet-3, which is based on Hassanpour & Greiner (2020)'s DR-CFR (Without
propensity weighting), using an orthogonal regularizer to enforce decomposition similar to
Wu et al (2020).
Parameters
----------
binary_y: bool, default False
Whether the outcome is binary
n_layers_out: int
Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_layers_out_prop: int
Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense
layer)
n_units_out: int
Number of hidden units in each hypothesis layer
n_units_out_prop: int
Number of hidden units in each propensity score hypothesis layer
n_layers_r: int
Number of shared & private representation layers before hypothesis layers
n_units_r: int
Number of hidden units in representation layer shared by propensity score and outcome
function (the 'confounding factor')
n_units_r_small: int
Number of hidden units in representation layer NOT shared by propensity score and outcome
functions (the 'outcome factor' and the 'instrumental factor')
penalty_l2: float
l2 (ridge) penalty
step_size: float
learning rate for optimizer
n_iter: int
Maximum number of iterations
batch_size: int
Batch size
val_split_prop: float
Proportion of samples used for validation split (can be 0)
early_stopping: bool, default True
Whether to use early stopping
patience: int
Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min: int
Minimum number of iterations to go through before starting early stopping
n_iter_print: int
Number of iterations after which to print updates
seed: int
Seed used
reg_diff: bool, default False
Whether to regularize the difference between the two potential outcome heads
penalty_diff: float
l2-penalty for regularizing the difference between output heads. used only if
train_separate=False
same_init: bool, False
Whether to initialise the two output heads with same values
nonlin: string, default 'elu'
Nonlinearity to use in NN
penalty_disc: float, default zero
Discrepancy penalty. Defaults to zero as this feature is not tested.
"""
def __init__(
self,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R_BIG_S3,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3,
n_units_out: int = DEFAULT_UNITS_OUT,
n_units_out_prop: int = DEFAULT_UNITS_OUT,
n_layers_out_prop: int = DEFAULT_LAYERS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,
penalty_disc: float = DEFAULT_PENALTY_DISC,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
nonlin: str = DEFAULT_NONLIN,
reg_diff: bool = False,
penalty_diff: float = DEFAULT_PENALTY_L2,
same_init: bool = False,
) -> None:
self.binary_y = binary_y
self.n_layers_r = n_layers_r
self.n_layers_out = n_layers_out
self.n_layers_out_prop = n_layers_out_prop
self.n_units_r = n_units_r
self.n_units_r_small = n_units_r_small
self.n_units_out = n_units_out
self.n_units_out_prop = n_units_out_prop
self.nonlin = nonlin
self.penalty_l2 = penalty_l2
self.penalty_orthogonal = penalty_orthogonal
self.penalty_disc = penalty_disc
self.reg_diff = reg_diff
self.penalty_diff = penalty_diff
self.same_init = same_init
self.step_size = step_size
self.n_iter = n_iter
self.batch_size = batch_size
self.val_split_prop = val_split_prop
self.early_stopping = early_stopping
self.patience = patience
self.n_iter_min = n_iter_min
self.seed = seed
self.n_iter_print = n_iter_print
def _get_predict_function(self) -> Callable:
return predict_snet3
def _get_train_function(self) -> Callable:
return train_snet3
# SNET-3 -------------------------------------------------------------
def train_snet3(
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R_BIG_S3,
n_units_r_small: int = DEFAULT_UNITS_R_SMALL_S3,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
n_units_out_prop: int = DEFAULT_UNITS_OUT,
n_layers_out_prop: int = DEFAULT_LAYERS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_disc: float = DEFAULT_PENALTY_DISC,
penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
n_iter_min: int = DEFAULT_N_ITER_MIN,
patience: int = DEFAULT_PATIENCE,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
return_val_loss: bool = False,
reg_diff: bool = False,
penalty_diff: float = DEFAULT_PENALTY_L2,
nonlin: str = DEFAULT_NONLIN,
avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
same_init: bool = False,
) -> Any:
"""
SNet-3, based on the decompostion used in Hassanpour and Greiner (2020)
"""
# function to train a net with 3 representations
y, w = check_shape_1d_data(y), check_shape_1d_data(w)
d = X.shape[1]
input_shape = (-1, d)
rng_key = random.PRNGKey(seed)
onp.random.seed(seed) # set seed for data generation via numpy as well
if not reg_diff:
penalty_diff = penalty_l2
# get validation split (can be none)
X, y, w, X_val, y_val, w_val, val_string = make_val_split(
X, y, w, val_split_prop=val_split_prop, seed=seed
)
n = X.shape[0] # could be different from before due to split
# get representation layers
init_fun_repr, predict_fun_repr = ReprBlock(
n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin
)
init_fun_repr_small, predict_fun_repr_small = ReprBlock(
n_layers=n_layers_r, n_units=n_units_r_small, nonlin=nonlin
)
# get output head functions (output heads share same structure)
init_fun_head_po, predict_fun_head_po = OutputHead(
n_layers_out=n_layers_out,
n_units_out=n_units_out,
binary_y=binary_y,
nonlin=nonlin,
)
# add propensity head
init_fun_head_prop, predict_fun_head_prop = OutputHead(
n_layers_out=n_layers_out_prop,
n_units_out=n_units_out_prop,
binary_y=True,
nonlin=nonlin,
)
def init_fun_snet3(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]:
# chain together the layers
# param should look like [repr_c, repr_o, repr_t, po_0, po_1, prop]
# initialise representation layers
rng, layer_rng = random.split(rng)
input_shape_repr, param_repr_c = init_fun_repr(layer_rng, input_shape)
rng, layer_rng = random.split(rng)
input_shape_repr_small, param_repr_o = init_fun_repr_small(
layer_rng, input_shape
)
rng, layer_rng = random.split(rng)
_, param_repr_w = init_fun_repr_small(layer_rng, input_shape)
# each head gets two representations
input_shape_repr = input_shape_repr[:-1] + (
input_shape_repr[-1] + input_shape_repr_small[-1],
)
# initialise output heads
rng, layer_rng = random.split(rng)
if same_init:
# initialise both on same values
input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)
input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)
else:
input_shape, param_0 = init_fun_head_po(layer_rng, input_shape_repr)
rng, layer_rng = random.split(rng)
input_shape, param_1 = init_fun_head_po(layer_rng, input_shape_repr)
rng, layer_rng = random.split(rng)
input_shape, param_prop = init_fun_head_prop(layer_rng, input_shape_repr)
return input_shape, [
param_repr_c,
param_repr_o,
param_repr_w,
param_0,
param_1,
param_prop,
]
# Define loss functions
# loss functions for the head
if not binary_y:
def loss_head(
params: List,
batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
penalty: float,
) -> jnp.ndarray:
# mse loss function
inputs, targets, weights = batch
preds = predict_fun_head_po(params, inputs)
return jnp.sum(weights * ((preds - targets) ** 2))
else:
def loss_head(
params: List,
batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
penalty: float,
) -> jnp.ndarray:
# log loss function
inputs, targets, weights = batch
preds = predict_fun_head_po(params, inputs)
return -jnp.sum(
weights
* (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))
)
def loss_head_prop(
params: List,
batch: Tuple[jnp.ndarray, jnp.ndarray],
penalty: float,
) -> jnp.ndarray:
# log loss function for propensities
inputs, targets = batch
preds = predict_fun_head_prop(params, inputs)
return -jnp.sum(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))
# complete loss function for all parts
@jit
def loss_snet3(
params: List,
batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
penalty_l2: float,
penalty_orthogonal: float,
penalty_disc: float,
) -> jnp.ndarray:
# params: list[repr_c, repr_o, repr_t, po_0, po_1, prop]
# batch: (X, y, w)
X, y, w = batch
# get representation
reps_c = predict_fun_repr(params[0], X)
reps_o = predict_fun_repr_small(params[1], X)
reps_w = predict_fun_repr_small(params[2], X)
# concatenate
reps_po = _concatenate_representations((reps_c, reps_o))
reps_prop = _concatenate_representations((reps_c, reps_w))
# pass down to heads
loss_0 = loss_head(params[3], (reps_po, y, 1 - w), penalty_l2)
loss_1 = loss_head(params[4], (reps_po, y, w), penalty_l2)
# pass down to propensity head
loss_prop = loss_head_prop(params[5], (reps_prop, w), penalty_l2)
weightsq_prop = sum(
[
jnp.sum(params[5][i][0] ** 2)
for i in range(0, 2 * n_layers_out_prop + 1, 2)
]
)
# which variable has impact on which representation
col_c = _get_absolute_rowsums(params[0][0][0])
col_o = _get_absolute_rowsums(params[1][0][0])
col_w = _get_absolute_rowsums(params[2][0][0])
loss_o = penalty_orthogonal * (
jnp.sum(col_c * col_o + col_c * col_w + col_w * col_o)
)
# is rep_o balanced between groups?
loss_disc = penalty_disc * mmd2_lin(reps_o, w)
# weight decay on representations
weightsq_body = sum(
[
sum(
[jnp.sum(params[j][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]
)
for j in range(3)
]
)
weightsq_head = heads_l2_penalty(
params[3], params[4], n_layers_out, reg_diff, penalty_l2, penalty_diff
)
if not avg_objective:
return (
loss_0
+ loss_1
+ loss_prop
+ loss_o
+ loss_disc
+ 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
)
else:
n_batch = y.shape[0]
return (
(loss_0 + loss_1) / n_batch
+ loss_prop / n_batch
+ loss_o
+ loss_disc
+ 0.5 * (penalty_l2 * (weightsq_body + weightsq_prop) + weightsq_head)
)
# Define optimisation routine
opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
@jit
def update(
i: int,
state: dict,
batch: jnp.ndarray,
penalty_l2: float,
penalty_orthogonal: float,
penalty_disc: float,
) -> jnp.ndarray:
# updating function
params = get_params(state)
return opt_update(
i,
grad(loss_snet3)(
params, batch, penalty_l2, penalty_orthogonal, penalty_disc
),
state,
)
# initialise states
_, init_params = init_fun_snet3(rng_key, input_shape)
opt_state = opt_init(init_params)
# calculate number of batches per epoch
batch_size = batch_size if batch_size < n else n
n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
train_indices = onp.arange(n)
l_best = LARGE_VAL
p_curr = 0
# do training
for i in range(n_iter):
# shuffle data for minibatches
onp.random.shuffle(train_indices)
for b in range(n_batches):
idx_next = train_indices[
(b * batch_size) : min((b + 1) * batch_size, n - 1)
]
next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
opt_state = update(
i * n_batches + b,
opt_state,
next_batch,
penalty_l2,
penalty_orthogonal,
penalty_disc,
)
if (i % n_iter_print == 0) or early_stopping:
params_curr = get_params(opt_state)
l_curr = loss_snet3(
params_curr,
(X_val, y_val, w_val),
penalty_l2,
penalty_orthogonal,
penalty_disc,
)
if i % n_iter_print == 0:
log.info(f"Epoch: {i}, current {val_string} loss {l_curr}")
if early_stopping and ((i + 1) * n_batches > n_iter_min):
# check if loss updated
if l_curr < l_best:
l_best = l_curr
p_curr = 0
params_best = params_curr
else:
if onp.isnan(l_curr):
# if diverged, return best
return params_best, (
predict_fun_repr,
predict_fun_head_po,
predict_fun_head_prop,
)
p_curr = p_curr + 1
if p_curr > patience:
if return_val_loss:
# return loss without penalty
l_final = loss_snet3(params_curr, (X_val, y_val, w_val), 0, 0, 0)
return (
params_curr,
(predict_fun_repr, predict_fun_head_po, predict_fun_head_prop),
l_final,
)
return params_curr, (
predict_fun_repr,
predict_fun_head_po,
predict_fun_head_prop,
)
# return the parameters
trained_params = get_params(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss_snet3(get_params(opt_state), (X_val, y_val, w_val), 0, 0)
return (
trained_params,
(predict_fun_repr, predict_fun_head_po, predict_fun_head_prop),
l_final,
)
return trained_params, (
predict_fun_repr,
predict_fun_head_po,
predict_fun_head_prop,
)
def predict_snet3(
X: jnp.ndarray,
trained_params: dict,
predict_funs: list,
return_po: bool = False,
return_prop: bool = False,
) -> jnp.ndarray:
# unpack inputs
predict_fun_repr, predict_fun_head, predict_fun_prop = predict_funs
param_repr_c, param_repr_o, param_repr_t = (
trained_params[0],
trained_params[1],
trained_params[2],
)
param_0, param_1, param_prop = (
trained_params[3],
trained_params[4],
trained_params[5],
)
# get representations
rep_c = predict_fun_repr(param_repr_c, X)
rep_o = predict_fun_repr(param_repr_o, X)
rep_w = predict_fun_repr(param_repr_t, X)
# concatenate
reps_po = jnp.concatenate((rep_c, rep_o), axis=1)
reps_prop = jnp.concatenate((rep_c, rep_w), axis=1)
# get potential outcomes
mu_0 = predict_fun_head(param_0, reps_po)
mu_1 = predict_fun_head(param_1, reps_po)
te = mu_1 - mu_0
if return_prop:
# get propensity
prop = predict_fun_prop(param_prop, reps_prop)
# stack other outputs
if return_po:
if return_prop:
return te, mu_0, mu_1, prop
else:
return te, mu_0, mu_1
else:
if return_prop:
return te, prop
else:
return te
================================================
FILE: catenets/models/jax/flextenet.py
================================================
"""
Module implements FlexTENet, also referred to as the 'flexible approach' in "On inductive biases
for heterogeneous treatment effect estimation", Curth & vd Schaar (2021).
"""
# Author: Alicia Curth
from typing import Any, Callable, Optional, Tuple
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.example_libraries import optimizers
from jax.example_libraries.stax import (
Dense,
Sigmoid,
elu,
glorot_normal,
normal,
serial,
)
import catenets.logger as log
from catenets.models.constants import (
DEFAULT_BATCH_SIZE,
DEFAULT_DIM_P_OUT,
DEFAULT_DIM_P_R,
DEFAULT_DIM_S_OUT,
DEFAULT_DIM_S_R,
DEFAULT_LAYERS_OUT,
DEFAULT_LAYERS_R,
DEFAULT_N_ITER,
DEFAULT_N_ITER_MIN,
DEFAULT_N_ITER_PRINT,
DEFAULT_NONLIN,
DEFAULT_PATIENCE,
DEFAULT_PENALTY_L2,
DEFAULT_PENALTY_ORTHOGONAL,
DEFAULT_SEED,
DEFAULT_STEP_SIZE,
DEFAULT_VAL_SPLIT,
LARGE_VAL,
N_SUBSPACES,
)
from catenets.models.jax.base import BaseCATENet
from catenets.models.jax.model_utils import check_shape_1d_data, make_val_split
class FlexTENet(BaseCATENet):
"""
Module implements FlexTENet, an architecture for treatment effect estimation that allows for
both shared and private information in each layer of the network.
Parameters
----------
binary_y: bool, default False
Whether the outcome is binary
n_layers_out: int
Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_s_out: int
Number of hidden units in each shared hypothesis layer
n_units_p_out: int
Number of hidden units in each private hypothesis layer
n_layers_r: int
Number of representation layers before hypothesis layers (distinction between
hypothesis layers and representation layers is made to match TARNet & SNets)
n_units_s_r: int
Number of hidden units in each shared representation layer
n_units_s_r: int
Number of hidden units in each private representation layer
private_out: bool, False
Whether the final prediction layer should be fully private, or retain a shared component.
penalty_l2: float
l2 (ridge) penalty
penalty_l2_p: float
l2 (ridge) penalty for private layers
penalty_orthogonal: float
orthogonalisation penalty
step_size: float
learning rate for optimizer
n_iter: int
Maximum number of iterations
batch_size: int
Batch size
val_split_prop: float
Proportion of samples used for validation split (can be 0)
early_stopping: bool, default True
Whether to use early stopping
patience: int
Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min: int
Minimum number of iterations to go through before starting early stopping
n_iter_print: int
Number of iterations after which to print updates
seed: int
Seed used
opt: str, default 'adam'
Optimizer to use, accepts 'adam' and 'sgd'
shared_repr: bool, False
Whether to use a shared representation block as TARNet
pretrain_shared: bool, False
Whether to pretrain the shared component of the network while freezing the private
parameters
same_init: bool, True
Whether to use the same initialisation for all private spaces
lr_scale: float
Whether to scale down the learning rate after unfreezing the private components of the
network (only used if pretrain_shared=True)
normalize_ortho: bool, False
Whether to normalize the orthogonality penalty (by depth of network)
"""
def __init__(
self,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_s_out: int = DEFAULT_DIM_S_OUT,
n_units_p_out: int = DEFAULT_DIM_P_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_s_r: int = DEFAULT_DIM_S_R,
n_units_p_r: int = DEFAULT_DIM_P_R,
private_out: bool = False,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_p: float = DEFAULT_PENALTY_L2,
penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
return_val_loss: bool = False,
opt: str = "adam",
shared_repr: bool = False,
pretrain_shared: bool = False,
same_init: bool = True,
lr_scale: float = 10,
normalize_ortho: bool = False,
) -> None:
self.binary_y = binary_y
self.n_layers_r = n_layers_r
self.n_layers_out = n_layers_out
self.n_units_s_out = n_units_s_out
self.n_units_p_out = n_units_p_out
self.n_units_s_r = n_units_s_r
self.n_units_p_r = n_units_p_r
self.private_out = private_out
self.penalty_orthogonal = penalty_orthogonal
self.penalty_l2 = penalty_l2
self.penalty_l2_p = penalty_l2_p
self.step_size = step_size
self.n_iter = n_iter
self.batch_size = batch_size
self.val_split_prop = val_split_prop
self.early_stopping = early_stopping
self.patience = patience
self.n_iter_min = n_iter_min
self.opt = opt
self.same_init = same_init
self.shared_repr = shared_repr
self.normalize_ortho = normalize_ortho
self.pretrain_shared = pretrain_shared
self.lr_scale = lr_scale
self.seed = seed
self.n_iter_print = n_iter_print
self.return_val_loss = return_val_loss
def _get_train_function(self) -> Callable:
return train_flextenet
def _get_predict_function(self) -> Callable:
return predict_flextenet
def train_flextenet(
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_s_out: int = DEFAULT_DIM_S_OUT,
n_units_p_out: int = DEFAULT_DIM_P_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_s_r: int = DEFAULT_DIM_S_R,
n_units_p_r: int = DEFAULT_DIM_P_R,
private_out: bool = False,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_p: float = DEFAULT_PENALTY_L2,
penalty_orthogonal: float = DEFAULT_PENALTY_ORTHOGONAL,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
avg_objective: bool = True,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
return_val_loss: bool = False,
opt: str = "adam",
shared_repr: bool = False,
pretrain_shared: bool = False,
same_init: bool = True,
lr_scale: float = 10,
normalize_ortho: bool = False,
nonlin: str = DEFAULT_NONLIN,
n_units_r: Optional[int] = None,
n_units_out: Optional[int] = None,
) -> Tuple: # TODO incorporate different nonlins here
# function to train a single output head
# input check
y, w = check_shape_1d_data(y), check_shape_1d_data(w)
d = X.shape[1]
input_shape = (-1, d)
rng_key = random.PRNGKey(seed)
onp.random.seed(seed) # set seed for data generation via numpy as well
# get validation split (can be none)
X, y, w, X_val, y_val, w_val, val_string = make_val_split(
X, y, w, val_split_prop=val_split_prop, seed=seed
)
n = X.shape[0] # could be different from before due to split
# get output head
init_fun, predict_fun = FlexTENetArchitecture(
n_layers_out=n_layers_out,
n_layers_r=n_layers_r,
n_units_p_r=n_units_p_r,
n_units_p_out=n_units_p_out,
n_units_s_r=n_units_s_r,
n_units_s_out=n_units_s_out,
private_out=private_out,
shared_repr=shared_repr,
same_init=same_init,
binary_y=binary_y,
)
# get functions
if not binary_y:
# define loss and grad
@jit
def loss(
params: jnp.ndarray,
batch: jnp.ndarray,
penalty_l2: float,
penalty_l2_p: float,
penalty_orthogonal: float,
mode: int,
) -> jnp.ndarray:
# mse loss function
inputs, targets = batch
preds = predict_fun(params, inputs, mode=mode)
penalty = _compute_penalty(
params,
n_layers_out,
n_layers_r,
private_out,
penalty_l2,
penalty_l2_p,
penalty_orthogonal,
shared_repr,
normalize_ortho,
mode,
)
if not avg_objective:
return jnp.sum((preds - targets) ** 2) + penalty
else:
return jnp.average((preds - targets) ** 2) + penalty
else:
# get loss and grad
@jit
def loss(
params: jnp.ndarray,
batch: jnp.ndarray,
penalty_l2: float,
penalty_l2_p: float,
penalty_orthogonal: float,
mode: int,
) -> jnp.ndarray:
# mse loss function
inputs, targets = batch
preds = predict_fun(params, inputs, mode=mode)
penalty = _compute_penalty(
params,
n_layers_out,
n_layers_r,
private_out,
penalty_l2,
penalty_l2_p,
penalty_orthogonal,
shared_repr,
normalize_ortho,
mode,
)
if not avg_objective:
return (
-jnp.sum(
targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)
)
+ penalty
)
else:
return (
-jnp.average(
targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds)
)
+ penalty
)
# set optimization routine
# set optimizer
if opt == "adam":
opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
elif opt == "sgd":
opt_init, opt_update, get_params = optimizers.sgd(step_size=step_size)
else:
raise ValueError("opt should be adam or sgd")
# set update function
@jit
def update(
i: int,
state: dict,
batch: jnp.ndarray,
penalty_l2: float,
penalty_l2_p: float,
penalty_orthogonal: float,
mode: int,
) -> jnp.ndarray:
params = get_params(state)
g_params = grad(loss)(
params, batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode
)
return opt_update(i, g_params, state)
# initialise states
_, init_params = init_fun(rng_key, input_shape)
opt_state = opt_init(init_params)
# calculate number of batches per epoch
batch_size = batch_size if batch_size < n else n
n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
train_indices = onp.arange(n)
l_best = LARGE_VAL
p_curr = 0
# do training
if not pretrain_shared: # train entire model together
for i in range(n_iter):
# shuffle data for minibatches
onp.random.shuffle(train_indices)
for b in range(n_batches):
idx_next = train_indices[
(b * batch_size) : min((b + 1) * batch_size, n - 1)
]
next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :]
opt_state = update(
i * n_batches + b,
opt_state,
next_batch,
penalty_l2,
penalty_l2_p,
penalty_orthogonal,
mode=1,
)
if (i % n_iter_print == 0) or early_stopping:
params_curr = get_params(opt_state)
l_curr = loss(
params_curr,
((X_val, w_val), y_val),
penalty_l2,
penalty_l2_p,
penalty_orthogonal,
mode=1,
)
if i % n_iter_print == 0:
log.debug(f"Epoch: {i}, current {val_string} loss: {l_curr}")
if early_stopping and ((i + 1) * n_batches > n_iter_min):
# check if loss updated
if l_curr < l_best:
l_best = l_curr
p_curr = 0
else:
p_curr = p_curr + 1
if p_curr > patience:
trained_params = get_params(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss(
trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1
)
return trained_params, predict_fun, l_final
return trained_params, predict_fun
# get final parameters
trained_params = get_params(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss(trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1)
return trained_params, predict_fun, l_final
return trained_params, predict_fun
else:
# Step 1: pretrain only shared bit of network (mode=0)
for i in range(n_iter):
# shuffle data for minibatches
onp.random.shuffle(train_indices)
for b in range(n_batches):
idx_next = train_indices[
(b * batch_size) : min((b + 1) * batch_size, n - 1)
]
next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :]
opt_state = update(
i * n_batches + b,
opt_state,
next_batch,
penalty_l2,
penalty_l2_p,
penalty_orthogonal,
mode=0,
)
if (i % n_iter_print == 0) or early_stopping:
params_curr = get_params(opt_state)
l_curr = loss(
params_curr,
((X_val, w_val), y_val),
penalty_l2,
penalty_l2_p,
penalty_orthogonal,
mode=0,
)
if i % n_iter_print == 0:
log.debug(
f"Pre-training epoch: {i}, current {val_string} loss: {l_curr}"
)
if early_stopping and ((i + 1) * n_batches > n_iter_min):
# check if loss updated
if l_curr < l_best:
l_best = l_curr
p_curr = 0
else:
p_curr = p_curr + 1
if p_curr > patience:
break
# get final parameters
pre_trained_params = get_params(opt_state)
# Step 2: train also private parts of network (mode=1)
# set new optimizer
if opt == "adam":
opt_init2, opt_update2, get_params2 = optimizers.adam(
step_size=step_size / lr_scale
)
elif opt == "sgd":
opt_init2, opt_update2, get_params2 = optimizers.sgd(
step_size=step_size / lr_scale
)
else:
raise ValueError("opt should be adam or sgd")
# set update function
@jit
def update2(
i: int,
state: dict,
batch: jnp.ndarray,
penalty_l2: float,
penalty_l2_p: float,
penalty_orthogonal: float,
mode: int,
) -> Any:
params = get_params(state)
g_params = grad(loss)(
params, batch, penalty_l2, penalty_l2_p, penalty_orthogonal, mode
)
return opt_update2(i, g_params, state)
opt_state = opt_init2(pre_trained_params)
l_best = LARGE_VAL
p_curr = 0
# train full
for i in range(n_iter):
# shuffle data for minibatches
onp.random.shuffle(train_indices)
for b in range(n_batches):
idx_next = train_indices[
(b * batch_size) : min((b + 1) * batch_size, n - 1)
]
next_batch = (X[idx_next, :], w[idx_next]), y[idx_next, :]
opt_state = update2(
i * n_batches + b,
opt_state,
next_batch,
penalty_l2,
penalty_l2_p,
penalty_orthogonal,
mode=1,
)
if (i % n_iter_print == 0) or early_stopping:
params_curr = get_params2(opt_state)
l_curr = loss(
params_curr,
((X_val, w_val), y_val),
penalty_l2,
penalty_l2_p,
penalty_orthogonal,
mode=1,
)
if i % n_iter_print == 0:
log.debug(f"Epoch: {i}, current {val_string} loss: {l_curr}")
if early_stopping and ((i + 1) * n_batches > n_iter_min):
# check if loss updated
if l_curr < l_best:
l_best = l_curr
p_curr = 0
else:
p_curr = p_curr + 1
if p_curr > patience:
trained_params = get_params2(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss(
trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1
)
return trained_params, predict_fun, l_final
return trained_params, predict_fun
# get final parameters
trained_params = get_params2(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss(trained_params, ((X_val, w_val), y_val), 0, 0, 0, mode=1)
return trained_params, predict_fun, l_final
return trained_params, predict_fun
def predict_flextenet(
X: jnp.ndarray,
trained_params: jnp.ndarray,
predict_funs: Callable,
return_po: bool = False,
return_prop: bool = False,
) -> Any:
# unpack inputs
n, _ = X.shape
W1 = check_shape_1d_data(jnp.ones(n))
W0 = check_shape_1d_data(jnp.zeros(n))
# get potential outcomes
mu_0 = predict_funs(trained_params, (X, W0))
mu_1 = predict_funs(trained_params, (X, W1))
te = mu_1 - mu_0
if return_prop:
raise ValueError("does not have propensity score estimator")
# stack other outputs
if return_po:
return te, mu_0, mu_1
else:
return te
# helper functions for training
def _get_cos_reg(
params_0: jnp.ndarray, params_1: jnp.ndarray, normalize: bool
) -> jnp.ndarray:
if normalize:
params_0 = params_0 / jnp.linalg.norm(params_0, axis=0)
params_1 = params_1 / jnp.linalg.norm(params_1, axis=0)
return jnp.linalg.norm(jnp.dot(jnp.transpose(params_0), params_1), "fro") ** 2
def _compute_ortho_penalty_asymmetric(
params: jnp.ndarray,
n_layers_out: int,
n_layers_r: int,
private_out: int,
penalty_orthogonal: float,
shared_repr: bool,
normalize_ortho: bool,
mode: int = 1,
) -> float:
# where to start counting: is there a fully shared representation?
if shared_repr:
lb = 2 * n_layers_r
else:
lb = 0
n_in = [
params[i][0][0].shape[0] for i in range(lb, 2 * (n_layers_out + n_layers_r), 2)
]
ortho_body = _get_cos_reg(params[lb][1][0], params[lb][2][0], normalize_ortho)
ortho_body = ortho_body + sum(
[
_get_cos_reg(
params[i][0][0],
params[i][1][0][: n_in[int(i / 2 - lb / 2)], :],
normalize_ortho,
)
+ _get_cos_reg(
params[i][0][0],
params[i][2][0][: n_in[int(i / 2 - lb / 2)], :],
normalize_ortho,
)
for i in range(lb, 2 * (n_layers_out + n_layers_r), 2)
]
)
if not private_out:
# add also orthogonal regularization on final layer
idx_out = 2 * (n_layers_r + n_layers_out)
n_idx = params[idx_out][0][0].shape[0]
ortho_body = (
ortho_body
+ _get_cos_reg(
params[idx_out][0][0],
params[idx_out][1][0][:n_idx, :],
normalize_ortho,
)
+ _get_cos_reg(
params[idx_out][0][0], params[idx_out][2][0][:n_idx, :], normalize_ortho
)
)
return mode * penalty_orthogonal * ortho_body
def _compute_penalty_l2(
params: jnp.ndarray,
n_layers_out: int,
n_layers_r: int,
private_out: int,
penalty_l2: float,
penalty_l2_p: float,
shared_repr: bool,
mode: int = 1,
) -> jnp.ndarray:
n_bodys = N_SUBSPACES
# compute l2 penalty
if shared_repr:
# get representation and then heads
weightsq_body = penalty_l2 * sum(
[jnp.sum(params[i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]
)
weightsq_body = weightsq_body + penalty_l2 * sum(
[
jnp.sum(params[i][0][0] ** 2)
for i in range(2 * n_layers_r, 2 * (n_layers_out + n_layers_r), 2)
]
)
weightsq_body = weightsq_body + penalty_l2_p * mode * sum(
[
sum(
[
jnp.sum(params[i][j][0] ** 2)
for i in range(
2 * n_layers_r, 2 * (n_layers_out + n_layers_r), 2
)
]
)
for j in range(1, n_bodys)
]
)
else:
weightsq_body = penalty_l2 * sum(
[
jnp.sum(params[i][0][0] ** 2)
for i in range(0, 2 * (n_layers_out + n_layers_r), 2)
]
)
weightsq_body = weightsq_body + penalty_l2_p * mode * sum(
[
sum(
[
jnp.sum(params[i][j][0] ** 2)
for i in range(0, 2 * (n_layers_out + n_layers_r), 2)
]
)
for j in range(1, n_bodys)
]
)
idx_out = 2 * (n_layers_r + n_layers_out)
if private_out:
weightsq = (
weightsq_body
+ penalty_l2 * jnp.sum(params[idx_out][0][0] ** 2)
+ jnp.sum(params[idx_out][1][0] ** 2)
)
else:
weightsq = (
weightsq_body
+ penalty_l2 * jnp.sum(params[idx_out][0][0] ** 2)
+ penalty_l2_p * mode * jnp.sum(params[idx_out][1][0] ** 2)
+ penalty_l2_p * mode * jnp.sum(params[idx_out][2][0] ** 2)
)
return 0.5 * weightsq
def _compute_penalty(
params: jnp.ndarray,
n_layers_out: int,
n_layers_r: int,
private_out: int,
penalty_l2: float,
penalty_l2_p: float,
penalty_orthogonal: float,
shared_repr: bool,
normalize_ortho: bool,
mode: int = 1,
) -> jnp.ndarray:
l2_penalty = _compute_penalty_l2(
params,
n_layers_out,
n_layers_r,
private_out,
penalty_l2,
penalty_l2_p,
shared_repr,
mode,
)
ortho_penalty = _compute_ortho_penalty_asymmetric(
params,
n_layers_out,
n_layers_r,
private_out,
penalty_orthogonal,
shared_repr,
normalize_ortho,
mode,
)
return l2_penalty + ortho_penalty
# ------------------------------------------------------------
# construction of FlexTENetlayers/architecture
def SplitLayerAsymmetric(
n_units_s: int, n_units_p: int, first_layer: bool = False, same_init: bool = True
) -> Tuple:
# create multitask layer has shape [shared, private_0, private_1]
init_s, apply_s = Dense(n_units_s)
init_p, apply_p = Dense(n_units_p)
def init_fun(rng: float, input_shape: Tuple) -> Tuple:
if first_layer: # put input shape in expected format
input_shape = (input_shape, input_shape, input_shape)
out_shape = (
input_shape[0][:-1] + (n_units_s,),
input_shape[1][:-1] + (n_units_p + n_units_s,),
input_shape[2][:-1] + (n_units_p + n_units_s,),
)
rng_1, rng_2, rng_3 = random.split(rng, N_SUBSPACES)
if same_init: # use same init for the two private layers
return out_shape, (
init_s(rng_1, input_shape[0])[1],
init_p(rng_2, input_shape[1])[1],
init_p(rng_2, input_shape[2])[1],
)
else: # initialise all separately
return out_shape, (
init_s(rng_1, input_shape[0])[1],
init_p(rng_2, input_shape[1])[1],
init_p(rng_3, input_shape[2])[1],
)
def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple:
mode = kwargs["mode"] if "mode" in kwargs.keys() else 1
if first_layer:
# X is the only input
X, W = inputs
rep_s = apply_s(params[0], X)
rep_p0 = mode * apply_p(params[1], X)
rep_p1 = mode * apply_p(params[2], X)
else:
X_s, X_p0, X_p1, W = inputs
rep_s = apply_s(params[0], X_s)
rep_p0 = mode * apply_p(params[1], jnp.concatenate([X_s, X_p0], axis=1))
rep_p1 = mode * apply_p(params[2], jnp.concatenate([X_s, X_p1], axis=1))
return (rep_s, rep_p0, rep_p1, W)
return init_fun, apply_fun
def TEOutputLayerAsymmetric(private: bool = True, same_init: bool = True) -> Tuple:
init_f, apply_f = Dense(1)
if private:
# the two output layers are private
def init_fun(rng: float, input_shape: Tuple) -> Tuple:
out_shape = input_shape[1][:-1] + (1,)
rng_1, rng_2 = random.split(rng, N_SUBSPACES - 1)
return out_shape, (
init_f(rng_1, input_shape[1])[1],
init_f(rng_2, input_shape[2])[1],
)
def apply_fun(params: jnp.ndarray, inputs: Tuple, **kwargs: Any) -> jnp.ndarray:
X_s, X_p0, X_p1, W = inputs
rep_p0 = apply_f(params[0], jnp.concatenate([X_s, X_p0], axis=1))
rep_p1 = apply_f(params[1], jnp.concatenate([X_s, X_p1], axis=1))
return (1 - W) * rep_p0 + W * rep_p1
else:
# also have a shared piece of output layer
def init_fun(rng: float, input_shape: Tuple) -> Tuple:
out_shape = input_shape[1][:-1] + (1,)
rng_1, rng_2, rng_3 = random.split(rng, N_SUBSPACES)
if same_init:
return out_shape, (
init_f(rng_1, input_shape[0])[1],
init_f(rng_2, input_shape[1])[1],
init_f(rng_2, input_shape[2])[1],
)
else:
return out_shape, (
init_f(rng_1, input_shape[0])[1],
init_f(rng_2, input_shape[1])[1],
init_f(rng_3, input_shape[2])[1],
)
def apply_fun(params: jnp.ndarray, inputs: Tuple, **kwargs: Any) -> jnp.ndarray:
mode = kwargs["mode"] if "mode" in kwargs.keys() else 1
X_s, X_p0, X_p1, W = inputs
rep_s = apply_f(params[0], X_s)
rep_p0 = mode * apply_f(params[1], jnp.concatenate([X_s, X_p0], axis=1))
rep_p1 = mode * apply_f(params[2], jnp.concatenate([X_s, X_p1], axis=1))
return (1 - W) * rep_p0 + W * rep_p1 + rep_s
return init_fun, apply_fun
def FlexTENetArchitecture(
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_s_out: int = DEFAULT_DIM_S_OUT,
n_units_p_out: int = DEFAULT_DIM_P_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_s_r: int = DEFAULT_DIM_S_R,
n_units_p_r: int = DEFAULT_DIM_P_R,
private_out: bool = False,
binary_y: bool = False,
shared_repr: bool = False,
same_init: bool = True,
) -> Any:
if n_layers_out < 1:
raise ValueError(
"FlexTENet needs at least one hidden output layer (else there are no "
"parameters to be shared)"
)
Nonlin_Elu = Elu_parallel
Layer = SplitLayerAsymmetric
Head = TEOutputLayerAsymmetric
# give broader body (as in e.g. TARNet)
has_body = n_layers_r > 0
layers: Tuple = ()
if has_body:
# representation block first
if shared_repr: # fully shared representation as in TARNet
layers = (DenseW(n_units_s_r), Elu_split)
# add required number of layers
for i in range(n_layers_r - 1):
layers = (*layers, DenseW(n_units_s_r), Elu_split)
else: # shared AND private representations
layers = (
Layer(n_units_s_r, n_units_p_r, first_layer=True, same_init=same_init),
Nonlin_Elu,
)
# add required number of layers
for i in range(n_layers_r - 1):
layers = (
*layers,
Layer(n_units_s_r, n_units_p_r, same_init=same_init),
Nonlin_Elu,
)
else:
layers = ()
# add output layers
first_layer = (has_body is False) | (shared_repr is True)
layers = (
*layers,
Layer(
n_units_s_out, n_units_p_out, first_layer=first_layer, same_init=same_init
),
Nonlin_Elu,
)
if n_layers_out > 1:
# add required number of layers
for i in range(n_layers_out - 1):
layers = (
*layers,
Layer(n_units_s_out, n_units_p_out, same_init=same_init),
Nonlin_Elu,
)
# return final architecture
if not binary_y:
return serial(*layers, Head(private=private_out, same_init=same_init))
else:
return serial(*layers, Head(private=private_out, same_init=same_init), Sigmoid)
# ------------------------------------------------
# rewrite some jax.stax code to allow different input types to be passed
def elementwise_split(fun: Callable, **fun_kwargs: Any) -> Tuple:
"""Layer that applies a scalar function elementwise on its inputs. Adapted from original
jax.stax to skip treatment indicator.
Input looks like: X, t = inputs"""
def init_fun(rng: float, input_shape: Tuple) -> Tuple:
return (input_shape, ())
def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple:
return fun(inputs[0], **fun_kwargs), inputs[1]
return init_fun, apply_fun
Elu_split = elementwise_split(elu)
def elementwise_parallel(fun: Callable, **fun_kwargs: Any) -> Tuple:
"""Layer that applies a scalar function elementwise on its inputs. Adapted from original
jax.stax to allow three inputs and to skip treatment indicator.
Input looks like: X_s, X_p0, X_p1, t = inputs
"""
def init_fun(rng: float, input_shape: Tuple) -> Tuple:
return input_shape, ()
def apply_fun(params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any) -> Tuple:
return (
fun(inputs[0], **fun_kwargs),
fun(inputs[1], **fun_kwargs),
fun(inputs[2], **fun_kwargs),
inputs[3],
)
return init_fun, apply_fun
Elu_parallel = elementwise_parallel(elu)
def DenseW(
out_dim: int, W_init: Callable = glorot_normal(), b_init: Callable = normal()
) -> Tuple:
"""Layer constructor function for a dense (fully-connected) layer. Adapted to allow passing
treatment indicator through layer without using it"""
def init_fun(rng: float, input_shape: Tuple) -> Tuple:
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(
params: jnp.ndarray, inputs: jnp.ndarray, **kwargs: Any
) -> jnp.ndarray:
W, b = params
x, t = inputs
return (jnp.dot(x, W) + b, t)
return init_fun, apply_fun
================================================
FILE: catenets/models/jax/model_utils.py
================================================
"""
Model utils shared across different nets
"""
# Author: Alicia Curth
from typing import Any, Optional
import jax.numpy as jnp
import pandas as pd
from sklearn.model_selection import train_test_split
from catenets.models.constants import DEFAULT_SEED, DEFAULT_VAL_SPLIT
TRAIN_STRING = "training"
VALIDATION_STRING = "validation"
def check_shape_1d_data(y: jnp.ndarray) -> jnp.ndarray:
# helper func to ensure that output shape won't clash
# with jax internally
shape_y = y.shape
if len(shape_y) == 1:
# should be shape (n_obs, 1), not (n_obs,)
return y.reshape((shape_y[0], 1))
return y
def check_X_is_np(X: pd.DataFrame) -> jnp.ndarray:
# function to make sure we are using arrays only
return jnp.asarray(X)
def make_val_split(
X: jnp.ndarray,
y: jnp.ndarray,
w: Optional[jnp.ndarray] = None,
val_split_prop: float = DEFAULT_VAL_SPLIT,
seed: int = DEFAULT_SEED,
stratify_w: bool = True,
) -> Any:
if val_split_prop == 0:
# return original data
if w is None:
return X, y, X, y, TRAIN_STRING
return X, y, w, X, y, w, TRAIN_STRING
# make actual split
if w is None:
X_t, X_val, y_t, y_val = train_test_split(
X, y, test_size=val_split_prop, random_state=seed, shuffle=True
)
return X_t, y_t, X_val, y_val, VALIDATION_STRING
if stratify_w:
# split to stratify by group
X_t, X_val, y_t, y_val, w_t, w_val = train_test_split(
X,
y,
w,
test_size=val_split_prop,
random_state=seed,
stratify=w,
shuffle=True,
)
else:
X_t, X_val, y_t, y_val, w_t, w_val = train_test_split(
X, y, w, test_size=val_split_prop, random_state=seed, shuffle=True
)
return X_t, y_t, w_t, X_val, y_val, w_val, VALIDATION_STRING
def heads_l2_penalty(
params_0: jnp.ndarray,
params_1: jnp.ndarray,
n_layers_out: jnp.ndarray,
reg_diff: jnp.ndarray,
penalty_0: jnp.ndarray,
penalty_1: jnp.ndarray,
) -> jnp.ndarray:
# Compute l2 penalty for output heads. Either seperately, or regularizing their difference
# get l2-penalty for first head
weightsq_0 = penalty_0 * sum(
[jnp.sum(params_0[i][0] ** 2) for i in range(0, 2 * n_layers_out + 1, 2)]
)
# get l2-penalty for second head
if reg_diff:
weightsq_1 = penalty_1 * sum(
[
jnp.sum((params_1[i][0] - params_0[i][0]) ** 2)
for i in range(0, 2 * n_layers_out + 1, 2)
]
)
else:
weightsq_1 = penalty_1 * sum(
[jnp.sum(params_1[i][0] ** 2) for i in range(0, 2 * n_layers_out + 1, 2)]
)
return weightsq_1 + weightsq_0
================================================
FILE: catenets/models/jax/offsetnet.py
================================================
"""
Module implements OffsetNet, also referred to as the 'reparametrization approach' and 'hard
approach' in "On inductive biases for heterogeneous treatment effect estimation", Curth & vd
Schaar (2021); modeling the POs using a shared prognostic function and
an offset (treatment effect)
"""
# Author: Alicia Curth
from typing import Any, Callable, List, Tuple
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.example_libraries import optimizers
from jax.example_libraries.stax import sigmoid
import catenets.logger as log
from catenets.models.constants import (
DEFAULT_BATCH_SIZE,
DEFAULT_LAYERS_OUT,
DEFAULT_LAYERS_R,
DEFAULT_N_ITER,
DEFAULT_N_ITER_MIN,
DEFAULT_N_ITER_PRINT,
DEFAULT_NONLIN,
DEFAULT_PATIENCE,
DEFAULT_PENALTY_L2,
DEFAULT_SEED,
DEFAULT_STEP_SIZE,
DEFAULT_UNITS_OUT,
DEFAULT_UNITS_R,
DEFAULT_VAL_SPLIT,
LARGE_VAL,
)
from catenets.models.jax.base import BaseCATENet, OutputHead
from catenets.models.jax.model_utils import (
check_shape_1d_data,
heads_l2_penalty,
make_val_split,
)
class OffsetNet(BaseCATENet):
"""
Module implements OffsetNet, also referred to as the 'reparametrization approach' and 'hard
approach' in Curth & vd Schaar (2021); modeling the POs using a shared prognostic function and
an offset (treatment effect).
Parameters
----------
binary_y: bool, default False
Whether the outcome is binary
n_layers_out: int
Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_out: int
Number of hidden units in each hypothesis layer
n_layers_r: int
Number of representation layers before hypothesis layers (distinction between
hypothesis layers and representation layers is made to match TARNet & SNets)
n_units_r: int
Number of hidden units in each representation layer
penalty_l2: float
l2 (ridge) penalty
step_size: float
learning rate for optimizer
n_iter: int
Maximum number of iterations
batch_size: int
Batch size
val_split_prop: float
Proportion of samples used for validation split (can be 0)
early_stopping: bool, default True
Whether to use early stopping
patience: int
Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min: int
Minimum number of iterations to go through before starting early stopping
n_iter_print: int
Number of iterations after which to print updates
seed: int
Seed used
penalty_l2_p: float
l2-penalty for regularizing the offset
nonlin: string, default 'elu'
Nonlinearity to use in NN
"""
def __init__(
self,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_p: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
nonlin: str = DEFAULT_NONLIN,
):
# structure of net
self.binary_y = binary_y
self.n_layers_r = n_layers_r
self.n_layers_out = n_layers_out
self.n_units_r = n_units_r
self.n_units_out = n_units_out
self.nonlin = nonlin
# penalties
self.penalty_l2 = penalty_l2
self.penalty_l2_p = penalty_l2_p
# training params
self.step_size = step_size
self.n_iter = n_iter
self.batch_size = batch_size
self.n_iter_print = n_iter_print
self.seed = seed
self.val_split_prop = val_split_prop
self.early_stopping = early_stopping
self.patience = patience
self.n_iter_min = n_iter_min
def _get_train_function(self) -> Callable:
return train_offsetnet
def _get_predict_function(self) -> Callable:
return predict_offsetnet
def predict_offsetnet(
X: jnp.ndarray,
trained_params: jnp.ndarray,
predict_funs: List[Any],
return_po: bool = False,
return_prop: bool = False,
) -> jnp.ndarray:
if return_prop:
raise NotImplementedError("OffsetNet does not implement a propensity model.")
# unpack inputs
predict_fun_head = predict_funs[0]
binary_y = predict_funs[1]
param_0, param_1 = trained_params[0], trained_params[1]
# get potential outcomes
mu_0 = predict_fun_head(param_0, X)
offset = predict_fun_head(param_1, X)
if not binary_y:
if return_po:
return offset, mu_0, mu_0 + offset
else:
return offset
else:
# still need to sigmoid
po_0 = sigmoid(mu_0)
po_1 = sigmoid(mu_0 + offset)
if return_po:
return po_1 - po_0, po_0, po_1
else:
return po_1 - po_0
def train_offsetnet(
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_p: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
return_val_loss: bool = False,
nonlin: str = DEFAULT_NONLIN,
avg_objective: bool = True,
) -> Tuple:
# input check
y, w = check_shape_1d_data(y), check_shape_1d_data(w)
d = X.shape[1]
input_shape = (-1, d)
rng_key = random.PRNGKey(seed)
onp.random.seed(seed) # set seed for data generation via numpy as well
# get validation split (can be none)
X, y, w, X_val, y_val, w_val, val_string = make_val_split(
X, y, w, val_split_prop=val_split_prop, seed=seed
)
n = X.shape[0] # could be different from before due to split
# get output head functions (both heads share same structure)
init_fun_head, predict_fun_head = OutputHead(
n_layers_out=n_layers_out,
n_units_out=n_units_out,
binary_y=False,
n_layers_r=n_layers_r,
n_units_r=n_units_r,
nonlin=nonlin,
)
def init_fun_offset(rng: float, input_shape: Tuple) -> Tuple:
# chain together the layers
# param should look like [param_base, param_offset]
rng, layer_rng = random.split(rng)
_, param_base = init_fun_head(layer_rng, input_shape)
rng, layer_rng = random.split(rng)
input_shape, param_offset = init_fun_head(layer_rng, input_shape)
return input_shape, [param_base, param_offset]
# Define loss functions
if not binary_y:
@jit
def loss_offsetnet(
params: jnp.ndarray, batch: jnp.ndarray, penalty: float, penalty_l2_p: float
) -> jnp.ndarray:
# params: list[representation, head_0, head_1]
# batch: (X, y, w)
inputs, targets, w = batch
preds_0 = predict_fun_head(params[0], inputs)
offset = predict_fun_head(params[1], inputs)
preds = preds_0 + w * offset
weightsq_head = heads_l2_penalty(
params[0],
params[1],
n_layers_out + n_layers_r,
False,
penalty,
penalty_l2_p,
)
if not avg_objective:
return jnp.sum((preds - targets) ** 2) + 0.5 * weightsq_head
else:
return jnp.average((preds - targets) ** 2) + 0.5 * weightsq_head
else:
def loss_offsetnet(
params: jnp.ndarray,
batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
penalty: float,
penalty_l2_p: float,
) -> jnp.ndarray:
# params: list[representation, head_0, head_1]
# batch: (X, y, w)
inputs, targets, w = batch
preds_0 = predict_fun_head(params[0], inputs)
offset = predict_fun_head(params[1], inputs)
preds = sigmoid(preds_0 + w * offset)
weightsq_head = heads_l2_penalty(
params[0],
params[1],
n_layers_out + n_layers_r,
False,
penalty,
penalty_l2_p,
)
if not avg_objective:
return (
-jnp.sum(
(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))
)
+ 0.5 * weightsq_head
)
else:
n_batch = y.shape[0]
return (
-jnp.sum(
(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))
)
/ n_batch
+ 0.5 * weightsq_head
)
# Define optimisation routine
opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
@jit
def update(
i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_l2_p: float
) -> jnp.ndarray:
# updating function
params = get_params(state)
return opt_update(
i, grad(loss_offsetnet)(params, batch, penalty_l2, penalty_l2_p), state
)
# initialise states
_, init_params = init_fun_offset(rng_key, input_shape)
opt_state = opt_init(init_params)
# calculate number of batches per epoch
batch_size = batch_size if batch_size < n else n
n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
train_indices = onp.arange(n)
l_best = LARGE_VAL
p_curr = 0
pred_funs = predict_fun_head, binary_y
# do training
for i in range(n_iter):
# shuffle data for minibatches
onp.random.shuffle(train_indices)
for b in range(n_batches):
idx_next = train_indices[
(b * batch_size) : min((b + 1) * batch_size, n - 1)
]
next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
opt_state = update(
i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_l2_p
)
if (i % n_iter_print == 0) or early_stopping:
params_curr = get_params(opt_state)
l_curr = loss_offsetnet(
params_curr, (X_val, y_val, w_val), penalty_l2, penalty_l2_p
)
if i % n_iter_print == 0:
log.info(f"Epoch: {i}, current {val_string} loss {l_curr}")
if early_stopping and ((i + 1) * n_batches > n_iter_min):
if l_curr < l_best:
l_best = l_curr
p_curr = 0
else:
p_curr = p_curr + 1
if p_curr > patience:
if return_val_loss:
# return loss without penalty
l_final = loss_offsetnet(params_curr, (X_val, y_val, w_val), 0, 0)
return params_curr, pred_funs, l_final
return params_curr, pred_funs
# return the parameters
trained_params = get_params(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss_offsetnet(get_params(opt_state), (X_val, y_val, w_val), 0, 0)
return trained_params, pred_funs, l_final
return trained_params, pred_funs
================================================
FILE: catenets/models/jax/pseudo_outcome_nets.py
================================================
"""
Implements Pseudo-outcome based Two-step Nets, namely the DR-learner, the PW-learner and the
RA-learner.
"""
# Author: Alicia Curth
from typing import Callable, Optional, Tuple
import jax.numpy as jnp
import numpy as onp
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import catenets.logger as log
from catenets.models.constants import (
DEFAULT_AVG_OBJECTIVE,
DEFAULT_BATCH_SIZE,
DEFAULT_CF_FOLDS,
DEFAULT_LAYERS_OUT,
DEFAULT_LAYERS_OUT_T,
DEFAULT_LAYERS_R,
DEFAULT_LAYERS_R_T,
DEFAULT_N_ITER,
DEFAULT_N_ITER_MIN,
DEFAULT_N_ITER_PRINT,
DEFAULT_NONLIN,
DEFAULT_PATIENCE,
DEFAULT_PENALTY_L2,
DEFAULT_SEED,
DEFAULT_STEP_SIZE,
DEFAULT_STEP_SIZE_T,
DEFAULT_UNITS_OUT,
DEFAULT_UNITS_OUT_T,
DEFAULT_UNITS_R,
DEFAULT_UNITS_R_T,
DEFAULT_VAL_SPLIT,
)
from catenets.models.jax.base import BaseCATENet, train_output_net_only
from catenets.models.jax.disentangled_nets import predict_snet3, train_snet3
from catenets.models.jax.flextenet import predict_flextenet, train_flextenet
from catenets.models.jax.model_utils import check_shape_1d_data, check_X_is_np
from catenets.models.jax.offsetnet import predict_offsetnet, train_offsetnet
from catenets.models.jax.representation_nets import (
predict_snet1,
predict_snet2,
train_snet1,
train_snet2,
)
from catenets.models.jax.snet import predict_snet, train_snet
from catenets.models.jax.tnet import predict_t_net, train_tnet
from catenets.models.jax.transformation_utils import (
DR_TRANSFORMATION,
PW_TRANSFORMATION,
RA_TRANSFORMATION,
_get_transformation_function,
)
T_STRATEGY = "T"
S1_STRATEGY = "Tar"
S2_STRATEGY = "S2"
S3_STRATEGY = "S3"
S_STRATEGY = "S"
OFFSET_STRATEGY = "Offset"
FLEX_STRATEGY = "Flex"
ALL_STRATEGIES = [
T_STRATEGY,
S1_STRATEGY,
S2_STRATEGY,
S3_STRATEGY,
S_STRATEGY,
FLEX_STRATEGY,
OFFSET_STRATEGY,
]
class PseudoOutcomeNet(BaseCATENet):
"""
Class implements TwoStepLearners based on pseudo-outcome regression as discussed in
Curth &vd Schaar (2021): RA-learner, PW-learner and DR-learner
Parameters
----------
first_stage_strategy: str, default 't'
which nuisance estimator to use in first stage
first_stage_args: dict
Any additional arguments to pass to first stage training function
data_split: bool, default False
Whether to split the data in two folds for estimation
cross_fit: bool, default False
Whether to perform cross fitting
n_cf_folds: int
Number of crossfitting folds to use
transformation: str, default 'AIPW'
pseudo-outcome to use ('AIPW' for DR-learner, 'HT' for PW learner, 'RA' for RA-learner)
binary_y: bool, default False
Whether the outcome is binary
n_layers_out: int
First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_out: int
First stage Number of hidden units in each hypothesis layer
n_layers_r: int
First stage Number of representation layers before hypothesis layers (distinction between
hypothesis layers and representation layers is made to match TARNet & SNets)
n_units_r: int
First stage Number of hidden units in each representation layer
n_layers_out_t: int
Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_out_t: int
Second stage Number of hidden units in each hypothesis layer
n_layers_r_t: int
Second stage Number of representation layers before hypothesis layers (distinction between
hypothesis layers and representation layers is made to match TARNet & SNets)
n_units_r_t: int
Second stage Number of hidden units in each representation layer
penalty_l2: float
First stage l2 (ridge) penalty
penalty_l2_t: float
Second stage l2 (ridge) penalty
step_size: float
First stage learning rate for optimizer
step_size_t: float
Second stage learning rate for optimizer
n_iter: int
Maximum number of iterations
batch_size: int
Batch size
val_split_prop: float
Proportion of samples used for validation split (can be 0)
early_stopping: bool, default True
Whether to use early stopping
patience: int
Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min: int
Minimum number of iterations to go through before starting early stopping
n_iter_print: int
Number of iterations after which to print updates
seed: int
Seed used
nonlin: string, default 'elu'
Nonlinearity to use in NN
"""
def __init__(
self,
first_stage_strategy: str = T_STRATEGY,
first_stage_args: Optional[dict] = None,
data_split: bool = False,
cross_fit: bool = False,
n_cf_folds: int = DEFAULT_CF_FOLDS,
transformation: str = DR_TRANSFORMATION,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
n_layers_r_t: int = DEFAULT_LAYERS_R_T,
n_units_out: int = DEFAULT_UNITS_OUT,
n_units_r: int = DEFAULT_UNITS_R,
n_units_out_t: int = DEFAULT_UNITS_OUT_T,
n_units_r_t: int = DEFAULT_UNITS_R_T,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_t: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
step_size_t: float = DEFAULT_STEP_SIZE_T,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
rescale_transformation: bool = False,
nonlin: str = DEFAULT_NONLIN,
) -> None:
# settings
self.first_stage_strategy = first_stage_strategy
self.first_stage_args = first_stage_args
self.binary_y = binary_y
self.transformation = transformation
self.data_split = data_split
self.cross_fit = cross_fit
self.n_cf_folds = n_cf_folds
# model architecture hyperparams
self.n_layers_out = n_layers_out
self.n_layers_out_t = n_layers_out_t
self.n_layers_r = n_layers_r
self.n_layers_r_t = n_layers_r_t
self.n_units_out = n_units_out
self.n_units_out_t = n_units_out_t
self.n_units_r = n_units_r
self.n_units_r_t = n_units_r_t
self.nonlin = nonlin
# other hyperparameters
self.penalty_l2 = penalty_l2
self.penalty_l2_t = penalty_l2_t
self.step_size = step_size
self.step_size_t = step_size_t
self.n_iter = n_iter
self.batch_size = batch_size
self.n_iter_print = n_iter_print
self.seed = seed
self.val_split_prop = val_split_prop
self.early_stopping = early_stopping
self.patience = patience
self.n_iter_min = n_iter_min
self.rescale_transformation = rescale_transformation
def _get_train_function(self) -> Callable:
return train_pseudooutcome_net
def fit(
self,
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
p: Optional[jnp.ndarray] = None,
) -> "PseudoOutcomeNet":
# overwrite super so we can pass p as extra param
# some quick input checks
X = check_X_is_np(X)
self._check_inputs(w, p)
train_func = self._get_train_function()
train_params = self.get_params()
if "transformation" not in train_params.keys():
train_params.update({"transformation": self.transformation})
if self.rescale_transformation:
self._params, self._predict_funs, self._scale_factor = train_func(
X, y, w, p, **train_params
)
else:
self._params, self._predict_funs = train_func(X, y, w, p, **train_params)
return self
def _get_predict_function(self) -> Callable:
# Two step nets do not need this
pass
def predict(
self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False
) -> jnp.ndarray:
# check input
if return_po:
raise NotImplementedError(
"TwoStepNets have no Potential outcome predictors."
)
if return_prop:
raise NotImplementedError("TwoStepNets have no Propensity predictors.")
if isinstance(X, pd.DataFrame):
X = X.values
if self.rescale_transformation:
return 1 / self._scale_factor * self._predict_funs(self._params, X)
else:
return self._predict_funs(self._params, X)
class DRNet(PseudoOutcomeNet):
"""Wrapper for DR-learner using PseudoOutcomeNet"""
def __init__(
self,
first_stage_strategy: str = T_STRATEGY,
data_split: bool = False,
cross_fit: bool = False,
n_cf_folds: int = DEFAULT_CF_FOLDS,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
n_layers_r_t: int = DEFAULT_LAYERS_R_T,
n_units_out: int = DEFAULT_UNITS_OUT,
n_units_r: int = DEFAULT_UNITS_R,
n_units_out_t: int = DEFAULT_UNITS_OUT_T,
n_units_r_t: int = DEFAULT_UNITS_R_T,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_t: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
step_size_t: float = DEFAULT_STEP_SIZE_T,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
rescale_transformation: bool = False,
nonlin: str = DEFAULT_NONLIN,
first_stage_args: Optional[dict] = None,
) -> None:
super().__init__(
first_stage_strategy=first_stage_strategy,
data_split=data_split,
cross_fit=cross_fit,
n_cf_folds=n_cf_folds,
transformation=DR_TRANSFORMATION,
binary_y=binary_y,
n_layers_out=n_layers_out,
n_layers_r=n_layers_r,
n_layers_out_t=n_layers_out_t,
n_layers_r_t=n_layers_r_t,
n_units_out=n_units_out,
n_units_r=n_units_r,
n_units_out_t=n_units_out_t,
n_units_r_t=n_units_r_t,
penalty_l2=penalty_l2,
penalty_l2_t=penalty_l2_t,
step_size=step_size,
step_size_t=step_size_t,
n_iter=n_iter,
batch_size=batch_size,
n_iter_min=n_iter_min,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
rescale_transformation=rescale_transformation,
first_stage_args=first_stage_args,
)
class RANet(PseudoOutcomeNet):
"""Wrapper for RA-learner using PseudoOutcomeNet"""
def __init__(
self,
first_stage_strategy: str = T_STRATEGY,
data_split: bool = False,
cross_fit: bool = False,
n_cf_folds: int = DEFAULT_CF_FOLDS,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
n_layers_r_t: int = DEFAULT_LAYERS_R_T,
n_units_out: int = DEFAULT_UNITS_OUT,
n_units_r: int = DEFAULT_UNITS_R,
n_units_out_t: int = DEFAULT_UNITS_OUT_T,
n_units_r_t: int = DEFAULT_UNITS_R_T,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_t: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
step_size_t: float = DEFAULT_STEP_SIZE_T,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
rescale_transformation: bool = False,
nonlin: str = DEFAULT_NONLIN,
first_stage_args: Optional[dict] = None,
) -> None:
super().__init__(
first_stage_strategy=first_stage_strategy,
data_split=data_split,
cross_fit=cross_fit,
n_cf_folds=n_cf_folds,
transformation=RA_TRANSFORMATION,
binary_y=binary_y,
n_layers_out=n_layers_out,
n_layers_r=n_layers_r,
n_layers_out_t=n_layers_out_t,
n_layers_r_t=n_layers_r_t,
n_units_out=n_units_out,
n_units_r=n_units_r,
n_units_out_t=n_units_out_t,
n_units_r_t=n_units_r_t,
penalty_l2=penalty_l2,
penalty_l2_t=penalty_l2_t,
step_size=step_size,
step_size_t=step_size_t,
n_iter=n_iter,
batch_size=batch_size,
n_iter_min=n_iter_min,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
rescale_transformation=rescale_transformation,
first_stage_args=first_stage_args,
)
class PWNet(PseudoOutcomeNet):
"""Wrapper for PW-learner using PseudoOutcomeNet"""
def __init__(
self,
first_stage_strategy: str = T_STRATEGY,
data_split: bool = False,
cross_fit: bool = False,
n_cf_folds: int = DEFAULT_CF_FOLDS,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
n_layers_r_t: int = DEFAULT_LAYERS_R_T,
n_units_out: int = DEFAULT_UNITS_OUT,
n_units_r: int = DEFAULT_UNITS_R,
n_units_out_t: int = DEFAULT_UNITS_OUT_T,
n_units_r_t: int = DEFAULT_UNITS_R_T,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_t: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
step_size_t: float = DEFAULT_STEP_SIZE_T,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
rescale_transformation: bool = False,
nonlin: str = DEFAULT_NONLIN,
first_stage_args: Optional[dict] = None,
) -> None:
super().__init__(
first_stage_strategy=first_stage_strategy,
data_split=data_split,
cross_fit=cross_fit,
n_cf_folds=n_cf_folds,
transformation=PW_TRANSFORMATION,
binary_y=binary_y,
n_layers_out=n_layers_out,
n_layers_r=n_layers_r,
n_layers_out_t=n_layers_out_t,
n_layers_r_t=n_layers_r_t,
n_units_out=n_units_out,
n_units_r=n_units_r,
n_units_out_t=n_units_out_t,
n_units_r_t=n_units_r_t,
penalty_l2=penalty_l2,
penalty_l2_t=penalty_l2_t,
step_size=step_size,
step_size_t=step_size_t,
n_iter=n_iter,
batch_size=batch_size,
n_iter_min=n_iter_min,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
rescale_transformation=rescale_transformation,
first_stage_args=first_stage_args,
)
def train_pseudooutcome_net(
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
p: Optional[jnp.ndarray] = None,
first_stage_strategy: str = T_STRATEGY,
data_split: bool = False,
cross_fit: bool = False,
n_cf_folds: int = DEFAULT_CF_FOLDS,
transformation: str = DR_TRANSFORMATION,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_layers_r_t: int = DEFAULT_LAYERS_R_T,
n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
n_units_out: int = DEFAULT_UNITS_OUT,
n_units_r: int = DEFAULT_UNITS_R,
n_units_out_t: int = DEFAULT_UNITS_OUT_T,
n_units_r_t: int = DEFAULT_UNITS_R_T,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_l2_t: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
step_size_t: float = DEFAULT_STEP_SIZE_T,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
rescale_transformation: bool = False,
return_val_loss: bool = False,
nonlin: str = DEFAULT_NONLIN,
avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
first_stage_args: Optional[dict] = None,
) -> Tuple:
# get shape of data
n, d = X.shape
if p is not None:
p = check_shape_1d_data(p)
# get transformation function
transformation_function = _get_transformation_function(transformation)
# get strategy name
if first_stage_strategy not in ALL_STRATEGIES:
raise ValueError(
"Parameter first stage should be in "
"catenets.models.pseudo_outcome_nets.ALL_STRATEGIES. "
"You passed {}".format(first_stage_strategy)
)
# split data as wanted
if p is None or transformation is not PW_TRANSFORMATION:
if not cross_fit:
if not data_split:
log.debug("Training first stage with all data (no data splitting)")
# use all data for both
fit_mask = onp.ones(n, dtype=bool)
pred_mask = onp.ones(n, dtype=bool)
else:
log.debug("Training first stage with half of the data (data splitting)")
# split data in half
fit_idx = onp.random.choice(n, int(onp.round(n / 2)))
fit_mask = onp.zeros(n, dtype=bool)
fit_mask[fit_idx] = 1
pred_mask = ~fit_mask
mu_0, mu_1, pi_hat = _train_and_predict_first_stage(
X,
y,
w,
fit_mask,
pred_mask,
first_stage_strategy=first_stage_strategy,
binary_y=binary_y,
n_layers_out=n_layers_out,
n_layers_r=n_layers_r,
n_units_out=n_units_out,
n_units_r=n_units_r,
penalty_l2=penalty_l2,
step_size=step_size,
n_iter=n_iter,
batch_size=batch_size,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_min=n_iter_min,
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
avg_objective=avg_objective,
transformation=transformation,
first_stage_args=first_stage_args,
)
if data_split:
# keep only prediction data
X, y, w = X[pred_mask, :], y[pred_mask, :], w[pred_mask, :]
if p is not None:
p = p[pred_mask, :]
else:
log.debug(f"Training first stage in {n_cf_folds} folds (cross-fitting)")
# do cross fitting
mu_0, mu_1, pi_hat = onp.zeros((n, 1)), onp.zeros((n, 1)), onp.zeros((n, 1))
splitter = StratifiedKFold(
n_splits=n_cf_folds, shuffle=True, random_state=seed
)
fold_count = 1
for train_idx, test_idx in splitter.split(X, w):
log.debug(f"Training fold {fold_count}.")
fold_count = fold_count + 1
pred_mask = onp.zeros(n, dtype=bool)
pred_mask[test_idx] = 1
fit_mask = ~pred_mask
(
mu_0[pred_mask],
mu_1[pred_mask],
pi_hat[pred_mask],
) = _train_and_predict_first_stage(
X,
y,
w,
fit_mask,
pred_mask,
first_stage_strategy=first_stage_strategy,
binary_y=binary_y,
n_layers_out=n_layers_out,
n_layers_r=n_layers_r,
n_units_out=n_units_out,
n_units_r=n_units_r,
penalty_l2=penalty_l2,
step_size=step_size,
n_iter=n_iter,
batch_size=batch_size,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_min=n_iter_min,
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
avg_objective=avg_objective,
transformation=transformation,
first_stage_args=first_stage_args,
)
log.debug("Training second stage.")
if p is not None:
# use known propensity score
p = check_shape_1d_data(p)
pi_hat = p
# second stage
y, w = check_shape_1d_data(y), check_shape_1d_data(w)
# transform data and fit on transformed data
if transformation is PW_TRANSFORMATION:
mu_0 = None
mu_1 = None
pseudo_outcome = transformation_function(y=y, w=w, p=pi_hat, mu_0=mu_0, mu_1=mu_1)
if rescale_transformation:
scale_factor = onp.std(y) / onp.std(pseudo_outcome)
if scale_factor > 1:
scale_factor = 1
else:
pseudo_outcome = scale_factor * pseudo_outcome
params, predict_funs = train_output_net_only(
X,
pseudo_outcome,
binary_y=False,
n_layers_out=n_layers_out_t,
n_units_out=n_units_out_t,
n_layers_r=n_layers_r_t,
n_units_r=n_units_r_t,
penalty_l2=penalty_l2_t,
step_size=step_size_t,
n_iter=n_iter,
batch_size=batch_size,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_min=n_iter_min,
n_iter_print=n_iter_print,
seed=seed,
return_val_loss=return_val_loss,
nonlin=nonlin,
avg_objective=avg_objective,
)
return params, predict_funs, scale_factor
else:
return train_output_net_only(
X,
pseudo_outcome,
binary_y=False,
n_layers_out=n_layers_out_t,
n_units_out=n_units_out_t,
n_layers_r=n_layers_r_t,
n_units_r=n_units_r_t,
penalty_l2=penalty_l2_t,
step_size=step_size_t,
n_iter=n_iter,
batch_size=batch_size,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_min=n_iter_min,
n_iter_print=n_iter_print,
seed=seed,
return_val_loss=return_val_loss,
nonlin=nonlin,
avg_objective=avg_objective,
)
def _train_and_predict_first_stage(
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
fit_mask: jnp.ndarray,
pred_mask: jnp.ndarray,
first_stage_strategy: str,
binary_y: bool = False,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_out: int = DEFAULT_UNITS_OUT,
n_units_r: int = DEFAULT_UNITS_R,
penalty_l2: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
nonlin: str = DEFAULT_NONLIN,
avg_objective: bool = False,
transformation: str = DR_TRANSFORMATION,
first_stage_args: Optional[dict] = None,
) -> Tuple:
if len(w.shape) > 1:
w = w.reshape((len(w),))
if first_stage_args is None:
first_stage_args = {}
# split the data
X_fit, y_fit, w_fit = X[fit_mask, :], y[fit_mask], w[fit_mask]
X_pred = X[pred_mask, :]
train_fun: Callable
predict_fun: Callable
if first_stage_strategy == T_STRATEGY:
train_fun, predict_fun = train_tnet, predict_t_net
elif first_stage_strategy == S_STRATEGY:
train_fun, predict_fun = train_snet, predict_snet
elif first_stage_strategy == S1_STRATEGY:
train_fun, predict_fun = train_snet1, predict_snet1
elif first_stage_strategy == S2_STRATEGY:
train_fun, predict_fun = train_snet2, predict_snet2
elif first_stage_strategy == S3_STRATEGY:
train_fun, predict_fun = train_snet3, predict_snet3
elif first_stage_strategy == OFFSET_STRATEGY:
train_fun, predict_fun = train_offsetnet, predict_offsetnet
elif first_stage_strategy == FLEX_STRATEGY:
train_fun, predict_fun = train_flextenet, predict_flextenet
else:
raise ValueError(
"{} is not a valid first stage strategy for a PseudoOutcomeNet".format(
first_stage_strategy
)
)
log.debug("Training PO estimators")
trained_params, pred_fun = train_fun(
X_fit,
y_fit,
w_fit,
binary_y=binary_y,
n_layers_r=n_layers_r,
n_units_r=n_units_r,
n_layers_out=n_layers_out,
n_units_out=n_units_out,
penalty_l2=penalty_l2,
step_size=step_size,
n_iter=n_iter,
batch_size=batch_size,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_min=n_iter_min,
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
avg_objective=avg_objective,
**first_stage_args,
)
if first_stage_strategy in [S_STRATEGY, S2_STRATEGY, S3_STRATEGY]:
_, mu_0, mu_1, pi_hat = predict_fun(
X_pred, trained_params, pred_fun, return_po=True, return_prop=True
)
else:
if transformation is not PW_TRANSFORMATION:
_, mu_0, mu_1 = predict_fun(
X_pred, trained_params, pred_fun, return_po=True
)
else:
mu_0, mu_1 = onp.nan, onp.nan
if transformation is not RA_TRANSFORMATION:
log.debug("Training propensity net")
params_prop, predict_fun_prop = train_output_net_only(
X_fit,
w_fit,
binary_y=True,
n_layers_out=n_layers_out,
n_units_out=n_units_out,
n_layers_r=n_layers_r,
n_units_r=n_units_r,
penalty_l2=penalty_l2,
step_size=step_size,
n_iter=n_iter,
batch_size=batch_size,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_min=n_iter_min,
n_iter_print=n_iter_print,
seed=seed,
nonlin=nonlin,
avg_objective=avg_objective,
)
pi_hat = predict_fun_prop(params_prop, X_pred)
else:
pi_hat = onp.nan
return mu_0, mu_1, pi_hat
================================================
FILE: catenets/models/jax/representation_nets.py
================================================
"""
Module implements SNet1 and SNet2, which are based on CFRNet/TARNet from Shalit et al (2017) and
DragonNet from Shi et al (2019), respectively.
"""
# Author: Alicia Curth
from typing import Any, Callable, List, Tuple
import jax.numpy as jnp
import numpy as onp
from jax import grad, jit, random
from jax.example_libraries import optimizers
import catenets.logger as log
from catenets.models.constants import (
DEFAULT_AVG_OBJECTIVE,
DEFAULT_BATCH_SIZE,
DEFAULT_LAYERS_OUT,
DEFAULT_LAYERS_R,
DEFAULT_N_ITER,
DEFAULT_N_ITER_MIN,
DEFAULT_N_ITER_PRINT,
DEFAULT_NONLIN,
DEFAULT_PATIENCE,
DEFAULT_PENALTY_DISC,
DEFAULT_PENALTY_L2,
DEFAULT_SEED,
DEFAULT_STEP_SIZE,
DEFAULT_UNITS_OUT,
DEFAULT_UNITS_R,
DEFAULT_VAL_SPLIT,
LARGE_VAL,
)
from catenets.models.jax.base import BaseCATENet, OutputHead, ReprBlock
from catenets.models.jax.model_utils import (
check_shape_1d_data,
heads_l2_penalty,
make_val_split,
)
class SNet1(BaseCATENet):
"""
Class implements Shalit et al (2017)'s TARNet & CFR (discrepancy regularization is NOT
TESTED). Also referred to as SNet-1 in our paper.
Parameters
----------
binary_y: bool, default False
Whether the outcome is binary
n_layers_out: int
Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_out: int
Number of hidden units in each hypothesis layer
n_layers_r: int
Number of shared representation layers before hypothesis layers
n_units_r: int
Number of hidden units in each representation layer
penalty_l2: float
l2 (ridge) penalty
step_size: float
learning rate for optimizer
n_iter: int
Maximum number of iterations
batch_size: int
Batch size
val_split_prop: float
Proportion of samples used for validation split (can be 0)
early_stopping: bool, default True
Whether to use early stopping
patience: int
Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min: int
Minimum number of iterations to go through before starting early stopping
n_iter_print: int
Number of iterations after which to print updates
seed: int
Seed used
reg_diff: bool, default False
Whether to regularize the difference between the two potential outcome heads
penalty_diff: float
l2-penalty for regularizing the difference between output heads. used only if
train_separate=False
same_init: bool, False
Whether to initialise the two output heads with same values
nonlin: string, default 'elu'
Nonlinearity to use in NN
penalty_disc: float, default zero
Discrepancy penalty. Defaults to zero as this feature is not tested.
"""
def __init__(
self,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
reg_diff: bool = False,
penalty_diff: float = DEFAULT_PENALTY_L2,
same_init: bool = False,
nonlin: str = DEFAULT_NONLIN,
penalty_disc: float = DEFAULT_PENALTY_DISC,
) -> None:
# structure of net
self.binary_y = binary_y
self.n_layers_r = n_layers_r
self.n_layers_out = n_layers_out
self.n_units_r = n_units_r
self.n_units_out = n_units_out
self.nonlin = nonlin
# penalties
self.penalty_l2 = penalty_l2
self.penalty_disc = penalty_disc
self.reg_diff = reg_diff
self.penalty_diff = penalty_diff
self.same_init = same_init
# training params
self.step_size = step_size
self.n_iter = n_iter
self.batch_size = batch_size
self.n_iter_print = n_iter_print
self.seed = seed
self.val_split_prop = val_split_prop
self.early_stopping = early_stopping
self.patience = patience
self.n_iter_min = n_iter_min
def _get_train_function(self) -> Callable:
return train_snet1
def _get_predict_function(self) -> Callable:
return predict_snet1
class TARNet(SNet1):
"""Wrapper for TARNet"""
def __init__(
self,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
reg_diff: bool = False,
penalty_diff: float = DEFAULT_PENALTY_L2,
same_init: bool = False,
nonlin: str = DEFAULT_NONLIN,
):
super().__init__(
binary_y=binary_y,
n_layers_r=n_layers_r,
n_units_r=n_units_r,
n_layers_out=n_layers_out,
n_units_out=n_units_out,
penalty_l2=penalty_l2,
step_size=step_size,
n_iter=n_iter,
batch_size=batch_size,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_min=n_iter_min,
n_iter_print=n_iter_print,
seed=seed,
reg_diff=reg_diff,
penalty_diff=penalty_diff,
same_init=same_init,
nonlin=nonlin,
penalty_disc=0,
)
class SNet2(BaseCATENet):
"""
Class implements SNet-2, which is based on Shi et al (2019)'s DragonNet (this version does
NOT use targeted regularization and has a (possibly deeper) propensity head.
Parameters
----------
binary_y: bool, default False
Whether the outcome is binary
n_layers_out: int
Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_layers_out_prop: int
Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense
layer)
n_units_out: int
Number of hidden units in each hypothesis layer
n_units_out_prop: int
Number of hidden units in each propensity score hypothesis layer
n_layers_r: int
Number of shared representation layers before hypothesis layers
n_units_r: int
Number of hidden units in each representation layer
penalty_l2: float
l2 (ridge) penalty
step_size: float
learning rate for optimizer
n_iter: int
Maximum number of iterations
batch_size: int
Batch size
val_split_prop: float
Proportion of samples used for validation split (can be 0)
early_stopping: bool, default True
Whether to use early stopping
patience: int
Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min: int
Minimum number of iterations to go through before starting early stopping
n_iter_print: int
Number of iterations after which to print updates
seed: int
Seed used
reg_diff: bool, default False
Whether to regularize the difference between the two potential outcome heads
penalty_diff: float
l2-penalty for regularizing the difference between output heads. used only if
train_separate=False
same_init: bool, False
Whether to initialise the two output heads with same values
nonlin: string, default 'elu'
Nonlinearity to use in NN
"""
def __init__(
self,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
n_units_out_prop: int = DEFAULT_UNITS_OUT,
n_layers_out_prop: int = DEFAULT_LAYERS_OUT,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
reg_diff: bool = False,
same_init: bool = False,
penalty_diff: float = DEFAULT_PENALTY_L2,
nonlin: str = DEFAULT_NONLIN,
) -> None:
self.binary_y = binary_y
self.n_layers_r = n_layers_r
self.n_layers_out = n_layers_out
self.n_layers_out_prop = n_layers_out_prop
self.n_units_r = n_units_r
self.n_units_out = n_units_out
self.n_units_out_prop = n_units_out_prop
self.nonlin = nonlin
self.penalty_l2 = penalty_l2
self.step_size = step_size
self.n_iter = n_iter
self.batch_size = batch_size
self.val_split_prop = val_split_prop
self.early_stopping = early_stopping
self.patience = patience
self.n_iter_min = n_iter_min
self.reg_diff = reg_diff
self.penalty_diff = penalty_diff
self.same_init = same_init
self.seed = seed
self.n_iter_print = n_iter_print
def _get_train_function(self) -> Callable:
return train_snet2
def _get_predict_function(self) -> Callable:
return predict_snet2
class DragonNet(SNet2):
"""Wrapper for DragonNet"""
def __init__(
self,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
n_units_out_prop: int = DEFAULT_UNITS_OUT,
n_layers_out_prop: int = 0,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
reg_diff: bool = False,
same_init: bool = False,
penalty_diff: float = DEFAULT_PENALTY_L2,
nonlin: str = DEFAULT_NONLIN,
):
super().__init__(
binary_y=binary_y,
n_layers_r=n_layers_r,
n_units_r=n_units_r,
n_layers_out=n_layers_out,
n_units_out=n_units_out,
penalty_l2=penalty_l2,
n_units_out_prop=n_units_out_prop,
n_layers_out_prop=n_layers_out_prop,
step_size=step_size,
n_iter=n_iter,
batch_size=batch_size,
val_split_prop=val_split_prop,
early_stopping=early_stopping,
patience=patience,
n_iter_min=n_iter_min,
n_iter_print=n_iter_print,
seed=seed,
reg_diff=reg_diff,
penalty_diff=penalty_diff,
same_init=same_init,
nonlin=nonlin,
)
# Training functions for SNet1 -------------------------------------------------
def mmd2_lin(X: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:
# Squared Linear MMD as implemented in CFR
# jax does not support indexing, so this is a workaround with reweighting in means
n = w.shape[0]
n_t = jnp.sum(w)
# normalize X so scale matters
X = X / jnp.sqrt(jnp.var(X, axis=0))
mean_control = (n / (n - n_t)) * jnp.mean((1 - w) * X, axis=0)
mean_treated = (n / n_t) * jnp.mean(w * X, axis=0)
return jnp.sum((mean_treated - mean_control) ** 2)
def predict_snet1(
X: jnp.ndarray,
trained_params: dict,
predict_funs: list,
return_po: bool = False,
return_prop: bool = False,
) -> jnp.ndarray:
if return_prop:
raise NotImplementedError("SNet1 does not implement a propensity model.")
# unpack inputs
predict_fun_repr, predict_fun_head = predict_funs
param_repr, param_0, param_1 = (
trained_params[0],
trained_params[1],
trained_params[2],
)
# get representation
representation = predict_fun_repr(param_repr, X)
# get potential outcomes
mu_0 = predict_fun_head(param_0, representation)
mu_1 = predict_fun_head(param_1, representation)
if return_po:
return mu_1 - mu_0, mu_0, mu_1
else:
return mu_1 - mu_0
def train_snet1(
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
penalty_disc: int = DEFAULT_PENALTY_DISC,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
return_val_loss: bool = False,
reg_diff: bool = False,
same_init: bool = False,
penalty_diff: float = DEFAULT_PENALTY_L2,
nonlin: str = DEFAULT_NONLIN,
avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
) -> Any:
# function to train TARNET (Johansson et al) using jax
# input check
y, w = check_shape_1d_data(y), check_shape_1d_data(w)
d = X.shape[1]
input_shape = (-1, d)
rng_key = random.PRNGKey(seed)
onp.random.seed(seed) # set seed for data generation via numpy as well
if not reg_diff:
penalty_diff = penalty_l2
# get validation split (can be none)
X, y, w, X_val, y_val, w_val, val_string = make_val_split(
X, y, w, val_split_prop=val_split_prop, seed=seed
)
n = X.shape[0] # could be different from before due to split
# get representation layer
init_fun_repr, predict_fun_repr = ReprBlock(
n_layers=n_layers_r, n_units=n_units_r, nonlin=nonlin
)
# get output head functions (both heads share same structure)
init_fun_head, predict_fun_head = OutputHead(
n_layers_out=n_layers_out,
n_units_out=n_units_out,
binary_y=binary_y,
nonlin=nonlin,
)
def init_fun_snet1(rng: float, input_shape: Tuple) -> Tuple[Tuple, List]:
# chain together the layers
# param should look like [repr, po_0, po_1]
rng, layer_rng = random.split(rng)
input_shape_repr, param_repr = init_fun_repr(layer_rng, input_shape)
rng, layer_rng = random.split(rng)
if same_init:
# initialise both on same values
input_shape, param_0 = init_fun_head(layer_rng, input_shape_repr)
input_shape, param_1 = init_fun_head(layer_rng, input_shape_repr)
else:
input_shape, param_0 = init_fun_head(layer_rng, input_shape_repr)
rng, layer_rng = random.split(rng)
input_shape, param_1 = init_fun_head(layer_rng, input_shape_repr)
return input_shape, [param_repr, param_0, param_1]
# Define loss functions
# loss functions for the head
if not binary_y:
def loss_head(
params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
) -> jnp.ndarray:
# mse loss function
inputs, targets, weights = batch
preds = predict_fun_head(params, inputs)
return jnp.sum(weights * ((preds - targets) ** 2))
else:
def loss_head(
params: List, batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
) -> jnp.ndarray:
# mse loss function
inputs, targets, weights = batch
preds = predict_fun_head(params, inputs)
return -jnp.sum(
weights
* (targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))
)
# complete loss function for all parts
@jit
def loss_snet1(
params: List,
batch: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
penalty_l2: float,
penalty_disc: float,
penalty_diff: float,
) -> jnp.ndarray:
# params: list[representation, head_0, head_1]
# batch: (X, y, w)
X, y, w = batch
# get representation
reps = predict_fun_repr(params[0], X)
# get mmd
disc = mmd2_lin(reps, w)
# pass down to two heads
loss_0 = loss_head(params[1], (reps, y, 1 - w))
loss_1 = loss_head(params[2], (reps, y, w))
# regularization on representation
weightsq_body = sum(
[jnp.sum(params[0][i][0] ** 2) for i in range(0, 2 * n_layers_r, 2)]
)
weightsq_head = heads_l2_penalty(
params[1], params[2], n_layers_out, reg_diff, penalty_l2, penalty_diff
)
if not avg_objective:
return (
loss_0
+ loss_1
+ penalty_disc * disc
+ 0.5 * (penalty_l2 * weightsq_body + weightsq_head)
)
else:
n_batch = y.shape[0]
return (
(loss_0 + loss_1) / n_batch
+ penalty_disc * disc
+ 0.5 * (penalty_l2 * weightsq_body + weightsq_head)
)
# Define optimisation routine
opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
@jit
def update(
i: int, state: dict, batch: jnp.ndarray, penalty_l2: float, penalty_disc: float
) -> jnp.ndarray:
# updating function
params = get_params(state)
return opt_update(
i,
grad(loss_snet1)(params, batch, penalty_l2, penalty_disc, penalty_diff),
state,
)
# initialise states
_, init_params = init_fun_snet1(rng_key, input_shape)
opt_state = opt_init(init_params)
# calculate number of batches per epoch
batch_size = batch_size if batch_size < n else n
n_batches = int(onp.round(n / batch_size)) if batch_size < n else 1
train_indices = onp.arange(n)
l_best = LARGE_VAL
p_curr = 0
# do training
for i in range(n_iter):
# shuffle data for minibatches
onp.random.shuffle(train_indices)
for b in range(n_batches):
idx_next = train_indices[
(b * batch_size) : min((b + 1) * batch_size, n - 1)
]
next_batch = X[idx_next, :], y[idx_next, :], w[idx_next]
opt_state = update(
i * n_batches + b, opt_state, next_batch, penalty_l2, penalty_disc
)
if (i % n_iter_print == 0) or early_stopping:
params_curr = get_params(opt_state)
l_curr = loss_snet1(
params_curr,
(X_val, y_val, w_val),
penalty_l2,
penalty_disc,
penalty_diff,
)
if i % n_iter_print == 0:
log.info(f"Epoch: {i}, current {val_string} loss {l_curr}")
if early_stopping:
if l_curr < l_best:
l_best = l_curr
p_curr = 0
params_best = params_curr
else:
if onp.isnan(l_curr):
# if diverged, return best
return params_best, (predict_fun_repr, predict_fun_head)
p_curr = p_curr + 1
if p_curr > patience and ((i + 1) * n_batches > n_iter_min):
if return_val_loss:
# return loss without penalty
l_final = loss_snet1(params_curr, (X_val, y_val, w_val), 0, 0, 0)
return params_curr, (predict_fun_repr, predict_fun_head), l_final
return params_curr, (predict_fun_repr, predict_fun_head)
# return the parameters
trained_params = get_params(opt_state)
if return_val_loss:
# return loss without penalty
l_final = loss_snet1(get_params(opt_state), (X_val, y_val, w_val), 0, 0, 0)
return trained_params, (predict_fun_repr, predict_fun_head), l_final
return trained_params, (predict_fun_repr, predict_fun_head)
# SNET-2 -----------------------------------------------------------------------------------------
def train_snet2(
X: jnp.ndarray,
y: jnp.ndarray,
w: jnp.ndarray,
binary_y: bool = False,
n_layers_r: int = DEFAULT_LAYERS_R,
n_units_r: int = DEFAULT_UNITS_R,
n_layers_out: int = DEFAULT_LAYERS_OUT,
n_units_out: int = DEFAULT_UNITS_OUT,
penalty_l2: float = DEFAULT_PENALTY_L2,
n_units_out_prop: int = DEFAULT_UNITS_OUT,
n_layers_out_prop: int = DEFAULT_LAYERS_OUT,
step_size: float = DEFAULT_STEP_SIZE,
n_iter: int = DEFAULT_N_ITER,
batch_size: int = DEFAULT_BATCH_SIZE,
val_split_prop: float = DEFAULT_VAL_SPLIT,
early_stopping: bool = True,
patience: int = DEFAULT_PATIENCE,
n_iter_min: int = DEFAULT_N_ITER_MIN,
n_iter_print: int = DEFAULT_N_ITER_PRINT,
seed: int = DEFAULT_SEED,
return_val_loss: bool = False,
reg_diff: bool = False,
penalty_diff: float = DEFAULT_PENALTY_L2,
nonlin: str = DEFAULT_NONLIN,
avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
same_init: bool = False,
) -> Any:
"""
SNet2 corresponds to DragonNet (Shi et al, 2019) [without TMLE regularisation term].
"""
y, w = check_shape_1d_data(y), check_shape_1d_data(w)
d = X.shape[1]
input_shape = (-1, d)
rng_
gitextract_9k1q0bj_/
├── .github/
│ └── workflows/
│ ├── release.yml
│ ├── scripts/
│ │ ├── release_linux.sh
│ │ ├── release_osx.sh
│ │ └── release_windows.bat
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── catenets/
│ ├── __init__.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── dataset_acic2016.py
│ │ ├── dataset_ihdp.py
│ │ ├── dataset_twins.py
│ │ └── network.py
│ ├── experiment_utils/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── simulation_utils.py
│ │ ├── tester.py
│ │ └── torch_metrics.py
│ ├── logger.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── jax/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── disentangled_nets.py
│ │ │ ├── flextenet.py
│ │ │ ├── model_utils.py
│ │ │ ├── offsetnet.py
│ │ │ ├── pseudo_outcome_nets.py
│ │ │ ├── representation_nets.py
│ │ │ ├── rnet.py
│ │ │ ├── snet.py
│ │ │ ├── tnet.py
│ │ │ ├── transformation_utils.py
│ │ │ └── xnet.py
│ │ └── torch/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── flextenet.py
│ │ ├── pseudo_outcome_nets.py
│ │ ├── representation_nets.py
│ │ ├── slearner.py
│ │ ├── snet.py
│ │ ├── tlearner.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── decorators.py
│ │ ├── model_utils.py
│ │ ├── transformations.py
│ │ └── weight_utils.py
│ └── version.py
├── docs/
│ ├── Makefile
│ ├── conf.py
│ ├── datasets.rst
│ ├── index.rst
│ ├── jax_models.rst
│ ├── make.bat
│ ├── requirements.txt
│ └── torch_models.rst
├── experiments/
│ ├── __init__.py
│ ├── experiments_AISTATS21/
│ │ ├── ihdp_experiments.py
│ │ └── simulations_AISTATS.py
│ ├── experiments_benchmarks_NeurIPS21/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── acic_experiments_catenets.py
│ │ ├── acic_experiments_grf.R
│ │ ├── ihdp_experiments_catenets.py
│ │ ├── ihdp_experiments_grf.R
│ │ ├── twins_experiments_catenets.py
│ │ └── twins_experiments_grf.R
│ └── experiments_inductivebias_NeurIPS21/
│ ├── __init__.py
│ ├── experiments_AB.py
│ ├── experiments_CD.py
│ ├── experiments_acic.py
│ └── experiments_twins.py
├── pyproject.toml
├── pytest.ini
├── run_experiments_AISTATS.py
├── run_experiments_benchmarks_NeurIPS.py
├── run_experiments_inductive_bias_NeurIPS.py
├── setup.py
└── tests/
├── conftest.py
├── datasets/
│ └── test_datasets.py
└── models/
├── jax/
│ ├── test_jax_ite.py
│ ├── test_jax_model_utils.py
│ └── test_jax_transformation_utils.py
└── torch/
├── test_torch_flextenet.py
├── test_torch_pseudo_outcome_nets.py
├── test_torch_representation_net.py
├── test_torch_slearner.py
├── test_torch_snet.py
└── test_torch_tlearner.py
SYMBOL INDEX (356 symbols across 57 files)
FILE: catenets/datasets/__init__.py
function load (line 16) | def load(dataset: str, *args: Any, **kwargs: Any) -> Tuple:
FILE: catenets/datasets/dataset_acic2016.py
function get_acic_covariates (line 53) | def get_acic_covariates(
function preprocess_simu (line 86) | def preprocess_simu(
function get_acic_orig_filenames (line 202) | def get_acic_orig_filenames(data_path: Path, simu_num: int) -> list:
function get_acic_orig_outcomes (line 210) | def get_acic_orig_outcomes(data_path: Path, simu_num: int, i_exp: int) -...
function preprocess_acic_orig (line 220) | def preprocess_acic_orig(
function preprocess (line 281) | def preprocess(
function load (line 296) | def load(
FILE: catenets/datasets/dataset_ihdp.py
function load_data_npz (line 27) | def load_data_npz(fname: Path, get_po: bool = True) -> dict:
function prepare_ihdp_data (line 59) | def prepare_ihdp_data(
function get_one_data_set (line 143) | def get_one_data_set(D: dict, i_exp: int, get_po: bool = True) -> dict:
function load (line 175) | def load(data_path: Path, exp: int = 1, rescale: bool = False, **kwargs:...
function load_raw (line 233) | def load_raw(data_path: Path) -> Tuple:
FILE: catenets/datasets/dataset_twins.py
function preprocess (line 25) | def preprocess(
function load (line 210) | def load(
FILE: catenets/datasets/network.py
function download_gdrive_if_needed (line 13) | def download_gdrive_if_needed(path: Path, file_id: str) -> None:
function download_http_if_needed (line 32) | def download_http_if_needed(path: Path, url: str) -> None:
function unarchive_if_needed (line 55) | def unarchive_if_needed(path: Path, output_folder: Path) -> None:
function download_if_needed (line 78) | def download_if_needed(
FILE: catenets/experiment_utils/base.py
function eval_mse_model (line 33) | def eval_mse_model(
function eval_mse (line 44) | def eval_mse(preds: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
function eval_root_mse (line 50) | def eval_root_mse(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp...
function eval_abs_error_ate (line 56) | def eval_abs_error_ate(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -...
function get_model_set (line 62) | def get_model_set(
function get_all_snets (line 104) | def get_all_snets() -> Dict:
function get_all_pseudoout_models (line 111) | def get_all_pseudoout_models() -> Dict: # DR, RA, PW learner
function get_all_twostep_models (line 120) | def get_all_twostep_models() -> Dict: # DR, RA, R, X learner
FILE: catenets/experiment_utils/simulation_utils.py
function simulate_treatment_setup (line 11) | def simulate_treatment_setup(
function get_multivariate_normal_params (line 130) | def get_multivariate_normal_params(
function get_set_normal_covariates (line 149) | def get_set_normal_covariates(m: int, n: int, correlated: bool = False) ...
function normal_covariate_model (line 156) | def normal_covariate_model(
function propensity_AISTATS (line 173) | def propensity_AISTATS(
function propensity_constant (line 210) | def propensity_constant(
function mu0_AISTATS (line 216) | def mu0_AISTATS(
function mu1_AISTATS (line 229) | def mu1_AISTATS(
function uniform_covariate_model (line 257) | def uniform_covariate_model(
function mu1_additive (line 271) | def mu1_additive(
function mu0_hg (line 287) | def mu0_hg(X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0) -> n...
function mu1_hg (line 295) | def mu1_hg(
function propensity_hg (line 310) | def propensity_hg(
FILE: catenets/experiment_utils/tester.py
function generate_score (line 13) | def generate_score(metric: np.ndarray) -> Tuple[float, float]:
function print_score (line 18) | def print_score(score: Tuple[float, float]) -> str:
function evaluate_treatments_model (line 22) | def evaluate_treatments_model(
FILE: catenets/experiment_utils/torch_metrics.py
function sqrt_PEHE (line 5) | def sqrt_PEHE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:
function abs_error_ATE (line 18) | def abs_error_ATE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:
FILE: catenets/logger.py
function remove (line 15) | def remove() -> None:
function add (line 19) | def add(
function traceback_and_raise (line 47) | def traceback_and_raise(e: Any, verbose: bool = False) -> NoReturn:
function create_log_and_print_function (line 60) | def create_log_and_print_function(level: str) -> Callable:
function traceback (line 78) | def traceback(*args: Any, **kwargs: Any) -> None:
function critical (line 82) | def critical(*args: Any, **kwargs: Any) -> None:
function error (line 86) | def error(*args: Any, **kwargs: Any) -> None:
function warning (line 90) | def warning(*args: Any, **kwargs: Any) -> None:
function info (line 94) | def info(*args: Any, **kwargs: Any) -> None:
function debug (line 98) | def debug(*args: Any, **kwargs: Any) -> None:
function trace (line 102) | def trace(*args: Any, **kwargs: Any) -> None:
FILE: catenets/models/jax/__init__.py
function get_catenet (line 90) | def get_catenet(name: str) -> Any:
FILE: catenets/models/jax/base.py
function ReprBlock (line 40) | def ReprBlock(
function OutputHead (line 64) | def OutputHead(
class BaseCATENet (line 99) | class BaseCATENet(BaseEstimator, RegressorMixin, abc.ABC):
method score (line 104) | def score(
method _get_train_function (line 132) | def _get_train_function(self) -> Callable:
method fit (line 135) | def fit(
method _get_predict_function (line 171) | def _get_predict_function(self) -> Callable:
method predict (line 174) | def predict(
method _check_inputs (line 205) | def _check_inputs(w: jnp.ndarray, p: jnp.ndarray) -> None:
method fit_and_select_params (line 213) | def fit_and_select_params(
function train_output_net_only (line 274) | def train_output_net_only(
FILE: catenets/models/jax/disentangled_nets.py
function _get_absolute_rowsums (line 45) | def _get_absolute_rowsums(mat: jnp.ndarray) -> jnp.ndarray:
function _concatenate_representations (line 49) | def _concatenate_representations(reps: jnp.ndarray) -> jnp.ndarray:
class SNet3 (line 53) | class SNet3(BaseCATENet):
method __init__ (line 113) | def __init__(
method _get_predict_function (line 169) | def _get_predict_function(self) -> Callable:
method _get_train_function (line 172) | def _get_train_function(self) -> Callable:
function train_snet3 (line 177) | def train_snet3(
function predict_snet3 (line 522) | def predict_snet3(
FILE: catenets/models/jax/flextenet.py
class FlexTENet (line 47) | class FlexTENet(BaseCATENet):
method __init__ (line 111) | def __init__(
method _get_train_function (line 172) | def _get_train_function(self) -> Callable:
method _get_predict_function (line 175) | def _get_predict_function(self) -> Callable:
function train_flextenet (line 179) | def train_flextenet(
function predict_flextenet (line 565) | def predict_flextenet(
function _get_cos_reg (line 594) | def _get_cos_reg(
function _compute_ortho_penalty_asymmetric (line 604) | def _compute_ortho_penalty_asymmetric(
function _compute_penalty_l2 (line 661) | def _compute_penalty_l2(
function _compute_penalty (line 735) | def _compute_penalty(
function SplitLayerAsymmetric (line 774) | def SplitLayerAsymmetric(
function TEOutputLayerAsymmetric (line 822) | def TEOutputLayerAsymmetric(private: bool = True, same_init: bool = True...
function FlexTENetArchitecture (line 869) | def FlexTENetArchitecture(
function elementwise_split (line 948) | def elementwise_split(fun: Callable, **fun_kwargs: Any) -> Tuple:
function elementwise_parallel (line 966) | def elementwise_parallel(fun: Callable, **fun_kwargs: Any) -> Tuple:
function DenseW (line 990) | def DenseW(
FILE: catenets/models/jax/model_utils.py
function check_shape_1d_data (line 17) | def check_shape_1d_data(y: jnp.ndarray) -> jnp.ndarray:
function check_X_is_np (line 27) | def check_X_is_np(X: pd.DataFrame) -> jnp.ndarray:
function make_val_split (line 32) | def make_val_split(
function heads_l2_penalty (line 73) | def heads_l2_penalty(
FILE: catenets/models/jax/offsetnet.py
class OffsetNet (line 42) | class OffsetNet(BaseCATENet):
method __init__ (line 87) | def __init__(
method _get_train_function (line 130) | def _get_train_function(self) -> Callable:
method _get_predict_function (line 133) | def _get_predict_function(self) -> Callable:
function predict_offsetnet (line 137) | def predict_offsetnet(
function train_offsetnet (line 171) | def train_offsetnet(
FILE: catenets/models/jax/pseudo_outcome_nets.py
class PseudoOutcomeNet (line 76) | class PseudoOutcomeNet(BaseCATENet):
method __init__ (line 143) | def __init__(
method _get_train_function (line 210) | def _get_train_function(self) -> Callable:
method fit (line 213) | def fit(
method _get_predict_function (line 240) | def _get_predict_function(self) -> Callable:
method predict (line 244) | def predict(
class DRNet (line 265) | class DRNet(PseudoOutcomeNet):
method __init__ (line 268) | def __init__(
class RANet (line 332) | class RANet(PseudoOutcomeNet):
method __init__ (line 335) | def __init__(
class PWNet (line 399) | class PWNet(PseudoOutcomeNet):
method __init__ (line 402) | def __init__(
function train_pseudooutcome_net (line 466) | def train_pseudooutcome_net(
function _train_and_predict_first_stage (line 690) | def _train_and_predict_first_stage(
FILE: catenets/models/jax/representation_nets.py
class SNet1 (line 41) | class SNet1(BaseCATENet):
method __init__ (line 91) | def __init__(
method _get_train_function (line 140) | def _get_train_function(self) -> Callable:
method _get_predict_function (line 143) | def _get_predict_function(self) -> Callable:
class TARNet (line 147) | class TARNet(SNet1):
method __init__ (line 150) | def __init__(
class SNet2 (line 196) | class SNet2(BaseCATENet):
method __init__ (line 249) | def __init__(
method _get_train_function (line 298) | def _get_train_function(self) -> Callable:
method _get_predict_function (line 301) | def _get_predict_function(self) -> Callable:
class DragonNet (line 305) | class DragonNet(SNet2):
method __init__ (line 308) | def __init__(
function mmd2_lin (line 358) | def mmd2_lin(X: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:
function predict_snet1 (line 373) | def predict_snet1(
function train_snet1 (line 404) | def train_snet1(
function train_snet2 (line 632) | def train_snet2(
function predict_snet2 (line 899) | def predict_snet2(
FILE: catenets/models/jax/rnet.py
class RNet (line 51) | class RNet(BaseCATENet):
method __init__ (line 111) | def __init__(
method _get_train_function (line 172) | def _get_train_function(self) -> Callable:
method fit (line 175) | def fit(
method _get_predict_function (line 194) | def _get_predict_function(self) -> Callable:
method predict (line 198) | def predict(
function train_r_net (line 215) | def train_r_net(
function _train_and_predict_r_stage1 (line 397) | def _train_and_predict_r_stage1(
function train_r_stage2 (line 476) | def train_r_stage2(
FILE: catenets/models/jax/snet.py
class SNet (line 51) | class SNet(BaseCATENet):
method __init__ (line 118) | def __init__(
method _get_predict_function (line 178) | def _get_predict_function(self) -> Callable:
method _get_train_function (line 184) | def _get_train_function(self) -> Callable:
function train_snet (line 191) | def train_snet(
function predict_snet (line 587) | def predict_snet(
function train_snet_noprop (line 636) | def train_snet_noprop(
function predict_snet_noprop (line 934) | def predict_snet_noprop(
FILE: catenets/models/jax/tnet.py
class TNet (line 39) | class TNet(BaseCATENet):
method __init__ (line 86) | def __init__(
method _get_predict_function (line 126) | def _get_predict_function(self) -> Callable:
method _get_train_function (line 129) | def _get_train_function(self) -> Callable:
function train_tnet (line 133) | def train_tnet(
function predict_t_net (line 249) | def predict_t_net(
function _train_tnet_jointly (line 272) | def _train_tnet_jointly(
FILE: catenets/models/jax/transformation_utils.py
function aipw_te_transformation (line 16) | def aipw_te_transformation(
function ht_te_transformation (line 53) | def ht_te_transformation(
function ra_te_transformation (line 87) | def ra_te_transformation(
function _get_transformation_function (line 125) | def _get_transformation_function(transformation_name: str) -> Any:
FILE: catenets/models/jax/xnet.py
class XNet (line 60) | class XNet(BaseCATENet):
method __init__ (line 120) | def __init__(
method _get_train_function (line 179) | def _get_train_function(self) -> Callable:
method _get_predict_function (line 182) | def _get_predict_function(self) -> Callable:
method predict (line 186) | def predict(
function train_x_net (line 218) | def train_x_net(
function _get_first_stage_pos (line 392) | def _get_first_stage_pos(
function predict_x_net (line 466) | def predict_x_net(
FILE: catenets/models/torch/base.py
class BasicNet (line 43) | class BasicNet(nn.Module):
method __init__ (line 81) | def __init__(
method forward (line 168) | def forward(self, X: torch.Tensor) -> torch.Tensor:
method fit (line 171) | def fit(
method _check_tensor (line 251) | def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:
class RepresentationNet (line 258) | class RepresentationNet(nn.Module):
method __init__ (line 274) | def __init__(
method forward (line 303) | def forward(self, X: torch.Tensor) -> torch.Tensor:
class PropensityNet (line 307) | class PropensityNet(nn.Module):
method __init__ (line 350) | def __init__(
method forward (line 438) | def forward(self, X: torch.Tensor) -> torch.Tensor:
method get_importance_weights (line 441) | def get_importance_weights(
method loss (line 447) | def loss(self, y_pred: torch.Tensor, y_target: torch.Tensor) -> torch....
method fit (line 450) | def fit(self, X: torch.Tensor, y: torch.Tensor) -> "PropensityNet":
method _check_tensor (line 520) | def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:
class BaseCATEEstimator (line 527) | class BaseCATEEstimator(nn.Module):
method __init__ (line 534) | def __init__(
method score (line 539) | def score(
method fit (line 568) | def fit(
method forward (line 589) | def forward(self, X: torch.Tensor) -> torch.Tensor:
method predict (line 605) | def predict(
method _check_tensor (line 623) | def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:
FILE: catenets/models/torch/flextenet.py
class FlexTELinearLayer (line 31) | class FlexTELinearLayer(nn.Module):
method __init__ (line 35) | def __init__(
method forward (line 52) | def forward(self, tensors: List[torch.Tensor]) -> List:
class FlexTESplitLayer (line 64) | class FlexTESplitLayer(nn.Module):
method __init__ (line 69) | def __init__(
method forward (line 103) | def forward(self, tensors: List[torch.Tensor]) -> List:
class FlexTEOutputLayer (line 134) | class FlexTEOutputLayer(nn.Module):
method __init__ (line 135) | def __init__(
method forward (line 160) | def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
class ElementWiseParallelActivation (line 183) | class ElementWiseParallelActivation(nn.Module):
method __init__ (line 189) | def __init__(self, act: Callable, **act_kwargs: Any) -> None:
method forward (line 194) | def forward(self, tensors: List[torch.Tensor]) -> List:
class ElementWiseSplitActivation (line 208) | class ElementWiseSplitActivation(nn.Module):
method __init__ (line 214) | def __init__(self, act: Callable, **act_kwargs: Any) -> None:
method forward (line 219) | def forward(self, tensors: List[torch.Tensor]) -> List:
class FlexTENet (line 231) | class FlexTENet(BaseCATEEstimator):
method __init__ (line 292) | def __init__(
method _ortho_penalty_asymmetric (line 468) | def _ortho_penalty_asymmetric(self) -> torch.Tensor:
method loss (line 525) | def loss(
method fit (line 546) | def fit(
method predict (line 637) | def predict(
FILE: catenets/models/torch/pseudo_outcome_nets.py
class PseudoOutcomeLearner (line 43) | class PseudoOutcomeLearner(BaseCATEEstimator):
method __init__ (line 105) | def __init__(
method _generate_te_estimator (line 175) | def _generate_te_estimator(self, name: str = "te_estimator") -> nn.Mod...
method _generate_po_estimator (line 200) | def _generate_po_estimator(self, name: str = "po_estimator") -> nn.Mod...
method _generate_propensity_estimator (line 226) | def _generate_propensity_estimator(
method fit (line 252) | def fit(
method predict (line 315) | def predict(
method _first_step (line 341) | def _first_step(
method _second_step (line 352) | def _second_step(
method _impute_pos (line 363) | def _impute_pos(
method _impute_propensity (line 388) | def _impute_propensity(
method _impute_unconditional_mean (line 409) | def _impute_unconditional_mean(
class DRLearner (line 426) | class DRLearner(PseudoOutcomeLearner):
method _first_step (line 431) | def _first_step(
method _second_step (line 447) | def _second_step(
class PWLearner (line 460) | class PWLearner(PseudoOutcomeLearner):
method _first_step (line 465) | def _first_step(
method _second_step (line 478) | def _second_step(
class RALearner (line 491) | class RALearner(PseudoOutcomeLearner):
method _first_step (line 496) | def _first_step(
method _second_step (line 508) | def _second_step(
class ULearner (line 521) | class ULearner(PseudoOutcomeLearner):
method _first_step (line 526) | def _first_step(
method _second_step (line 540) | def _second_step(
class RLearner (line 553) | class RLearner(PseudoOutcomeLearner):
method _first_step (line 559) | def _first_step(
method _second_step (line 572) | def _second_step(
class XLearner (line 587) | class XLearner(PseudoOutcomeLearner):
method __init__ (line 593) | def __init__(
method _first_step (line 605) | def _first_step(
method _second_step (line 617) | def _second_step(
method predict (line 637) | def predict(
FILE: catenets/models/torch/representation_nets.py
class BasicDragonNet (line 39) | class BasicDragonNet(BaseCATEEstimator):
method __init__ (line 83) | def __init__(
method loss (line 154) | def loss(
method fit (line 189) | def fit(
method _step (line 283) | def _step(
method _forward (line 288) | def _forward(self, X: torch.Tensor) -> torch.Tensor:
method predict (line 296) | def predict(
method _maximum_mean_discrepancy (line 325) | def _maximum_mean_discrepancy(
class TARNet (line 340) | class TARNet(BasicDragonNet):
method __init__ (line 345) | def __init__(
method _step (line 384) | def _step(
class DragonNet (line 399) | class DragonNet(BasicDragonNet):
method __init__ (line 404) | def __init__(
method _step (line 441) | def _step(
FILE: catenets/models/torch/slearner.py
class SLearner (line 27) | class SLearner(BaseCATEEstimator):
method __init__ (line 68) | def __init__(
method fit (line 138) | def fit(
method _create_extended_matrices (line 191) | def _create_extended_matrices(self, X: torch.Tensor) -> torch.Tensor:
method predict (line 203) | def predict(
FILE: catenets/models/torch/snet.py
class SNet (line 40) | class SNet(BaseCATEEstimator):
method __init__ (line 92) | def __init__(
method loss (line 258) | def loss(
method fit (line 301) | def fit(
method _ortho_reg (line 399) | def _ortho_reg(self) -> float:
method _maximum_mean_discrepancy (line 478) | def _maximum_mean_discrepancy(
method _step (line 492) | def _step(
method _forward (line 501) | def _forward(
method predict (line 526) | def predict(
FILE: catenets/models/torch/tlearner.py
class TLearner (line 22) | class TLearner(BaseCATEEstimator):
method __init__ (line 56) | def __init__(
method predict (line 107) | def predict(
method fit (line 139) | def fit(
FILE: catenets/models/torch/utils/decorators.py
function check_input_train (line 9) | def check_input_train(func: Callable) -> Callable:
function benchmark (line 32) | def benchmark(func: Callable) -> Callable:
FILE: catenets/models/torch/utils/model_utils.py
function make_val_split (line 19) | def make_val_split(
function train_wrapper (line 77) | def train_wrapper(
function predict_wrapper (line 93) | def predict_wrapper(estimator: Any, X: torch.Tensor) -> torch.Tensor:
FILE: catenets/models/torch/utils/transformations.py
function dr_transformation_cate (line 10) | def dr_transformation_cate(
function pw_transformation_cate (line 47) | def pw_transformation_cate(
function ra_transformation_cate (line 79) | def ra_transformation_cate(
function u_transformation_cate (line 110) | def u_transformation_cate(
FILE: catenets/models/torch/utils/weight_utils.py
function compute_importance_weights (line 26) | def compute_importance_weights(
function compute_ipw (line 54) | def compute_ipw(propensity: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
function compute_trunc_ipw (line 59) | def compute_trunc_ipw(
function compute_matching_weights (line 67) | def compute_matching_weights(propensity: torch.Tensor, w: torch.Tensor) ...
function compute_overlap_weights (line 72) | def compute_overlap_weights(propensity: torch.Tensor, w: torch.Tensor) -...
FILE: experiments/experiments_AISTATS21/ihdp_experiments.py
function do_ihdp_experiments (line 62) | def do_ihdp_experiments(
FILE: experiments/experiments_AISTATS21/simulations_AISTATS.py
function simulation_experiment_loop (line 133) | def simulation_experiment_loop(
function do_one_experiment_repeat (line 253) | def do_one_experiment_repeat(
function one_simulation_experiment (line 342) | def one_simulation_experiment(
function main_AISTATS (line 408) | def main_AISTATS(
FILE: experiments/experiments_benchmarks_NeurIPS21/acic_experiments_catenets.py
function do_acic_experiments (line 34) | def do_acic_experiments(
FILE: experiments/experiments_benchmarks_NeurIPS21/ihdp_experiments_catenets.py
function do_ihdp_experiments (line 36) | def do_ihdp_experiments(
FILE: experiments/experiments_benchmarks_NeurIPS21/twins_experiments_catenets.py
function do_twins_experiment_loop (line 38) | def do_twins_experiment_loop(
function do_twins_experiments (line 56) | def do_twins_experiments(
function prepare_twins (line 114) | def prepare_twins(treat_prop=0.5, seed=42, test_size=0.5, subset_train: ...
FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_AB.py
function do_acic_simu_loops (line 206) | def do_acic_simu_loops(
function do_acic_simu (line 249) | def do_acic_simu(
function acic_simu (line 414) | def acic_simu(
FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_CD.py
function do_ihdp_experiments (line 60) | def do_ihdp_experiments(
FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_acic.py
function do_acic_orig_loop (line 60) | def do_acic_orig_loop(
function do_acic_experiments (line 80) | def do_acic_experiments(
FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_twins.py
function do_twins_experiment_loop (line 63) | def do_twins_experiment_loop(
function do_twins_experiments (line 88) | def do_twins_experiments(
function split_data (line 216) | def split_data(X, y, w, pos, test_size=0.5, random_state=42, subset_trai...
function eval_roc_auc (line 232) | def eval_roc_auc(targets, preds):
function eval_ap (line 238) | def eval_ap(targets, preds):
FILE: run_experiments_AISTATS.py
function init_arg (line 16) | def init_arg() -> Any:
FILE: run_experiments_benchmarks_NeurIPS.py
function init_arg (line 26) | def init_arg() -> Any:
FILE: run_experiments_inductive_bias_NeurIPS.py
function init_arg (line 28) | def init_arg() -> Any:
FILE: setup.py
function read (line 11) | def read(fname: str) -> str:
function find_version (line 15) | def find_version() -> str:
FILE: tests/datasets/test_datasets.py
function test_dataset_sanity_twins (line 9) | def test_dataset_sanity_twins(
function test_dataset_sanity_ihdp (line 29) | def test_dataset_sanity_ihdp() -> None:
function test_dataset_sanity_acic2016 (line 41) | def test_dataset_sanity_acic2016(preprocessed: bool) -> None:
FILE: tests/models/jax/test_jax_ite.py
function test_model_sanity (line 30) | def test_model_sanity(dataset: str, pehe_threshold: float, model_name: s...
function test_model_score (line 39) | def test_model_score() -> None:
FILE: tests/models/jax/test_jax_model_utils.py
function test_check_shape_1d_data_sanity (line 16) | def test_check_shape_1d_data_sanity(data: np.ndarray) -> None:
function test_check_X_is_np_sanity (line 23) | def test_check_X_is_np_sanity(data: Any) -> None:
function test_make_val_split_sanity (line 29) | def test_make_val_split_sanity() -> None:
FILE: tests/models/jax/test_jax_transformation_utils.py
function test_get_transformation_function_sanity (line 18) | def test_get_transformation_function_sanity() -> None:
function test_aipw_te_transformation_sanity (line 31) | def test_aipw_te_transformation_sanity(fn: Callable) -> None:
function test_ht_te_transformation_sanity (line 45) | def test_ht_te_transformation_sanity(fn: Callable) -> None:
function test_ra_te_transformation_sanity (line 56) | def test_ra_te_transformation_sanity(fn: Callable) -> None:
FILE: tests/models/torch/test_torch_flextenet.py
function test_flextenet_model_params (line 9) | def test_flextenet_model_params() -> None:
function test_flextenet_model_sanity (line 59) | def test_flextenet_model_sanity(dataset: str, pehe_threshold: float) -> ...
function test_flextenet_model_predict_api (line 81) | def test_flextenet_model_predict_api(
FILE: tests/models/torch/test_torch_pseudo_outcome_nets.py
function test_nn_model_params (line 24) | def test_nn_model_params(model_t: Any) -> None:
function test_nn_model_params_nonlin (line 39) | def test_nn_model_params_nonlin(nonlin: str, model_t: Any) -> None:
function test_nn_model_sanity (line 54) | def test_nn_model_sanity(dataset: str, pehe_threshold: float, model_t: A...
function test_sklearn_model_pseudo_outcome_binary (line 101) | def test_sklearn_model_pseudo_outcome_binary(
function test_model_predict_api (line 130) | def test_model_predict_api() -> None:
FILE: tests/models/torch/test_torch_representation_net.py
function test_model_params (line 12) | def test_model_params(snet: Type) -> None:
function test_model_params_nonlin (line 41) | def test_model_params_nonlin(nonlin: str, snet: Type) -> None:
function test_model_sanity (line 61) | def test_model_sanity(dataset: str, pehe_threshold: float, snet: Type) -...
function test_model_predict_api (line 79) | def test_model_predict_api() -> None:
FILE: tests/models/torch/test_torch_slearner.py
function test_nn_model_params (line 15) | def test_nn_model_params() -> None:
function test_nn_model_params_nonlin (line 57) | def test_nn_model_params_nonlin(nonlin: str) -> None:
function test_nn_model_sanity (line 72) | def test_nn_model_sanity(
function test_sklearn_model_sanity_binary_output (line 124) | def test_sklearn_model_sanity_binary_output(
function test_slearner_sklearn_model_ihdp (line 155) | def test_slearner_sklearn_model_ihdp(po_estimator: Any, exp: int) -> None:
function test_model_predict_api (line 175) | def test_model_predict_api() -> None:
FILE: tests/models/torch/test_torch_snet.py
function test_model_params (line 10) | def test_model_params() -> None:
function test_model_params_nonlin (line 83) | def test_model_params_nonlin(nonlin: str) -> None:
function test_model_sanity (line 108) | def test_model_sanity(dataset: str, pehe_threshold: float) -> None:
function test_model_predict_api (line 141) | def test_model_predict_api() -> None:
FILE: tests/models/torch/test_torch_tlearner.py
function test_nn_model_params (line 15) | def test_nn_model_params() -> None:
function test_nn_model_params_nonlin (line 42) | def test_nn_model_params_nonlin(nonlin: str) -> None:
function test_nn_model_sanity (line 58) | def test_nn_model_sanity(dataset: str, pehe_threshold: float) -> None:
function test_sklearn_model_sanity_binary_output (line 103) | def test_sklearn_model_sanity_binary_output(
function test_sklearn_model_sanity_regression (line 149) | def test_sklearn_model_sanity_regression(
function test_model_predict_api (line 168) | def test_model_predict_api() -> None:
Condensed preview — 91 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (564K chars).
[
{
"path": ".github/workflows/release.yml",
"chars": 1977,
"preview": "name: Package release\n\non:\n release:\n types: [created]\n\n\njobs:\n deploy_osx:\n runs-on: ${{ matrix.os }}\n strat"
},
{
"path": ".github/workflows/scripts/release_linux.sh",
"chars": 382,
"preview": "#!/bin/bash\n\nset -e\n\nyum makecache -y\nyum install centos-release-scl -y\nyum-config-manager --enable rhel-server-rhscl-7-"
},
{
"path": ".github/workflows/scripts/release_osx.sh",
"chars": 254,
"preview": "#!/bin/sh\n\nexport MACOSX_DEPLOYMENT_TARGET=10.14\n\npython -m pip install --upgrade pip\npip install setuptools wheel twine"
},
{
"path": ".github/workflows/scripts/release_windows.bat",
"chars": 161,
"preview": "echo on\n\npython -m pip install --upgrade pip\npip install setuptools wheel twine auditwheel\n\npip wheel . -w wheel/ --no-d"
},
{
"path": ".github/workflows/test.yml",
"chars": 1909,
"preview": "name: CATENets Tests\n \non:\n push:\n branches: [main, release]\n pull_request:\n types: [opened, synchronize, reo"
},
{
"path": ".gitignore",
"chars": 239,
"preview": "*.pyc\n*.xml\n*.iml\n*.csv\n*.xlsx\n*.Rhistory\n.idea/\n.coverage\n.ipynb_checkpoints\n.ipynb_checkpoints/\n*/.ipynb_checkpoints/\n"
},
{
"path": ".pre-commit-config.yaml",
"chars": 1528,
"preview": "exclude: 'setup.py|^docs'\n\nrepos:\n- repo: https://github.com/pre-commit/pre-commit-hooks\n rev: v3.4.0\n hooks:\n - id: "
},
{
"path": "LICENSE",
"chars": 1520,
"preview": "BSD 3-Clause License\n\nCopyright (c) 2021, Alicia Curth\nAll rights reserved.\n\nRedistribution and use in source and binary"
},
{
"path": "README.md",
"chars": 6279,
"preview": "# CATENets - Conditional Average Treatment Effect Estimation Using Neural Networks\n\n[ dataset\n\"\"\"\n# stdlib\nimport os\nimport random\nfrom pathlib import Path\nf"
},
{
"path": "catenets/datasets/dataset_twins.py",
"chars": 6707,
"preview": "\"\"\"\nTwins dataset\nLoad real-world individualized treatment effects estimation datasets\n\n- Reference: http://data.nber.or"
},
{
"path": "catenets/datasets/network.py",
"chars": 3305,
"preview": "\"\"\"\nUtilities and helpers for retrieving the datasets\n\"\"\"\n# stdlib\nimport tarfile\nimport urllib.request\nfrom pathlib imp"
},
{
"path": "catenets/experiment_utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "catenets/experiment_utils/base.py",
"chars": 3675,
"preview": "\"\"\"\nSome utils for experiments\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Callable, Dict, Optional, Union\n\nimport jax"
},
{
"path": "catenets/experiment_utils/simulation_utils.py",
"chars": 9213,
"preview": "\"\"\"\r\nSimulation utils, allowing to flexibly consider different DGPs\r\n\"\"\"\r\n# Author: Alicia Curth\r\nfrom typing import Any"
},
{
"path": "catenets/experiment_utils/tester.py",
"chars": 1951,
"preview": "# stdlib\nimport copy\nfrom typing import Any, Tuple\n\n# third party\nimport numpy as np\nimport torch\nfrom sklearn.model_sel"
},
{
"path": "catenets/experiment_utils/torch_metrics.py",
"chars": 922,
"preview": "# third party\nimport torch\n\n\ndef sqrt_PEHE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Precisio"
},
{
"path": "catenets/logger.py",
"chars": 2771,
"preview": "# stdlib\nimport logging\nimport os\nfrom typing import Any, Callable, NoReturn, TextIO, Union\n\n# third party\nfrom loguru i"
},
{
"path": "catenets/models/__init__.py",
"chars": 234,
"preview": "import catenets.logger as log\n\ntry:\n from . import jax\nexcept ImportError:\n log.error(\"JAX models disabled\")\n\ntry:"
},
{
"path": "catenets/models/constants.py",
"chars": 1114,
"preview": "\"\"\"\r\nDefine some constants for initialisation of hyperparamters etc\r\n\"\"\"\r\nimport numpy as np\r\n\r\n# default model architec"
},
{
"path": "catenets/models/jax/__init__.py",
"chars": 2044,
"preview": "\"\"\"\nJAX-based implementations for the CATE estimators.\n\"\"\"\nfrom typing import Any\n\nfrom catenets.models.jax.disentangled"
},
{
"path": "catenets/models/jax/base.py",
"chars": 13743,
"preview": "\"\"\"\r\nBase modules shared across different nets\r\n\"\"\"\r\n# Author: Alicia Curth\r\nimport abc\r\nfrom typing import Any, Callabl"
},
{
"path": "catenets/models/jax/disentangled_nets.py",
"chars": 19103,
"preview": "\"\"\"\nClass implements SNet-3, a variation on DR-CFR discussed in\nHassanpour and Greiner (2020) and Wu et al (2020).\n\"\"\"\n#"
},
{
"path": "catenets/models/jax/flextenet.py",
"chars": 33263,
"preview": "\"\"\"\nModule implements FlexTENet, also referred to as the 'flexible approach' in \"On inductive biases\nfor heterogeneous t"
},
{
"path": "catenets/models/jax/model_utils.py",
"chars": 2905,
"preview": "\"\"\"\r\nModel utils shared across different nets\r\n\"\"\"\r\n# Author: Alicia Curth\r\nfrom typing import Any, Optional\r\n\r\nimport j"
},
{
"path": "catenets/models/jax/offsetnet.py",
"chars": 12103,
"preview": "\"\"\"\nModule implements OffsetNet, also referred to as the 'reparametrization approach' and 'hard\napproach' in \"On inducti"
},
{
"path": "catenets/models/jax/pseudo_outcome_nets.py",
"chars": 28778,
"preview": "\"\"\"\nImplements Pseudo-outcome based Two-step Nets, namely the DR-learner, the PW-learner and the\nRA-learner.\n\"\"\"\n# Autho"
},
{
"path": "catenets/models/jax/representation_nets.py",
"chars": 31554,
"preview": "\"\"\"\nModule implements SNet1 and SNet2, which are based on CFRNet/TARNet from Shalit et al (2017) and\nDragonNet from Shi"
},
{
"path": "catenets/models/jax/rnet.py",
"chars": 20157,
"preview": "\"\"\"\nImplements NN based on R-learner and U-learner (as discussed in Nie & Wager (2017))\n\"\"\"\n# Author: Alicia Curth\nfrom "
},
{
"path": "catenets/models/jax/snet.py",
"chars": 34017,
"preview": "\"\"\"\nModule implements SNet class as discussed in Curth & van der Schaar (2021)\n\"\"\"\n# Author: Alicia Curth\nfrom typing im"
},
{
"path": "catenets/models/jax/tnet.py",
"chars": 14991,
"preview": "\"\"\"\nImplements a T-Net: T-learner for CATE based on a dense NN\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Callab"
},
{
"path": "catenets/models/jax/transformation_utils.py",
"chars": 4044,
"preview": "\"\"\"\nUtils for transformations\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Any, Optional\n\nimport numpy as np\n\nPW_TRANSF"
},
{
"path": "catenets/models/jax/xnet.py",
"chars": 16947,
"preview": "\"\"\"\nModule implements X-learner from Kuenzel et al (2019) using NNs\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Callab"
},
{
"path": "catenets/models/torch/__init__.py",
"chars": 562,
"preview": "\"\"\"\nPyTorch-based implementations for the CATE estimators.\n\"\"\"\nfrom .flextenet import FlexTENet\nfrom .pseudo_outcome_net"
},
{
"path": "catenets/models/torch/base.py",
"chars": 20487,
"preview": "import abc\r\nfrom typing import Optional\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom torch import nn\r\n\r\nimport catenets.log"
},
{
"path": "catenets/models/torch/flextenet.py",
"chars": 22738,
"preview": "from typing import Any, Callable, List\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nimport catenets.logger as "
},
{
"path": "catenets/models/torch/pseudo_outcome_nets.py",
"chars": 22113,
"preview": "import abc\nimport copy\nfrom typing import Any, Optional, Tuple\n\nimport numpy as np\nimport torch\nfrom sklearn.model_selec"
},
{
"path": "catenets/models/torch/representation_nets.py",
"chars": 14345,
"preview": "import abc\nfrom typing import Any, Optional, Tuple\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nimport catenet"
},
{
"path": "catenets/models/torch/slearner.py",
"chars": 7539,
"preview": "from typing import Any, Optional\n\nimport torch\n\nimport catenets.logger as log\nfrom catenets.models.constants import (\n "
},
{
"path": "catenets/models/torch/snet.py",
"chars": 19368,
"preview": "from typing import Tuple\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nimport catenets.logger as log\nfrom caten"
},
{
"path": "catenets/models/torch/tlearner.py",
"chars": 5095,
"preview": "import copy\nfrom typing import Any\n\nimport torch\n\nfrom catenets.models.constants import (\n DEFAULT_BATCH_SIZE,\n DE"
},
{
"path": "catenets/models/torch/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "catenets/models/torch/utils/decorators.py",
"chars": 1090,
"preview": "import time\nfrom typing import Any, Callable\n\nimport torch\n\nimport catenets.logger as log\n\n\ndef check_input_train(func: "
},
{
"path": "catenets/models/torch/utils/model_utils.py",
"chars": 3082,
"preview": "\"\"\"\r\nModel utils shared across different nets\r\n\"\"\"\r\n# Author: Alicia Curth, Bogdan Cebere\r\nfrom typing import Any, Optio"
},
{
"path": "catenets/models/torch/utils/transformations.py",
"chars": 4388,
"preview": "\"\"\"\nUnbiased Transformations for CATE\n\"\"\"\n# Author: Alicia Curth\nfrom typing import Optional\n\nimport torch\n\n\ndef dr_tran"
},
{
"path": "catenets/models/torch/utils/weight_utils.py",
"chars": 2188,
"preview": "\"\"\"\nImplement different reweighting/balancing strategies as in Li et al (2018)\n\"\"\"\n# Author: Alicia Curth\nfrom typing im"
},
{
"path": "catenets/version.py",
"chars": 22,
"preview": "__version__ = \"0.2.3\"\n"
},
{
"path": "docs/Makefile",
"chars": 634,
"preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the "
},
{
"path": "docs/conf.py",
"chars": 2600,
"preview": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common op"
},
{
"path": "docs/datasets.rst",
"chars": 378,
"preview": "Datasets\n=========================\n\nDataloaders for datasets used for experiments.\n\n.. toctree::\n :glob:\n :maxdept"
},
{
"path": "docs/index.rst",
"chars": 398,
"preview": "Welcome to CATENets's documentation!\n====================================\n\n.. mdinclude:: ../README.md\n\n\nAPI documentati"
},
{
"path": "docs/jax_models.rst",
"chars": 684,
"preview": "JAX models\n=========================\n\nJAX-based CATE estimators\n\n.. toctree::\n :glob:\n :maxdepth: 2\n\n T-Learner"
},
{
"path": "docs/make.bat",
"chars": 795,
"preview": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sp"
},
{
"path": "docs/requirements.txt",
"chars": 362,
"preview": "autodoc\nbandit\nblack\ncatboost\nflake8\ngdown\njax>=0.3.16\njaxlib>=0.3.14; sys_platform != 'win32'\njupyter\nloguru>=0.5.3\nm2r"
},
{
"path": "docs/torch_models.rst",
"chars": 458,
"preview": "PyTorch models\n=========================\n\nPyTorch-based CATE estimators\n\n.. toctree::\n :glob:\n :maxdepth: 2\n\n T"
},
{
"path": "experiments/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "experiments/experiments_AISTATS21/ihdp_experiments.py",
"chars": 3765,
"preview": "\"\"\"\nScript to run experiments on Johansson's IHDP dataset (retrieved via https://www.fredjo.com/)\n\"\"\"\n# Author: Alicia C"
},
{
"path": "experiments/experiments_AISTATS21/simulations_AISTATS.py",
"chars": 13756,
"preview": "\"\"\"\nScript to generate synthetic simulations in AISTATS paper\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom typin"
},
{
"path": "experiments/experiments_benchmarks_NeurIPS21/README.md",
"chars": 1017,
"preview": "# Replication code for \"Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment"
},
{
"path": "experiments/experiments_benchmarks_NeurIPS21/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "experiments/experiments_benchmarks_NeurIPS21/acic_experiments_catenets.py",
"chars": 3004,
"preview": "\"\"\"\nUtils to replicate ACIC2016 experiments with catenets\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib i"
},
{
"path": "experiments/experiments_benchmarks_NeurIPS21/acic_experiments_grf.R",
"chars": 6407,
"preview": "library(grf)\n\ndo_acic_exper_loop <-\n function(simnums = c(2, 26, 7),\n n_reps = 5,\n n_exp = 10,\n "
},
{
"path": "experiments/experiments_benchmarks_NeurIPS21/ihdp_experiments_catenets.py",
"chars": 3612,
"preview": "\"\"\"\nUtils to replicate IHDP experiments with catenets\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib impor"
},
{
"path": "experiments/experiments_benchmarks_NeurIPS21/ihdp_experiments_grf.R",
"chars": 4203,
"preview": "library(grf)\nlibrary(reticulate)\n\ndo_ihdp_exper <- function(n_exp = 100,\n n_reps = 5,\n "
},
{
"path": "experiments/experiments_benchmarks_NeurIPS21/twins_experiments_catenets.py",
"chars": 5845,
"preview": "\"\"\"\nUtils to replicate Twins experiments with catenets\n\"\"\"\nimport csv\n\n# Author: Alicia Curth\nimport os\nfrom pathlib imp"
},
{
"path": "experiments/experiments_benchmarks_NeurIPS21/twins_experiments_grf.R",
"chars": 4174,
"preview": "library(grf)\n\ndo_twins_exper <- function(\n n_reps = 10,\n subset_train "
},
{
"path": "experiments/experiments_inductivebias_NeurIPS21/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "experiments/experiments_inductivebias_NeurIPS21/experiments_AB.py",
"chars": 12844,
"preview": "\"\"\"\nUtils to replicate setups A & B\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom typing import Optional, Tuple, "
},
{
"path": "experiments/experiments_inductivebias_NeurIPS21/experiments_CD.py",
"chars": 3559,
"preview": "\"\"\"\nUtils to replicate experiments C and D\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib import Path\nfrom"
},
{
"path": "experiments/experiments_inductivebias_NeurIPS21/experiments_acic.py",
"chars": 4117,
"preview": "\"\"\"\nUtils to replicate ACIC2016 experiments (Appendix E.1)\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib "
},
{
"path": "experiments/experiments_inductivebias_NeurIPS21/experiments_twins.py",
"chars": 7014,
"preview": "\"\"\"\nUtils to replicate Twins experiments (Appendix E.2)\n\"\"\"\n# Author: Alicia Curth\nimport csv\nimport os\nfrom pathlib imp"
},
{
"path": "pyproject.toml",
"chars": 159,
"preview": "[build-system]\n# AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD!\nrequires = [\"setuptools>=46.1.0\", \"wheel\"]\nb"
},
{
"path": "pytest.ini",
"chars": 50,
"preview": "[pytest]\nmarkers =\n slow: mark a test as slow.\n"
},
{
"path": "run_experiments_AISTATS.py",
"chars": 1236,
"preview": "\"\"\"\nFile to run AISTATS experiments from shell\n\"\"\"\n# Author: Alicia Curth\nimport argparse\nimport sys\nfrom typing import "
},
{
"path": "run_experiments_benchmarks_NeurIPS.py",
"chars": 2128,
"preview": "\"\"\"\nFile to run the catenets experiments for\n\"Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking "
},
{
"path": "run_experiments_inductive_bias_NeurIPS.py",
"chars": 2596,
"preview": "\"\"\"\nFile to run experiments for\n\"On Inductive Biases for Heterogeneous Treatment Effect Estimation\" (Curth & vdS, NeurIP"
},
{
"path": "setup.py",
"chars": 958,
"preview": "# stdlib\nimport os\nimport re\n\n# third party\nfrom setuptools import setup\n\nPKG_DIR = os.path.dirname(os.path.abspath(__fi"
},
{
"path": "tests/conftest.py",
"chars": 86,
"preview": "import sys\n\nimport catenets.logger as log\n\nlog.add(sink=sys.stderr, level=\"CRITICAL\")\n"
},
{
"path": "tests/datasets/test_datasets.py",
"chars": 1727,
"preview": "import pytest\n\nfrom catenets.datasets import load\n\n\n@pytest.mark.parametrize(\"train_ratio\", [0.5, 0.8])\n@pytest.mark.par"
},
{
"path": "tests/models/jax/test_jax_ite.py",
"chars": 1553,
"preview": "from copy import deepcopy\n\nimport pytest\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester impor"
},
{
"path": "tests/models/jax/test_jax_model_utils.py",
"chars": 1117,
"preview": "from typing import Any\n\nimport jax.numpy as jnp\nimport numpy as np\nimport pandas as pd\nimport pytest\n\nfrom catenets.mode"
},
{
"path": "tests/models/jax/test_jax_transformation_utils.py",
"chars": 1690,
"preview": "from typing import Callable\n\nimport numpy as np\nimport pytest\n\nfrom catenets.models.jax.transformation_utils import (\n "
},
{
"path": "tests/models/torch/test_torch_flextenet.py",
"chars": 3106,
"preview": "import numpy as np\nimport pytest\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils.tester import evalua"
},
{
"path": "tests/models/torch/test_torch_pseudo_outcome_nets.py",
"chars": 3912,
"preview": "from typing import Any\n\nimport numpy as np\nimport pytest\nfrom sklearn.ensemble import RandomForestRegressor\nfrom torch i"
},
{
"path": "tests/models/torch/test_torch_representation_net.py",
"chars": 2656,
"preview": "from typing import Type\n\nimport pytest\nfrom torch import nn\n\nfrom catenets.datasets import load\nfrom catenets.experiment"
},
{
"path": "tests/models/torch/test_torch_slearner.py",
"chars": 5717,
"preview": "from typing import Any, Optional\n\nimport numpy as np\nimport pytest\nfrom sklearn.ensemble import RandomForestClassifier, "
},
{
"path": "tests/models/torch/test_torch_snet.py",
"chars": 4373,
"preview": "import numpy as np\nimport pytest\nfrom torch import nn\n\nfrom catenets.datasets import load\nfrom catenets.experiment_utils"
},
{
"path": "tests/models/torch/test_torch_tlearner.py",
"chars": 5218,
"preview": "from typing import Any\n\nimport numpy as np\nimport pytest\nfrom sklearn.ensemble import RandomForestClassifier, RandomFore"
}
]
About this extraction
This page contains the full source code of the AliciaCurth/CATENets GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 91 files (523.3 KB), approximately 139.3k tokens, and a symbol index with 356 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.