SYMBOL INDEX (447 symbols across 50 files) FILE: gpax/acquisition/acquisition.py function _compute_mean_and_var (line 22) | def _compute_mean_and_var( function _compute_penalties (line 38) | def _compute_penalties( function EI (line 49) | def EI(rng_key: jnp.ndarray, model: Type[ExactGP], function UCB (line 143) | def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], function POI (line 227) | def POI(rng_key: jnp.ndarray, model: Type[ExactGP], function UE (line 314) | def UE(rng_key: jnp.ndarray, model: Type[ExactGP], function KG (line 397) | def KG(rng_key: jnp.ndarray, function Thompson (line 488) | def Thompson(rng_key: jnp.ndarray, FILE: gpax/acquisition/base_acq.py function ei (line 20) | def ei(moments: Tuple[jnp.ndarray, jnp.ndarray], function ucb (line 74) | def ucb(moments: Tuple[jnp.ndarray, jnp.ndarray], function ue (line 109) | def ue(moments: Tuple[jnp.ndarray, jnp.ndarray], **kwargs) -> jnp.ndarray: function poi (line 134) | def poi(moments: Tuple[jnp.ndarray, jnp.ndarray], function kg (line 158) | def kg(model: Type[ExactGP], FILE: gpax/acquisition/batch_acquisition.py function _compute_batch_acquisition (line 21) | def _compute_batch_acquisition( function qEI (line 62) | def qEI(rng_key: jnp.ndarray, function qUCB (line 119) | def qUCB(rng_key: jnp.ndarray, function qPOI (line 176) | def qPOI(rng_key: jnp.ndarray, function qKG (line 233) | def qKG(rng_key: jnp.ndarray, FILE: gpax/acquisition/optimize.py function optimize_acq (line 19) | def optimize_acq(rng_key: jnp.ndarray, function ensure_array (line 91) | def ensure_array(x): FILE: gpax/acquisition/penalties.py function compute_penalty (line 6) | def compute_penalty(X: jnp.ndarray, recent_points: jnp.ndarray, function penalty_point (line 37) | def penalty_point(x: jnp.ndarray, recent_points: jnp.ndarray) -> jnp.nda... function find_and_replace_point_indices (line 53) | def find_and_replace_point_indices(points, other_points): FILE: gpax/hypo.py function step (line 21) | def step(model: Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndar... function sample_next (line 102) | def sample_next(rewards: Union[np.array, jnp.array], function softmax (line 134) | def softmax(logits: Union[np.array, jnp.array], function eps_greedy (line 146) | def eps_greedy(rewards: Union[np.array, jnp.array], function update_record (line 159) | def update_record(record: np.array, action: int, r: Union[int, float]) -... FILE: gpax/kernels/kernels.py function _sqrt (line 20) | def _sqrt(x, eps=1e-12): function add_jitter (line 24) | def add_jitter(x, jitter=1e-6): function square_scaled_distance (line 28) | def square_scaled_distance(X: jnp.ndarray, Z: jnp.ndarray, function RBFKernel (line 45) | def RBFKernel(X: jnp.ndarray, Z: jnp.ndarray, function MaternKernel (line 69) | def MaternKernel(X: jnp.ndarray, Z: jnp.ndarray, function PeriodicKernel (line 95) | def PeriodicKernel(X: jnp.ndarray, Z: jnp.ndarray, function nngp_erf (line 120) | def nngp_erf(x1: jnp.ndarray, x2: jnp.ndarray, function nngp_relu (line 153) | def nngp_relu(x1: jnp.ndarray, x2: jnp.ndarray, function NNGPKernel (line 186) | def NNGPKernel(activation: str = 'erf', depth: int = 3 function get_kernel (line 227) | def get_kernel(kernel: Union[str, kernel_fn_type] = 'RBF', **kwargs): FILE: gpax/kernels/mtkernels.py function index_kernel (line 24) | def index_kernel(indices1, indices2, params): function MultitaskKernel (line 66) | def MultitaskKernel(base_kernel, **kwargs1): function MultivariateKernel (line 130) | def MultivariateKernel(base_kernel, num_tasks, **kwargs1): function LCMKernel (line 197) | def LCMKernel(base_kernel, shared_input_space=True, num_tasks=None, **kw... FILE: gpax/models/bnn.py class BNN (line 19) | class BNN(sPM): method __init__ (line 21) | def __init__(self, method _set_data (line 31) | def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None function sample_weights (line 40) | def sample_weights(name: str, in_channels: int, out_channels: int) -> jn... function sample_biases (line 48) | def sample_biases(name: str, channels: int) -> jnp.ndarray: function get_mlp (line 55) | def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str,... function get_mlp_prior (line 68) | def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[in... FILE: gpax/models/corgp.py class CoregGP (line 12) | class CoregGP(ExactGP): method __init__ (line 38) | def __init__(self, input_dim: int, data_kernel: str, method model (line 54) | def model(self, method _sample_task_kernel_params (line 105) | def _sample_task_kernel_params(self, n_tasks, rank): FILE: gpax/models/dkl.py class DKL (line 22) | class DKL(ExactGP): method __init__ (line 69) | def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF', method model (line 83) | def model(self, method get_mvn_posterior (line 113) | def get_mvn_posterior(self, method embed (line 135) | def embed(self, X_new: jnp.ndarray) -> jnp.ndarray: method _print_summary (line 145) | def _print_summary(self): function sample_weights (line 152) | def sample_weights(name: str, in_channels: int, out_channels: int) -> jn... function sample_biases (line 160) | def sample_biases(name: str, channels: int) -> jnp.ndarray: function get_mlp (line 167) | def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str,... function get_mlp_prior (line 180) | def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[in... FILE: gpax/models/gp.py class ExactGP (line 29) | class ExactGP: method __init__ (line 96) | def __init__( method model (line 137) | def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float... method fit (line 166) | def fit( method _sample_noise (line 222) | def _sample_noise(self) -> jnp.ndarray: method _sample_kernel_params (line 229) | def _sample_kernel_params(self, output_scale=True) -> Dict[str, jnp.nd... method get_samples (line 249) | def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: method get_mvn_posterior (line 253) | def get_mvn_posterior( method _predict (line 279) | def _predict( method _predict_in_batches (line 295) | def _predict_in_batches( method predict_in_batches (line 325) | def predict_in_batches( method predict (line 351) | def predict( method sample_from_prior (line 401) | def sample_from_prior(self, rng_key: jnp.ndarray, X: jnp.ndarray, num_... method _set_data (line 410) | def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -... method _set_training_data (line 416) | def _set_training_data( method _print_summary (line 430) | def _print_summary(self): FILE: gpax/models/hskgp.py class VarNoiseGP (line 24) | class VarNoiseGP(ExactGP): method __init__ (line 81) | def __init__( method model (line 105) | def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float... method _sample_noise_kernel_params (line 151) | def _sample_noise_kernel_params(self) -> Dict[str, jnp.ndarray]: method get_mvn_posterior (line 163) | def get_mvn_posterior( method get_data_var_samples (line 206) | def get_data_var_samples(self): method _print_summary (line 218) | def _print_summary(self): FILE: gpax/models/ibnn.py class iBNN (line 20) | class iBNN(ExactGP): method __init__ (line 42) | def __init__(self, input_dim: int, depth: int = 3, activation: str = '... method _sample_kernel_params (line 54) | def _sample_kernel_params(self) -> Dict[str, jnp.ndarray]: FILE: gpax/models/linreg.py class LinReg (line 9) | class LinReg: method __init__ (line 11) | def __init__(self): method model (line 15) | def model(x, y=None): method train (line 24) | def train(self, x, y, learning_rate=0.01, num_iterations=5000): method predict (line 32) | def predict(self, x_new): method get_params (line 38) | def get_params(self): FILE: gpax/models/mngp.py class MeasuredNoiseGP (line 29) | class MeasuredNoiseGP(ExactGP): method __init__ (line 61) | def __init__(self, method model (line 74) | def model(self, X: jnp.ndarray, y: jnp.ndarray = None, measured_noise:... method fit (line 100) | def fit( method _predict (line 159) | def _predict( method predict (line 184) | def predict( method linreg (line 248) | def linreg(self, x, y, x_new, **kwargs): method gpreg (line 253) | def gpreg(self, x, y, x_new, **kwargs): FILE: gpax/models/mtgp.py class MultiTaskGP (line 12) | class MultiTaskGP(ExactGP): method __init__ (line 57) | def __init__(self, input_dim: int, data_kernel: str, method model (line 92) | def model(self, method _sample_noise (line 147) | def _sample_noise(self): method _sample_task_kernel_params (line 159) | def _sample_task_kernel_params(self): method _sample_kernel_params (line 183) | def _sample_kernel_params(self): FILE: gpax/models/sparse_gp.py class viSparseGP (line 25) | class viSparseGP(viGP): method __init__ (line 49) | def __init__(self, input_dim: int, kernel: str, method model (line 62) | def model(self, method fit (line 116) | def fit(self, method get_mvn_posterior (line 173) | def get_mvn_posterior(self, X_new: jnp.ndarray, FILE: gpax/models/spm.py class sPM (line 29) | class sPM: method __init__ (line 44) | def __init__(self, method model (line 63) | def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None: method _sample_noise (line 79) | def _sample_noise(self) -> jnp.ndarray: method fit (line 86) | def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, method get_samples (line 127) | def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: method get_param_means (line 131) | def get_param_means(self): method sample_from_prior (line 141) | def sample_from_prior(self, rng_key: jnp.ndarray, method sample_single_posterior_predictive (line 150) | def sample_single_posterior_predictive(self, rng_key, X_new, params, n... method _vmap_predict (line 156) | def _vmap_predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, method predict (line 173) | def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, method _print_summary (line 210) | def _print_summary(self): method _set_data (line 213) | def _set_data(self, FILE: gpax/models/uigp.py class UIGP (line 22) | class UIGP(ExactGP): method __init__ (line 63) | def __init__(self, method model (line 78) | def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float... method _sample_x (line 113) | def _sample_x(self, X: jnp.ndarray) -> jnp.ndarray: method get_mvn_posterior (line 131) | def get_mvn_posterior( method _predict (line 158) | def _predict( method _set_data (line 177) | def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -... method _print_summary (line 192) | def _print_summary(self): FILE: gpax/models/vgp.py class vExactGP (line 23) | class vExactGP(ExactGP): method __init__ (line 42) | def __init__(self, input_dim: int, kernel: str, method model (line 55) | def model(self, method _sample_noise (line 92) | def _sample_noise(self, task_dim) -> jnp.ndarray: method _sample_kernel_params (line 101) | def _sample_kernel_params(self, task_dim: int = None) -> Dict[str, jnp... method _get_mvn_posterior (line 123) | def _get_mvn_posterior(self, method get_mvn_posterior (line 147) | def get_mvn_posterior(self, method predict_in_batches (line 175) | def predict_in_batches(self, rng_key: jnp.ndarray, method _set_data (line 198) | def _set_data(self, FILE: gpax/models/vi_ibnn.py class vi_iBNN (line 20) | class vi_iBNN(viGP): method __init__ (line 43) | def __init__(self, input_dim: int, depth: int = 3, activation: str = '... method _sample_kernel_params (line 53) | def _sample_kernel_params(self) -> Dict[str, jnp.ndarray]: FILE: gpax/models/vi_mtdkl.py class viMTDKL (line 25) | class viMTDKL(viDKL): method __init__ (line 75) | def __init__(self, input_dim: int, z_dim: int = 2, data_kernel: str = ... method model (line 104) | def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None: method _sample_noise (line 163) | def _sample_noise(self): method _sample_task_kernel_params (line 175) | def _sample_task_kernel_params(self): method _sample_kernel_params (line 199) | def _sample_kernel_params(self): method get_mvn_posterior (line 212) | def get_mvn_posterior(self, FILE: gpax/models/vidkl.py class viDKL (line 27) | class viDKL(ExactGP): method __init__ (line 71) | def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, ... method model (line 90) | def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None: method single_fit (line 126) | def single_fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, method fit (line 163) | def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, method get_mvn_posterior (line 206) | def get_mvn_posterior(self, method sample_from_posterior (line 238) | def sample_from_posterior(self, rng_key: jnp.ndarray, method get_samples (line 253) | def get_samples(self) -> Tuple[Dict['str', jnp.ndarray]]: method predict_in_batches (line 257) | def predict_in_batches(self, rng_key: jnp.ndarray, method predict (line 277) | def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, method fit_predict (line 320) | def fit_predict(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, method embed (line 372) | def embed(self, X_new: jnp.ndarray) -> jnp.ndarray: method _print_summary (line 386) | def _print_summary(self) -> None: class MLP (line 400) | class MLP(hk.Module): method __init__ (line 402) | def __init__(self, embedim=2): method __call__ (line 406) | def __call__(self, x): FILE: gpax/models/vigp.py class viGP (line 23) | class viGP(ExactGP): method __init__ (line 61) | def __init__(self, input_dim: int, kernel: str, method fit (line 77) | def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, method get_samples (line 125) | def get_samples(self) -> Dict[str, jnp.ndarray]: method predict_in_batches (line 129) | def predict_in_batches(self, rng_key: jnp.ndarray, method predict (line 153) | def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, method _print_summary (line 187) | def _print_summary(self) -> None: FILE: gpax/priors/priors.py function place_normal_prior (line 18) | def place_normal_prior(param_name: str, loc: float = 0.0, scale: float =... function place_lognormal_prior (line 27) | def place_lognormal_prior(param_name: str, loc: float = 0.0, scale: floa... function place_halfnormal_prior (line 36) | def place_halfnormal_prior(param_name: str, scale: float = 1.0): function place_uniform_prior (line 45) | def place_uniform_prior(param_name: str, function place_gamma_prior (line 58) | def place_gamma_prior(param_name: str, function normal_dist (line 71) | def normal_dist(loc: float = None, scale: float = None function lognormal_dist (line 93) | def lognormal_dist(loc: float = None, scale: float = None) -> numpyro.di... function halfnormal_dist (line 114) | def halfnormal_dist(scale: float = None) -> numpyro.distributions.Distri... function gamma_dist (line 134) | def gamma_dist(c: float = None, function uniform_dist (line 164) | def uniform_dist(low: float = None, function auto_priors (line 192) | def auto_priors(func: Callable, params_begin_with: int, dist_type: str =... function auto_normal_priors (line 219) | def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = ... function auto_lognormal_priors (line 235) | def auto_lognormal_priors(func: Callable, loc: float = 0.0, scale: float... function auto_normal_kernel_priors (line 251) | def auto_normal_kernel_priors(kernel_fn: Callable, loc: float = 0.0, sca... function auto_lognormal_kernel_priors (line 267) | def auto_lognormal_kernel_priors(kernel_fn: Callable, loc: float = 0.0, ... FILE: gpax/utils/fn.py function set_fn (line 21) | def set_fn(func: Callable) -> Callable: function set_kernel_fn (line 58) | def set_kernel_fn(func: Callable, function _set_noise_kernel_fn (line 119) | def _set_noise_kernel_fn(func: Callable) -> Callable: FILE: gpax/utils/utils.py function enable_x64 (line 19) | def enable_x64(): function get_keys (line 24) | def get_keys(seed: int = 0): function split_in_batches (line 33) | def split_in_batches(X_new: Union[onp.ndarray, jnp.ndarray], function split_dict (line 54) | def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int function random_sample_dict (line 84) | def random_sample_dict(data: Dict[str, jnp.ndarray], function get_haiku_dict (line 105) | def get_haiku_dict(kernel_params: Dict[str, jnp.ndarray]) -> Dict[str, D... function dviz (line 126) | def dviz(d: Type[numpyro.distributions.Distribution], samples: int = 100... function preprocess_sparse_image (line 150) | def preprocess_sparse_image(sparse_image): function initialize_inducing_points (line 171) | def initialize_inducing_points(X, ratio=0.1, method='uniform', key=None): FILE: tests/test_acq.py class mock_GP (line 22) | class mock_GP: method __init__ (line 23) | def __init__(self): method get_samples (line 26) | def get_samples(self): function test_base_standard_acq (line 35) | def test_base_standard_acq(base_acq): function test_base_acq_kg (line 45) | def test_base_acq_kg(): function test_base_standard_acq_maximize (line 60) | def test_base_standard_acq_maximize(base_acq): function test_base_standard_acq_best_f (line 70) | def test_base_standard_acq_best_f(base_acq): function test_compute_mean_and_var (line 80) | def test_compute_mean_and_var(): function test_acq_gp (line 94) | def test_acq_gp(acq): function test_acq_vidkl (line 107) | def test_acq_vidkl(acq): function test_acq_dkl (line 120) | def test_acq_dkl(acq): function test_UCB_beta (line 131) | def test_UCB_beta(): function test_EI_gp_penalty_inv_distance (line 145) | def test_EI_gp_penalty_inv_distance(): function test_UCB_gp_penalty_inv_distance (line 159) | def test_UCB_gp_penalty_inv_distance(): function test_UE_gp_penalty_inv_distance (line 173) | def test_UE_gp_penalty_inv_distance(): function test_compute_batch_acquisition (line 188) | def test_compute_batch_acquisition(maximize_distance): function test_batched_acq_gp (line 202) | def test_batched_acq_gp(acq, q): function test_acq_penalty_indices (line 215) | def test_acq_penalty_indices(acq, pen): function test_compute_penalty_delta (line 232) | def test_compute_penalty_delta(): function test_compute_penalty_inverse_distance (line 240) | def test_compute_penalty_inverse_distance(): function test_acq_error (line 250) | def test_acq_error(acq_func): FILE: tests/test_bnn.py function get_dummy_data (line 15) | def get_dummy_data(feature_dim=1, target_dim=1, squeezed=False): function test_bnn_fit (line 23) | def test_bnn_fit(): function test_bnn_custom_layers_fit (line 31) | def test_bnn_custom_layers_fit(): function test_bnn_predict_with_samples (line 47) | def test_bnn_predict_with_samples(): function test_bnn_custom_layers_predict_custom_with_samples (line 66) | def test_bnn_custom_layers_predict_custom_with_samples(): function test_bnn_fit_predict (line 90) | def test_bnn_fit_predict(feature_dim, target_dim, squeezed): FILE: tests/test_corgp.py function get_dummy_data (line 14) | def get_dummy_data(): function attach_indices (line 20) | def attach_indices(X, num_tasks): function dummy_mean_fn (line 25) | def dummy_mean_fn(x, params): function dummy_mean_fn_priors (line 29) | def dummy_mean_fn_priors(): function test_fit_corgp (line 37) | def test_fit_corgp(data_kernel, num_tasks): function test_fit_corgp_meanfn (line 46) | def test_fit_corgp_meanfn(): FILE: tests/test_dkl.py function get_dummy_data (line 14) | def get_dummy_data(jax_ndarray=True): function test_fit (line 23) | def test_fit(jax_ndarray): function test_get_mvn_posterior (line 31) | def test_get_mvn_posterior(): function test_get_mvn_posterior_noiseless (line 54) | def test_get_mvn_posterior_noiseless(): function test_jitter_fit (line 79) | def test_jitter_fit(): function test_jitter_mvn_posterior (line 95) | def test_jitter_mvn_posterior(): FILE: tests/test_func_setter.py function linear_kernel_test (line 11) | def linear_kernel_test(X, Z, k_scale): function rbf_test (line 16) | def rbf_test(X, Z, k_length, k_scale): function sample_function (line 28) | def sample_function(x, a, b): function test_set_fn (line 32) | def test_set_fn(): function test_set_kernel_fn (line 38) | def test_set_kernel_fn(): function test_set_kernel_fn_with_jitter (line 54) | def test_set_kernel_fn_with_jitter(): function test_set_noise_kernel_fn (line 71) | def test_set_noise_kernel_fn(): FILE: tests/test_gp.py function get_dummy_data (line 15) | def get_dummy_data(jax_ndarray=True, unsqueeze=False): function dummy_mean_fn (line 25) | def dummy_mean_fn(x, params): function dummy_mean_fn_priors (line 29) | def dummy_mean_fn_priors(): function gp_kernel_custom_prior (line 35) | def gp_kernel_custom_prior(): function test_fit (line 44) | def test_fit(kernel, jax_ndarray, unsqueeze): function test_get_samples (line 54) | def test_get_samples(kernel, jax_ndarray): function test_get_samples_chain_dim (line 68) | def test_get_samples_chain_dim(chain_dim, samples_dim): function test_sample_kernel (line 80) | def test_sample_kernel(kernel): function test_sample_periodic_kernel (line 91) | def test_sample_periodic_kernel(): function test_sample_noise (line 101) | def test_sample_noise(): function test_sample_noise_custom_prior (line 108) | def test_sample_noise_custom_prior(): function test_sample_kernel_custom_lscale_prior (line 119) | def test_sample_kernel_custom_lscale_prior(): function test_fit_with_custom_kernel_priors (line 131) | def test_fit_with_custom_kernel_priors(kernel): function test_get_mvn_posterior (line 139) | def test_get_mvn_posterior(): function test_get_mvn_posterior_noiseless (line 155) | def test_get_mvn_posterior_noiseless(): function test_single_sample_prediction (line 173) | def test_single_sample_prediction(): function test_prediction (line 192) | def test_prediction(unsqueeze, n): function test_noiseless_prediction (line 209) | def test_noiseless_prediction(): function test_prediction_in_batches (line 227) | def test_prediction_in_batches(batch_size, n): function test_fit_predict (line 245) | def test_fit_predict(kernel): function test_fit_predict_in_batches (line 259) | def test_fit_predict_in_batches(n): function test_fit_noiseless_predict_in_batches (line 273) | def test_fit_noiseless_predict_in_batches(n): function test_fit_with_mean_fn (line 286) | def test_fit_with_mean_fn(jax_ndarray): function test_fit_with_prob_mean_fn (line 295) | def test_fit_with_prob_mean_fn(jax_ndarray): function test_fit_predict_with_mean_fn (line 303) | def test_fit_predict_with_mean_fn(): function test_fit_predict_with_prob_mean_fn (line 316) | def test_fit_predict_with_prob_mean_fn(): function test_sample_from_prior (line 329) | def test_sample_from_prior(): function test_jitter_fit (line 337) | def test_jitter_fit(): function test_jitter_predict (line 353) | def test_jitter_predict(): FILE: tests/test_hskgp.py function get_dummy_data (line 16) | def get_dummy_data(unsqueeze=False): function noise_fn (line 24) | def noise_fn(x, params): function noise_fn_prior (line 28) | def noise_fn_prior(): function test_fit (line 35) | def test_fit(noise_kernel): function test_fit_with_custom_noise_lscale (line 43) | def test_fit_with_custom_noise_lscale(): function test_fit_with_noise_mean_fn (line 51) | def test_fit_with_noise_mean_fn(): function test_fit_with_noise_and_regular_mean_fn (line 59) | def test_fit_with_noise_and_regular_mean_fn(): function test_get_mvn_posterior (line 68) | def test_get_mvn_posterior(): function test_get_mvn_posterior_with_mean_fn (line 87) | def test_get_mvn_posterior_with_mean_fn(): function test_get_mvn_posterior_with_noise_and_regular_mean_fn (line 109) | def test_get_mvn_posterior_with_noise_and_regular_mean_fn(): function test_get_noise_samples (line 133) | def test_get_noise_samples(): function test_get_noise_samples_with_mean_fn (line 142) | def test_get_noise_samples_with_mean_fn(): FILE: tests/test_hypo.py function get_dummy_data (line 13) | def get_dummy_data(jax_ndarray=True): function model (line 21) | def model(x, params): function model_priors (line 25) | def model_priors(): function test_sample_next (line 32) | def test_sample_next(method): function test_step (line 39) | def test_step(gp_wrap): FILE: tests/test_ibnn.py function get_dummy_data (line 14) | def get_dummy_data(jax_ndarray=True, unsqueeze=False): function test_ibnn_fit_predict (line 26) | def test_ibnn_fit_predict(activation, depth): function test_viibnn_fit_predict (line 41) | def test_viibnn_fit_predict(activation, depth): FILE: tests/test_kernels.py function test_data_kernel_shapes (line 16) | def test_data_kernel_shapes(kernel, dim): function test_periodkernel_shapes (line 25) | def test_periodkernel_shapes(dim): function test_data_kernel_ard_shapes (line 35) | def test_data_kernel_ard_shapes(kernel, dim): function test_index_kernel_shapes (line 43) | def test_index_kernel_shapes(): function test_index_kernel_shapes_uneven_obs (line 51) | def test_index_kernel_shapes_uneven_obs(): function test_index_kernel_computations (line 59) | def test_index_kernel_computations(): function test_nngp_shapes (line 71) | def test_nngp_shapes(kernel, dim, depth): function test_NNGPKernel (line 83) | def test_NNGPKernel(activation, dim, depth): function test_NNGPKernel_activations (line 92) | def test_NNGPKernel_activations(): function test_MultiTaskKernel (line 103) | def test_MultiTaskKernel(): function test_multitask_kernel_shapes_test_noiseless (line 111) | def test_multitask_kernel_shapes_test_noiseless(data_kernel, dim): function test_multitask_kernel_shapes_test_noisy (line 126) | def test_multitask_kernel_shapes_test_noisy(data_kernel, dim): function test_multitask_kernel_shapes_train (line 142) | def test_multitask_kernel_shapes_train(data_kernel, dim): function test_MultiVariateKernel (line 156) | def test_MultiVariateKernel(): function test_multivariate_kernel_shapes_test_noisy (line 167) | def test_multivariate_kernel_shapes_test_noisy(data_kernel, dim, num_tas... function test_multivariate_kernel_shapes_test_noiseless (line 182) | def test_multivariate_kernel_shapes_test_noiseless(data_kernel, dim, num... function test_multivariate_kernel_shapes_train (line 196) | def test_multivariate_kernel_shapes_train(data_kernel, dim, num_tasks, r... function test_LCMKernel (line 207) | def test_LCMKernel(): function test_LCMKernel_shapes_multitask (line 216) | def test_LCMKernel_shapes_multitask(data_kernel, dim, num_latent): function test_LCMKernel_shapes_multivariate (line 235) | def test_LCMKernel_shapes_multivariate(data_kernel, dim, num_latent, ran... FILE: tests/test_mngp.py function variable_noise (line 16) | def variable_noise(x): function get_dummy_data (line 19) | def get_dummy_data(): function test_fit (line 28) | def test_fit(): function test_get_mvn_posterior (line 36) | def test_get_mvn_posterior(): function test_predict_single_sample (line 57) | def test_predict_single_sample(n): function test_predict (line 80) | def test_predict(noise_pred_fn): FILE: tests/test_mtgp.py function get_dummy_data (line 14) | def get_dummy_data(): function attach_indices (line 20) | def attach_indices(X, num_tasks): function dummy_mean_fn (line 25) | def dummy_mean_fn(x, params): function dummy_mean_fn_priors (line 29) | def dummy_mean_fn_priors(): function test_fit_multitask (line 38) | def test_fit_multitask(data_kernel, num_tasks, num_latents): function test_fit_multivariate (line 50) | def test_fit_multivariate(data_kernel, num_tasks, num_latents): function test_fit_multitask_meanfn (line 61) | def test_fit_multitask_meanfn(): function test_sample_kernel_custom_lscale_prior (line 71) | def test_sample_kernel_custom_lscale_prior(): function test_sample_task_kernel_custom_W_prior (line 83) | def test_sample_task_kernel_custom_W_prior(): function test_sample_task_kernel_custom_v_prior (line 95) | def test_sample_task_kernel_custom_v_prior(): FILE: tests/test_optimize_acq.py function get_inputs (line 15) | def get_inputs(): function test_optimize_acq (line 22) | def test_optimize_acq(acq_fn): FILE: tests/test_priors.py function linear_kernel_test (line 14) | def linear_kernel_test(X, Z, k_scale): function sample_function (line 19) | def sample_function(x, a, b): function test_normal_prior (line 24) | def test_normal_prior(prior): function test_uniform_prior (line 30) | def test_uniform_prior(): function test_gamma_prior (line 36) | def test_gamma_prior(): function test_normal_prior_params (line 42) | def test_normal_prior_params(): function test_lognormal_prior_params (line 52) | def test_lognormal_prior_params(): function test_halfnormal_prior_params (line 62) | def test_halfnormal_prior_params(): function test_uniform_prior_params (line 71) | def test_uniform_prior_params(): function test_gamma_prior_params (line 81) | def test_gamma_prior_params(): function test_get_uniform_dist (line 91) | def test_get_uniform_dist(): function test_get_uniform_dist_infer_params (line 98) | def test_get_uniform_dist_infer_params(): function test_get_gamma_dist (line 104) | def test_get_gamma_dist(): function test_get_normal_dist (line 111) | def test_get_normal_dist(): function test_get_lognormal_dist (line 118) | def test_get_lognormal_dist(): function test_get_halfnormal_dist (line 125) | def test_get_halfnormal_dist(): function test_get_gamma_dist_infer_param (line 131) | def test_get_gamma_dist_infer_param(): function test_get_uniform_dist_error (line 138) | def test_get_uniform_dist_error(): function test_get_gamma_dist_error (line 147) | def test_get_gamma_dist_error(): function test_auto_priors (line 153) | def test_auto_priors(prior_type): function test_auto_normal_priors (line 169) | def test_auto_normal_priors(autopriors): function test_auto_normal_kernel_priors (line 179) | def test_auto_normal_kernel_priors(autopriors): FILE: tests/test_sparsegp.py function get_dummy_data (line 16) | def get_dummy_data(jax_ndarray=True, unsqueeze=False): function test_fit (line 28) | def test_fit(jax_ndarray, unsqueeze): function test_inducing_points_optimization (line 37) | def test_inducing_points_optimization(): function test_get_mvn_posterior (line 47) | def test_get_mvn_posterior(): FILE: tests/test_spm.py function get_dummy_data (line 15) | def get_dummy_data(jax_ndarray=True): function model (line 23) | def model(x, params): function model_priors (line 27) | def model_priors(): function test_fit (line 34) | def test_fit(jax_ndarray): function test_get_samples (line 42) | def test_get_samples(): function test_prediction (line 55) | def test_prediction(): function test_fit_predict (line 70) | def test_fit_predict(): FILE: tests/test_uigp.py function get_dummy_data (line 16) | def get_dummy_data(): function test_sample_x (line 24) | def test_sample_x(n_features): function test_fit (line 33) | def test_fit(): function test_fit_with_custom_sigma_x_prior (line 41) | def test_fit_with_custom_sigma_x_prior(): function test_get_mvn_posterior (line 49) | def test_get_mvn_posterior(): function test_predict_single_sample (line 72) | def test_predict_single_sample(noiseless): FILE: tests/test_utils.py function test_sparse_img_processing (line 15) | def test_sparse_img_processing(): function test_split_dict (line 33) | def test_split_dict(): function test_random_sample_size (line 56) | def test_random_sample_size(): function test_random_sample_consistency (line 69) | def test_random_sample_consistency(): function test_random_sample_difference (line 84) | def test_random_sample_difference(): function test_get_keys (line 99) | def test_get_keys(): function test_get_keys_different_seeds (line 105) | def test_get_keys_different_seeds(): function test_ratio_out_of_bounds (line 112) | def test_ratio_out_of_bounds(): function test_invalid_method (line 120) | def test_invalid_method(): function test_missing_key_for_random_method (line 126) | def test_missing_key_for_random_method(): function test_output_shape (line 133) | def test_output_shape(method): function test_kmeans_dependency (line 143) | def test_kmeans_dependency(): FILE: tests/test_vgp.py function get_dummy_data (line 15) | def get_dummy_data(jax_ndarray=True, unsqueeze=False): function test_fit (line 29) | def test_fit(kernel, jax_ndarray, unsqueeze): function test_get_samples (line 39) | def test_get_samples(kernel, jax_ndarray): function test_get_samples_chain_dim (line 53) | def test_get_samples_chain_dim(chain_dim, samples_dim): function test_sample_kernel (line 65) | def test_sample_kernel(kernel): function test_sample_periodic_kernel (line 76) | def test_sample_periodic_kernel(): function test_sample_noise (line 86) | def test_sample_noise(): function test_sample_noise_custom_prior (line 93) | def test_sample_noise_custom_prior(): function test_get_mvn_posterior (line 104) | def test_get_mvn_posterior(): function test_get_mvn_posterior_noiseless (line 120) | def test_get_mvn_posterior_noiseless(): function test_prediction (line 139) | def test_prediction(n): function test_fit_predict_in_batches (line 157) | def test_fit_predict_in_batches(n): function test_fit_predict_in_batches_noiseless (line 172) | def test_fit_predict_in_batches_noiseless(n): function test_jitter_predict (line 184) | def test_jitter_predict(): FILE: tests/test_vidkl.py function get_dummy_data (line 16) | def get_dummy_data(jax_ndarray=True): function get_dummy_image_data (line 24) | def get_dummy_image_data(jax_ndarray=True): function get_dummy_vector_data (line 32) | def get_dummy_vector_data(jax_ndarray=True): class CustomConvNet (line 38) | class CustomConvNet(hk.Module): method __init__ (line 39) | def __init__(self, embedim=2): method __call__ (line 43) | def __call__(self, x): function test_single_fit (line 55) | def test_single_fit(jax_ndarray): function test_single_fit_custom_net (line 67) | def test_single_fit_custom_net(jax_ndarray): function test_get_mvn_posterior (line 79) | def test_get_mvn_posterior(): function test_get_mvn_posterior_noiseless (line 98) | def test_get_mvn_posterior_noiseless(): function test_fit_scalar_target (line 119) | def test_fit_scalar_target(): function test_fit_vector_target (line 130) | def test_fit_vector_target(): function test_predict_scalar (line 144) | def test_predict_scalar(): function test_predict_vector (line 165) | def test_predict_vector(): function test_predict_in_batches_scalar (line 187) | def test_predict_in_batches_scalar(): function test_predict_in_batches_vector (line 208) | def test_predict_in_batches_vector(): function test_fit_predict_scalar (line 230) | def test_fit_predict_scalar(): function test_fit_predict_vector (line 243) | def test_fit_predict_vector(): function test_fit_predict_scalar_ensemble (line 256) | def test_fit_predict_scalar_ensemble(): function test_fit_predict_vector_ensemble (line 270) | def test_fit_predict_vector_ensemble(): function test_fit_predict_scalar_ensemble_custom_net (line 284) | def test_fit_predict_scalar_ensemble_custom_net(): FILE: tests/test_vigp.py function get_dummy_data (line 17) | def get_dummy_data(jax_ndarray=True, unsqueeze=False): function dummy_mean_fn (line 27) | def dummy_mean_fn(x, params): function dummy_mean_fn_priors (line 31) | def dummy_mean_fn_priors(): function gp_kernel_custom_prior (line 37) | def gp_kernel_custom_prior(): function test_fit (line 46) | def test_fit(kernel, jax_ndarray, unsqueeze): function test_get_samples (line 56) | def test_get_samples(kernel, jax_ndarray): function test_prediction (line 69) | def test_prediction(unsqueeze): function test_noiseless_prediction (line 86) | def test_noiseless_prediction(): function test_prediction_in_batches (line 104) | def test_prediction_in_batches(unsqueeze, batch_size): function test_fit_predict (line 123) | def test_fit_predict(kernel): function test_fit_predict_in_batches (line 137) | def test_fit_predict_in_batches(batch_size): function test_fit_with_mean_fn (line 153) | def test_fit_with_mean_fn(jax_ndarray): function test_fit_with_prob_mean_fn (line 162) | def test_fit_with_prob_mean_fn(jax_ndarray): function test_fit_predict_with_mean_fn (line 170) | def test_fit_predict_with_mean_fn(): function test_fit_predict_with_prob_mean_fn (line 183) | def test_fit_predict_with_prob_mean_fn(): function test_sample_from_prior (line 196) | def test_sample_from_prior(): function test_jitter_fit (line 204) | def test_jitter_fit(): function test_jitter_predict (line 220) | def test_jitter_predict(): function test_guide_type (line 236) | def test_guide_type(): FILE: tests/test_vimtdkl.py function get_dummy_data (line 13) | def get_dummy_data(): function attach_indices (line 19) | def attach_indices(X, num_tasks): function test_fit_multitask (line 27) | def test_fit_multitask(data_kernel, num_tasks, num_latents): function test_fit_multitask_shared_input (line 40) | def test_fit_multitask_shared_input(data_kernel, num_tasks, num_latents): function test_fit_predict_multitask (line 54) | def test_fit_predict_multitask(data_kernel, num_tasks, num_latents):