SYMBOL INDEX (356 symbols across 57 files) FILE: catenets/datasets/__init__.py function load (line 16) | def load(dataset: str, *args: Any, **kwargs: Any) -> Tuple: FILE: catenets/datasets/dataset_acic2016.py function get_acic_covariates (line 53) | def get_acic_covariates( function preprocess_simu (line 86) | def preprocess_simu( function get_acic_orig_filenames (line 202) | def get_acic_orig_filenames(data_path: Path, simu_num: int) -> list: function get_acic_orig_outcomes (line 210) | def get_acic_orig_outcomes(data_path: Path, simu_num: int, i_exp: int) -... function preprocess_acic_orig (line 220) | def preprocess_acic_orig( function preprocess (line 281) | def preprocess( function load (line 296) | def load( FILE: catenets/datasets/dataset_ihdp.py function load_data_npz (line 27) | def load_data_npz(fname: Path, get_po: bool = True) -> dict: function prepare_ihdp_data (line 59) | def prepare_ihdp_data( function get_one_data_set (line 143) | def get_one_data_set(D: dict, i_exp: int, get_po: bool = True) -> dict: function load (line 175) | def load(data_path: Path, exp: int = 1, rescale: bool = False, **kwargs:... function load_raw (line 233) | def load_raw(data_path: Path) -> Tuple: FILE: catenets/datasets/dataset_twins.py function preprocess (line 25) | def preprocess( function load (line 210) | def load( FILE: catenets/datasets/network.py function download_gdrive_if_needed (line 13) | def download_gdrive_if_needed(path: Path, file_id: str) -> None: function download_http_if_needed (line 32) | def download_http_if_needed(path: Path, url: str) -> None: function unarchive_if_needed (line 55) | def unarchive_if_needed(path: Path, output_folder: Path) -> None: function download_if_needed (line 78) | def download_if_needed( FILE: catenets/experiment_utils/base.py function eval_mse_model (line 33) | def eval_mse_model( function eval_mse (line 44) | def eval_mse(preds: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray: function eval_root_mse (line 50) | def eval_root_mse(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp... function eval_abs_error_ate (line 56) | def eval_abs_error_ate(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -... function get_model_set (line 62) | def get_model_set( function get_all_snets (line 104) | def get_all_snets() -> Dict: function get_all_pseudoout_models (line 111) | def get_all_pseudoout_models() -> Dict: # DR, RA, PW learner function get_all_twostep_models (line 120) | def get_all_twostep_models() -> Dict: # DR, RA, R, X learner FILE: catenets/experiment_utils/simulation_utils.py function simulate_treatment_setup (line 11) | def simulate_treatment_setup( function get_multivariate_normal_params (line 130) | def get_multivariate_normal_params( function get_set_normal_covariates (line 149) | def get_set_normal_covariates(m: int, n: int, correlated: bool = False) ... function normal_covariate_model (line 156) | def normal_covariate_model( function propensity_AISTATS (line 173) | def propensity_AISTATS( function propensity_constant (line 210) | def propensity_constant( function mu0_AISTATS (line 216) | def mu0_AISTATS( function mu1_AISTATS (line 229) | def mu1_AISTATS( function uniform_covariate_model (line 257) | def uniform_covariate_model( function mu1_additive (line 271) | def mu1_additive( function mu0_hg (line 287) | def mu0_hg(X: np.ndarray, n_w: int = 0, n_c: int = 0, n_o: int = 0) -> n... function mu1_hg (line 295) | def mu1_hg( function propensity_hg (line 310) | def propensity_hg( FILE: catenets/experiment_utils/tester.py function generate_score (line 13) | def generate_score(metric: np.ndarray) -> Tuple[float, float]: function print_score (line 18) | def print_score(score: Tuple[float, float]) -> str: function evaluate_treatments_model (line 22) | def evaluate_treatments_model( FILE: catenets/experiment_utils/torch_metrics.py function sqrt_PEHE (line 5) | def sqrt_PEHE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor: function abs_error_ATE (line 18) | def abs_error_ATE(po: torch.Tensor, hat_te: torch.Tensor) -> torch.Tensor: FILE: catenets/logger.py function remove (line 15) | def remove() -> None: function add (line 19) | def add( function traceback_and_raise (line 47) | def traceback_and_raise(e: Any, verbose: bool = False) -> NoReturn: function create_log_and_print_function (line 60) | def create_log_and_print_function(level: str) -> Callable: function traceback (line 78) | def traceback(*args: Any, **kwargs: Any) -> None: function critical (line 82) | def critical(*args: Any, **kwargs: Any) -> None: function error (line 86) | def error(*args: Any, **kwargs: Any) -> None: function warning (line 90) | def warning(*args: Any, **kwargs: Any) -> None: function info (line 94) | def info(*args: Any, **kwargs: Any) -> None: function debug (line 98) | def debug(*args: Any, **kwargs: Any) -> None: function trace (line 102) | def trace(*args: Any, **kwargs: Any) -> None: FILE: catenets/models/jax/__init__.py function get_catenet (line 90) | def get_catenet(name: str) -> Any: FILE: catenets/models/jax/base.py function ReprBlock (line 40) | def ReprBlock( function OutputHead (line 64) | def OutputHead( class BaseCATENet (line 99) | class BaseCATENet(BaseEstimator, RegressorMixin, abc.ABC): method score (line 104) | def score( method _get_train_function (line 132) | def _get_train_function(self) -> Callable: method fit (line 135) | def fit( method _get_predict_function (line 171) | def _get_predict_function(self) -> Callable: method predict (line 174) | def predict( method _check_inputs (line 205) | def _check_inputs(w: jnp.ndarray, p: jnp.ndarray) -> None: method fit_and_select_params (line 213) | def fit_and_select_params( function train_output_net_only (line 274) | def train_output_net_only( FILE: catenets/models/jax/disentangled_nets.py function _get_absolute_rowsums (line 45) | def _get_absolute_rowsums(mat: jnp.ndarray) -> jnp.ndarray: function _concatenate_representations (line 49) | def _concatenate_representations(reps: jnp.ndarray) -> jnp.ndarray: class SNet3 (line 53) | class SNet3(BaseCATENet): method __init__ (line 113) | def __init__( method _get_predict_function (line 169) | def _get_predict_function(self) -> Callable: method _get_train_function (line 172) | def _get_train_function(self) -> Callable: function train_snet3 (line 177) | def train_snet3( function predict_snet3 (line 522) | def predict_snet3( FILE: catenets/models/jax/flextenet.py class FlexTENet (line 47) | class FlexTENet(BaseCATENet): method __init__ (line 111) | def __init__( method _get_train_function (line 172) | def _get_train_function(self) -> Callable: method _get_predict_function (line 175) | def _get_predict_function(self) -> Callable: function train_flextenet (line 179) | def train_flextenet( function predict_flextenet (line 565) | def predict_flextenet( function _get_cos_reg (line 594) | def _get_cos_reg( function _compute_ortho_penalty_asymmetric (line 604) | def _compute_ortho_penalty_asymmetric( function _compute_penalty_l2 (line 661) | def _compute_penalty_l2( function _compute_penalty (line 735) | def _compute_penalty( function SplitLayerAsymmetric (line 774) | def SplitLayerAsymmetric( function TEOutputLayerAsymmetric (line 822) | def TEOutputLayerAsymmetric(private: bool = True, same_init: bool = True... function FlexTENetArchitecture (line 869) | def FlexTENetArchitecture( function elementwise_split (line 948) | def elementwise_split(fun: Callable, **fun_kwargs: Any) -> Tuple: function elementwise_parallel (line 966) | def elementwise_parallel(fun: Callable, **fun_kwargs: Any) -> Tuple: function DenseW (line 990) | def DenseW( FILE: catenets/models/jax/model_utils.py function check_shape_1d_data (line 17) | def check_shape_1d_data(y: jnp.ndarray) -> jnp.ndarray: function check_X_is_np (line 27) | def check_X_is_np(X: pd.DataFrame) -> jnp.ndarray: function make_val_split (line 32) | def make_val_split( function heads_l2_penalty (line 73) | def heads_l2_penalty( FILE: catenets/models/jax/offsetnet.py class OffsetNet (line 42) | class OffsetNet(BaseCATENet): method __init__ (line 87) | def __init__( method _get_train_function (line 130) | def _get_train_function(self) -> Callable: method _get_predict_function (line 133) | def _get_predict_function(self) -> Callable: function predict_offsetnet (line 137) | def predict_offsetnet( function train_offsetnet (line 171) | def train_offsetnet( FILE: catenets/models/jax/pseudo_outcome_nets.py class PseudoOutcomeNet (line 76) | class PseudoOutcomeNet(BaseCATENet): method __init__ (line 143) | def __init__( method _get_train_function (line 210) | def _get_train_function(self) -> Callable: method fit (line 213) | def fit( method _get_predict_function (line 240) | def _get_predict_function(self) -> Callable: method predict (line 244) | def predict( class DRNet (line 265) | class DRNet(PseudoOutcomeNet): method __init__ (line 268) | def __init__( class RANet (line 332) | class RANet(PseudoOutcomeNet): method __init__ (line 335) | def __init__( class PWNet (line 399) | class PWNet(PseudoOutcomeNet): method __init__ (line 402) | def __init__( function train_pseudooutcome_net (line 466) | def train_pseudooutcome_net( function _train_and_predict_first_stage (line 690) | def _train_and_predict_first_stage( FILE: catenets/models/jax/representation_nets.py class SNet1 (line 41) | class SNet1(BaseCATENet): method __init__ (line 91) | def __init__( method _get_train_function (line 140) | def _get_train_function(self) -> Callable: method _get_predict_function (line 143) | def _get_predict_function(self) -> Callable: class TARNet (line 147) | class TARNet(SNet1): method __init__ (line 150) | def __init__( class SNet2 (line 196) | class SNet2(BaseCATENet): method __init__ (line 249) | def __init__( method _get_train_function (line 298) | def _get_train_function(self) -> Callable: method _get_predict_function (line 301) | def _get_predict_function(self) -> Callable: class DragonNet (line 305) | class DragonNet(SNet2): method __init__ (line 308) | def __init__( function mmd2_lin (line 358) | def mmd2_lin(X: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray: function predict_snet1 (line 373) | def predict_snet1( function train_snet1 (line 404) | def train_snet1( function train_snet2 (line 632) | def train_snet2( function predict_snet2 (line 899) | def predict_snet2( FILE: catenets/models/jax/rnet.py class RNet (line 51) | class RNet(BaseCATENet): method __init__ (line 111) | def __init__( method _get_train_function (line 172) | def _get_train_function(self) -> Callable: method fit (line 175) | def fit( method _get_predict_function (line 194) | def _get_predict_function(self) -> Callable: method predict (line 198) | def predict( function train_r_net (line 215) | def train_r_net( function _train_and_predict_r_stage1 (line 397) | def _train_and_predict_r_stage1( function train_r_stage2 (line 476) | def train_r_stage2( FILE: catenets/models/jax/snet.py class SNet (line 51) | class SNet(BaseCATENet): method __init__ (line 118) | def __init__( method _get_predict_function (line 178) | def _get_predict_function(self) -> Callable: method _get_train_function (line 184) | def _get_train_function(self) -> Callable: function train_snet (line 191) | def train_snet( function predict_snet (line 587) | def predict_snet( function train_snet_noprop (line 636) | def train_snet_noprop( function predict_snet_noprop (line 934) | def predict_snet_noprop( FILE: catenets/models/jax/tnet.py class TNet (line 39) | class TNet(BaseCATENet): method __init__ (line 86) | def __init__( method _get_predict_function (line 126) | def _get_predict_function(self) -> Callable: method _get_train_function (line 129) | def _get_train_function(self) -> Callable: function train_tnet (line 133) | def train_tnet( function predict_t_net (line 249) | def predict_t_net( function _train_tnet_jointly (line 272) | def _train_tnet_jointly( FILE: catenets/models/jax/transformation_utils.py function aipw_te_transformation (line 16) | def aipw_te_transformation( function ht_te_transformation (line 53) | def ht_te_transformation( function ra_te_transformation (line 87) | def ra_te_transformation( function _get_transformation_function (line 125) | def _get_transformation_function(transformation_name: str) -> Any: FILE: catenets/models/jax/xnet.py class XNet (line 60) | class XNet(BaseCATENet): method __init__ (line 120) | def __init__( method _get_train_function (line 179) | def _get_train_function(self) -> Callable: method _get_predict_function (line 182) | def _get_predict_function(self) -> Callable: method predict (line 186) | def predict( function train_x_net (line 218) | def train_x_net( function _get_first_stage_pos (line 392) | def _get_first_stage_pos( function predict_x_net (line 466) | def predict_x_net( FILE: catenets/models/torch/base.py class BasicNet (line 43) | class BasicNet(nn.Module): method __init__ (line 81) | def __init__( method forward (line 168) | def forward(self, X: torch.Tensor) -> torch.Tensor: method fit (line 171) | def fit( method _check_tensor (line 251) | def _check_tensor(self, X: torch.Tensor) -> torch.Tensor: class RepresentationNet (line 258) | class RepresentationNet(nn.Module): method __init__ (line 274) | def __init__( method forward (line 303) | def forward(self, X: torch.Tensor) -> torch.Tensor: class PropensityNet (line 307) | class PropensityNet(nn.Module): method __init__ (line 350) | def __init__( method forward (line 438) | def forward(self, X: torch.Tensor) -> torch.Tensor: method get_importance_weights (line 441) | def get_importance_weights( method loss (line 447) | def loss(self, y_pred: torch.Tensor, y_target: torch.Tensor) -> torch.... method fit (line 450) | def fit(self, X: torch.Tensor, y: torch.Tensor) -> "PropensityNet": method _check_tensor (line 520) | def _check_tensor(self, X: torch.Tensor) -> torch.Tensor: class BaseCATEEstimator (line 527) | class BaseCATEEstimator(nn.Module): method __init__ (line 534) | def __init__( method score (line 539) | def score( method fit (line 568) | def fit( method forward (line 589) | def forward(self, X: torch.Tensor) -> torch.Tensor: method predict (line 605) | def predict( method _check_tensor (line 623) | def _check_tensor(self, X: torch.Tensor) -> torch.Tensor: FILE: catenets/models/torch/flextenet.py class FlexTELinearLayer (line 31) | class FlexTELinearLayer(nn.Module): method __init__ (line 35) | def __init__( method forward (line 52) | def forward(self, tensors: List[torch.Tensor]) -> List: class FlexTESplitLayer (line 64) | class FlexTESplitLayer(nn.Module): method __init__ (line 69) | def __init__( method forward (line 103) | def forward(self, tensors: List[torch.Tensor]) -> List: class FlexTEOutputLayer (line 134) | class FlexTEOutputLayer(nn.Module): method __init__ (line 135) | def __init__( method forward (line 160) | def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: class ElementWiseParallelActivation (line 183) | class ElementWiseParallelActivation(nn.Module): method __init__ (line 189) | def __init__(self, act: Callable, **act_kwargs: Any) -> None: method forward (line 194) | def forward(self, tensors: List[torch.Tensor]) -> List: class ElementWiseSplitActivation (line 208) | class ElementWiseSplitActivation(nn.Module): method __init__ (line 214) | def __init__(self, act: Callable, **act_kwargs: Any) -> None: method forward (line 219) | def forward(self, tensors: List[torch.Tensor]) -> List: class FlexTENet (line 231) | class FlexTENet(BaseCATEEstimator): method __init__ (line 292) | def __init__( method _ortho_penalty_asymmetric (line 468) | def _ortho_penalty_asymmetric(self) -> torch.Tensor: method loss (line 525) | def loss( method fit (line 546) | def fit( method predict (line 637) | def predict( FILE: catenets/models/torch/pseudo_outcome_nets.py class PseudoOutcomeLearner (line 43) | class PseudoOutcomeLearner(BaseCATEEstimator): method __init__ (line 105) | def __init__( method _generate_te_estimator (line 175) | def _generate_te_estimator(self, name: str = "te_estimator") -> nn.Mod... method _generate_po_estimator (line 200) | def _generate_po_estimator(self, name: str = "po_estimator") -> nn.Mod... method _generate_propensity_estimator (line 226) | def _generate_propensity_estimator( method fit (line 252) | def fit( method predict (line 315) | def predict( method _first_step (line 341) | def _first_step( method _second_step (line 352) | def _second_step( method _impute_pos (line 363) | def _impute_pos( method _impute_propensity (line 388) | def _impute_propensity( method _impute_unconditional_mean (line 409) | def _impute_unconditional_mean( class DRLearner (line 426) | class DRLearner(PseudoOutcomeLearner): method _first_step (line 431) | def _first_step( method _second_step (line 447) | def _second_step( class PWLearner (line 460) | class PWLearner(PseudoOutcomeLearner): method _first_step (line 465) | def _first_step( method _second_step (line 478) | def _second_step( class RALearner (line 491) | class RALearner(PseudoOutcomeLearner): method _first_step (line 496) | def _first_step( method _second_step (line 508) | def _second_step( class ULearner (line 521) | class ULearner(PseudoOutcomeLearner): method _first_step (line 526) | def _first_step( method _second_step (line 540) | def _second_step( class RLearner (line 553) | class RLearner(PseudoOutcomeLearner): method _first_step (line 559) | def _first_step( method _second_step (line 572) | def _second_step( class XLearner (line 587) | class XLearner(PseudoOutcomeLearner): method __init__ (line 593) | def __init__( method _first_step (line 605) | def _first_step( method _second_step (line 617) | def _second_step( method predict (line 637) | def predict( FILE: catenets/models/torch/representation_nets.py class BasicDragonNet (line 39) | class BasicDragonNet(BaseCATEEstimator): method __init__ (line 83) | def __init__( method loss (line 154) | def loss( method fit (line 189) | def fit( method _step (line 283) | def _step( method _forward (line 288) | def _forward(self, X: torch.Tensor) -> torch.Tensor: method predict (line 296) | def predict( method _maximum_mean_discrepancy (line 325) | def _maximum_mean_discrepancy( class TARNet (line 340) | class TARNet(BasicDragonNet): method __init__ (line 345) | def __init__( method _step (line 384) | def _step( class DragonNet (line 399) | class DragonNet(BasicDragonNet): method __init__ (line 404) | def __init__( method _step (line 441) | def _step( FILE: catenets/models/torch/slearner.py class SLearner (line 27) | class SLearner(BaseCATEEstimator): method __init__ (line 68) | def __init__( method fit (line 138) | def fit( method _create_extended_matrices (line 191) | def _create_extended_matrices(self, X: torch.Tensor) -> torch.Tensor: method predict (line 203) | def predict( FILE: catenets/models/torch/snet.py class SNet (line 40) | class SNet(BaseCATEEstimator): method __init__ (line 92) | def __init__( method loss (line 258) | def loss( method fit (line 301) | def fit( method _ortho_reg (line 399) | def _ortho_reg(self) -> float: method _maximum_mean_discrepancy (line 478) | def _maximum_mean_discrepancy( method _step (line 492) | def _step( method _forward (line 501) | def _forward( method predict (line 526) | def predict( FILE: catenets/models/torch/tlearner.py class TLearner (line 22) | class TLearner(BaseCATEEstimator): method __init__ (line 56) | def __init__( method predict (line 107) | def predict( method fit (line 139) | def fit( FILE: catenets/models/torch/utils/decorators.py function check_input_train (line 9) | def check_input_train(func: Callable) -> Callable: function benchmark (line 32) | def benchmark(func: Callable) -> Callable: FILE: catenets/models/torch/utils/model_utils.py function make_val_split (line 19) | def make_val_split( function train_wrapper (line 77) | def train_wrapper( function predict_wrapper (line 93) | def predict_wrapper(estimator: Any, X: torch.Tensor) -> torch.Tensor: FILE: catenets/models/torch/utils/transformations.py function dr_transformation_cate (line 10) | def dr_transformation_cate( function pw_transformation_cate (line 47) | def pw_transformation_cate( function ra_transformation_cate (line 79) | def ra_transformation_cate( function u_transformation_cate (line 110) | def u_transformation_cate( FILE: catenets/models/torch/utils/weight_utils.py function compute_importance_weights (line 26) | def compute_importance_weights( function compute_ipw (line 54) | def compute_ipw(propensity: torch.Tensor, w: torch.Tensor) -> torch.Tensor: function compute_trunc_ipw (line 59) | def compute_trunc_ipw( function compute_matching_weights (line 67) | def compute_matching_weights(propensity: torch.Tensor, w: torch.Tensor) ... function compute_overlap_weights (line 72) | def compute_overlap_weights(propensity: torch.Tensor, w: torch.Tensor) -... FILE: experiments/experiments_AISTATS21/ihdp_experiments.py function do_ihdp_experiments (line 62) | def do_ihdp_experiments( FILE: experiments/experiments_AISTATS21/simulations_AISTATS.py function simulation_experiment_loop (line 133) | def simulation_experiment_loop( function do_one_experiment_repeat (line 253) | def do_one_experiment_repeat( function one_simulation_experiment (line 342) | def one_simulation_experiment( function main_AISTATS (line 408) | def main_AISTATS( FILE: experiments/experiments_benchmarks_NeurIPS21/acic_experiments_catenets.py function do_acic_experiments (line 34) | def do_acic_experiments( FILE: experiments/experiments_benchmarks_NeurIPS21/ihdp_experiments_catenets.py function do_ihdp_experiments (line 36) | def do_ihdp_experiments( FILE: experiments/experiments_benchmarks_NeurIPS21/twins_experiments_catenets.py function do_twins_experiment_loop (line 38) | def do_twins_experiment_loop( function do_twins_experiments (line 56) | def do_twins_experiments( function prepare_twins (line 114) | def prepare_twins(treat_prop=0.5, seed=42, test_size=0.5, subset_train: ... FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_AB.py function do_acic_simu_loops (line 206) | def do_acic_simu_loops( function do_acic_simu (line 249) | def do_acic_simu( function acic_simu (line 414) | def acic_simu( FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_CD.py function do_ihdp_experiments (line 60) | def do_ihdp_experiments( FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_acic.py function do_acic_orig_loop (line 60) | def do_acic_orig_loop( function do_acic_experiments (line 80) | def do_acic_experiments( FILE: experiments/experiments_inductivebias_NeurIPS21/experiments_twins.py function do_twins_experiment_loop (line 63) | def do_twins_experiment_loop( function do_twins_experiments (line 88) | def do_twins_experiments( function split_data (line 216) | def split_data(X, y, w, pos, test_size=0.5, random_state=42, subset_trai... function eval_roc_auc (line 232) | def eval_roc_auc(targets, preds): function eval_ap (line 238) | def eval_ap(targets, preds): FILE: run_experiments_AISTATS.py function init_arg (line 16) | def init_arg() -> Any: FILE: run_experiments_benchmarks_NeurIPS.py function init_arg (line 26) | def init_arg() -> Any: FILE: run_experiments_inductive_bias_NeurIPS.py function init_arg (line 28) | def init_arg() -> Any: FILE: setup.py function read (line 11) | def read(fname: str) -> str: function find_version (line 15) | def find_version() -> str: FILE: tests/datasets/test_datasets.py function test_dataset_sanity_twins (line 9) | def test_dataset_sanity_twins( function test_dataset_sanity_ihdp (line 29) | def test_dataset_sanity_ihdp() -> None: function test_dataset_sanity_acic2016 (line 41) | def test_dataset_sanity_acic2016(preprocessed: bool) -> None: FILE: tests/models/jax/test_jax_ite.py function test_model_sanity (line 30) | def test_model_sanity(dataset: str, pehe_threshold: float, model_name: s... function test_model_score (line 39) | def test_model_score() -> None: FILE: tests/models/jax/test_jax_model_utils.py function test_check_shape_1d_data_sanity (line 16) | def test_check_shape_1d_data_sanity(data: np.ndarray) -> None: function test_check_X_is_np_sanity (line 23) | def test_check_X_is_np_sanity(data: Any) -> None: function test_make_val_split_sanity (line 29) | def test_make_val_split_sanity() -> None: FILE: tests/models/jax/test_jax_transformation_utils.py function test_get_transformation_function_sanity (line 18) | def test_get_transformation_function_sanity() -> None: function test_aipw_te_transformation_sanity (line 31) | def test_aipw_te_transformation_sanity(fn: Callable) -> None: function test_ht_te_transformation_sanity (line 45) | def test_ht_te_transformation_sanity(fn: Callable) -> None: function test_ra_te_transformation_sanity (line 56) | def test_ra_te_transformation_sanity(fn: Callable) -> None: FILE: tests/models/torch/test_torch_flextenet.py function test_flextenet_model_params (line 9) | def test_flextenet_model_params() -> None: function test_flextenet_model_sanity (line 59) | def test_flextenet_model_sanity(dataset: str, pehe_threshold: float) -> ... function test_flextenet_model_predict_api (line 81) | def test_flextenet_model_predict_api( FILE: tests/models/torch/test_torch_pseudo_outcome_nets.py function test_nn_model_params (line 24) | def test_nn_model_params(model_t: Any) -> None: function test_nn_model_params_nonlin (line 39) | def test_nn_model_params_nonlin(nonlin: str, model_t: Any) -> None: function test_nn_model_sanity (line 54) | def test_nn_model_sanity(dataset: str, pehe_threshold: float, model_t: A... function test_sklearn_model_pseudo_outcome_binary (line 101) | def test_sklearn_model_pseudo_outcome_binary( function test_model_predict_api (line 130) | def test_model_predict_api() -> None: FILE: tests/models/torch/test_torch_representation_net.py function test_model_params (line 12) | def test_model_params(snet: Type) -> None: function test_model_params_nonlin (line 41) | def test_model_params_nonlin(nonlin: str, snet: Type) -> None: function test_model_sanity (line 61) | def test_model_sanity(dataset: str, pehe_threshold: float, snet: Type) -... function test_model_predict_api (line 79) | def test_model_predict_api() -> None: FILE: tests/models/torch/test_torch_slearner.py function test_nn_model_params (line 15) | def test_nn_model_params() -> None: function test_nn_model_params_nonlin (line 57) | def test_nn_model_params_nonlin(nonlin: str) -> None: function test_nn_model_sanity (line 72) | def test_nn_model_sanity( function test_sklearn_model_sanity_binary_output (line 124) | def test_sklearn_model_sanity_binary_output( function test_slearner_sklearn_model_ihdp (line 155) | def test_slearner_sklearn_model_ihdp(po_estimator: Any, exp: int) -> None: function test_model_predict_api (line 175) | def test_model_predict_api() -> None: FILE: tests/models/torch/test_torch_snet.py function test_model_params (line 10) | def test_model_params() -> None: function test_model_params_nonlin (line 83) | def test_model_params_nonlin(nonlin: str) -> None: function test_model_sanity (line 108) | def test_model_sanity(dataset: str, pehe_threshold: float) -> None: function test_model_predict_api (line 141) | def test_model_predict_api() -> None: FILE: tests/models/torch/test_torch_tlearner.py function test_nn_model_params (line 15) | def test_nn_model_params() -> None: function test_nn_model_params_nonlin (line 42) | def test_nn_model_params_nonlin(nonlin: str) -> None: function test_nn_model_sanity (line 58) | def test_nn_model_sanity(dataset: str, pehe_threshold: float) -> None: function test_sklearn_model_sanity_binary_output (line 103) | def test_sklearn_model_sanity_binary_output( function test_sklearn_model_sanity_regression (line 149) | def test_sklearn_model_sanity_regression( function test_model_predict_api (line 168) | def test_model_predict_api() -> None: