SYMBOL INDEX (5897 symbols across 565 files) FILE: docs/source/conf.py function setup (line 212) | def setup(app): FILE: examples/air/air.py function default_z_pres_prior_p (line 24) | def default_z_pres_prior_p(t): class AIR (line 34) | class AIR(nn.Module): method __init__ (line 35) | def __init__( method prior (line 128) | def prior(self, n, **kwargs): method prior_step (line 145) | def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p): method model (line 193) | def model(self, data, batch_size, **kwargs): method guide (line 211) | def guide(self, data, batch_size, **kwargs): method guide_step (line 262) | def guide_step(self, t, n, prev, inputs): method baseline_step (line 313) | def baseline_step(self, prev, inputs): function expand_z_where (line 352) | def expand_z_where(z_where): function z_where_inv (line 369) | def z_where_inv(z_where): function window_to_image (line 382) | def window_to_image(z_where, window_size, image_size, windows): function image_to_window (line 391) | def image_to_window(z_where, window_size, image_size, images): function latents_to_tensor (line 403) | def latents_to_tensor(z): FILE: examples/air/main.py function count_accuracy (line 33) | def count_accuracy(X, true_counts, air, batch_size): function make_prior (line 70) | def make_prior(k): function lin_decay (line 103) | def lin_decay(initial, final, begin, duration, t): function exp_decay (line 109) | def exp_decay(initial, final, begin, duration, t): function load_data (line 118) | def load_data(): function main (line 130) | def main(**kwargs): FILE: examples/air/modules.py class Encoder (line 12) | class Encoder(nn.Module): method __init__ (line 13) | def __init__(self, x_size, h_sizes, z_size, non_linear_layer): method forward (line 19) | def forward(self, x): class Decoder (line 25) | class Decoder(nn.Module): method __init__ (line 26) | def __init__(self, x_size, h_sizes, z_size, bias, use_sigmoid, non_lin... method forward (line 32) | def forward(self, z): class MLP (line 44) | class MLP(nn.Module): method __init__ (line 45) | def __init__( method forward (line 61) | def forward(self, x): class Predict (line 67) | class Predict(nn.Module): method __init__ (line 68) | def __init__( method forward (line 77) | def forward(self, h): class Identity (line 85) | class Identity(nn.Module): method __init__ (line 86) | def __init__(self): method forward (line 89) | def forward(self, x): FILE: examples/air/viz.py function bounding_box (line 11) | def bounding_box(z_where, x_size): function arr2img (line 23) | def arr2img(arr): function img2arr (line 30) | def img2arr(img): function colors (line 38) | def colors(k): function draw_one (line 42) | def draw_one(imgarr, z_arr): function draw_many (line 64) | def draw_many(imgarrs, z_arr): function tensor_to_objs (line 75) | def tensor_to_objs(latents): FILE: examples/baseball.py function fully_pooled (line 63) | def fully_pooled(at_bats, hits): function not_pooled (line 79) | def not_pooled(at_bats, hits): function partially_pooled (line 95) | def partially_pooled(at_bats, hits): function partially_pooled_with_logit (line 118) | def partially_pooled_with_logit(at_bats, hits): function get_summary_table (line 141) | def get_summary_table( function train_test_split (line 176) | def train_test_split(pd_dataframe): function sample_posterior_predictive (line 203) | def sample_posterior_predictive(model, posterior_samples, baseball_datas... function evaluate_pointwise_pred_density (line 234) | def evaluate_pointwise_pred_density(model, posterior_samples, baseball_d... function main (line 255) | def main(args): FILE: examples/capture_recapture/cjs.py function model_1 (line 54) | def model_1(capture_history, sex): function model_2 (line 89) | def model_2(capture_history, sex): function model_3 (line 128) | def model_3(capture_history, sex): function model_4 (line 178) | def model_4(capture_history, sex): function model_5 (line 224) | def model_5(capture_history, sex): function main (line 269) | def main(args): FILE: examples/contrib/autoname/mixture.py function model (line 21) | def model(data, k): function local_model (line 35) | def local_model(latent, ps, locs, scales, obs=None): function guide (line 40) | def guide(data, k): function local_guide (line 48) | def local_guide(latent, k): function main (line 54) | def main(args): FILE: examples/contrib/autoname/scoping_mixture.py function model (line 16) | def model(K, data): function local_model (line 27) | def local_model(weights, locs, scale, data): function guide (line 34) | def guide(K, data): function local_guide (line 45) | def local_guide(probs): function main (line 49) | def main(args): FILE: examples/contrib/autoname/tree_data.py function model (line 25) | def model(data): function model_recurse (line 31) | def model_recurse(data, latent): function guide (line 51) | def guide(data): function guide_recurse (line 55) | def guide_recurse(data, latent): function main (line 73) | def main(args): FILE: examples/contrib/cevae/synthetic.py function generate_data (line 29) | def generate_data(args): function main (line 46) | def main(args): FILE: examples/contrib/epidemiology/regional.py function Model (line 15) | def Model(args, data): function generate_data (line 22) | def generate_data(args): function infer_mcmc (line 53) | def infer_mcmc(args, model): function infer_svi (line 87) | def infer_svi(args, model): function predict (line 109) | def predict(args, model, truth): function main (line 152) | def main(args): FILE: examples/contrib/epidemiology/sir.py function Model (line 29) | def Model(args, data): function generate_data (line 58) | def generate_data(args): function infer_mcmc (line 103) | def infer_mcmc(args, model): function infer_svi (line 143) | def infer_svi(args, model): function evaluate (line 167) | def evaluate(args, model, samples): function predict (line 261) | def predict(args, model, truth): function main (line 316) | def main(args): FILE: examples/contrib/forecast/bart.py function preprocess (line 20) | def preprocess(args): class Model (line 45) | class Model(ForecastingModel): method model (line 51) | def model(self, zero_data, covariates): function main (line 125) | def main(args): FILE: examples/contrib/funsor/hmm.py function model_0 (line 94) | def model_0(sequences, lengths, args, batch_size=None, include_prior=True): function model_1 (line 185) | def model_1(sequences, lengths, args, batch_size=None, include_prior=True): function model_2 (line 276) | def model_2(sequences, lengths, args, batch_size=None, include_prior=True): function model_3 (line 327) | def model_3(sequences, lengths, args, batch_size=None, include_prior=True): function model_4 (line 382) | def model_4(sequences, lengths, args, batch_size=None, include_prior=True): class TonesGenerator (line 437) | class TonesGenerator(nn.Module): method __init__ (line 438) | def __init__(self, args, data_dim): method forward (line 448) | def forward(self, x, y): function model_5 (line 470) | def model_5(sequences, lengths, args, batch_size=None, include_prior=True): function model_6 (line 523) | def model_6(sequences, lengths, args, batch_size=None, include_prior=Fal... function model_7 (line 591) | def model_7(sequences, lengths, args, batch_size=None, include_prior=True): function main (line 671) | def main(args): FILE: examples/contrib/gp/sv-dkl.py class CNN (line 45) | class CNN(nn.Module): method __init__ (line 46) | def __init__(self): method forward (line 53) | def forward(self, x): function train (line 62) | def train(args, train_loader, gpmodule, optimizer, loss_fn, epoch): function test (line 87) | def test(args, test_loader, gpmodule): function main (line 111) | def main(args): FILE: examples/contrib/mue/FactorMuE.py function generate_data (line 47) | def generate_data(small_test, include_stop, device): function main (line 62) | def main(args): FILE: examples/contrib/mue/ProfileHMM.py function generate_data (line 51) | def generate_data(small_test, include_stop, device): function main (line 66) | def main(args): FILE: examples/contrib/oed/ab_test.py function estimated_ape (line 60) | def estimated_ape(ns, num_vi_steps): function true_ape (line 82) | def true_ape(ns): function main (line 94) | def main(num_vi_steps, num_bo_steps, seed): FILE: examples/contrib/oed/gp_bayes_opt.py class GPBayesOptimizer (line 14) | class GPBayesOptimizer(pyro.optim.multi.MultiOptimizer): method __init__ (line 19) | def __init__(self, constraints, gpmodel, num_acquisitions, acquisition... method update_posterior (line 36) | def update_posterior(self, X, y): method find_a_candidate (line 50) | def find_a_candidate(self, differentiable, x_init): method opt_differentiable (line 83) | def opt_differentiable(self, differentiable, num_candidates=5): method acquire_thompson (line 110) | def acquire_thompson(self, num_acquisitions=1, **opt_params): method get_step (line 132) | def get_step(self, loss, params, verbose=False): FILE: examples/contrib/timeseries/gp_models.py function download_data (line 16) | def download_data(): function main (line 23) | def main(args): FILE: examples/cvae/baseline.py class BaselineNet (line 14) | class BaselineNet(nn.Module): method __init__ (line 15) | def __init__(self, hidden_1, hidden_2): method forward (line 22) | def forward(self, x): class MaskedBCELoss (line 30) | class MaskedBCELoss(nn.Module): method __init__ (line 31) | def __init__(self, masked_with=-1): method forward (line 35) | def forward(self, input, target): function train (line 46) | def train( FILE: examples/cvae/cvae.py class Encoder (line 16) | class Encoder(nn.Module): method __init__ (line 17) | def __init__(self, z_dim, hidden_1, hidden_2): method forward (line 25) | def forward(self, x, y): class Decoder (line 40) | class Decoder(nn.Module): method __init__ (line 41) | def __init__(self, z_dim, hidden_1, hidden_2): method forward (line 48) | def forward(self, z): class CVAE (line 55) | class CVAE(nn.Module): method __init__ (line 56) | def __init__(self, z_dim, hidden_1, hidden_2, pre_trained_baseline_net): method model (line 68) | def model(self, xs, ys=None): method guide (line 107) | def guide(self, xs, ys=None): function train (line 122) | def train( FILE: examples/cvae/main.py function main (line 15) | def main(args): FILE: examples/cvae/mnist.py class CVAEMNIST (line 12) | class CVAEMNIST(Dataset): method __init__ (line 13) | def __init__(self, root, train=True, transform=None, download=False): method __len__ (line 17) | def __len__(self): method __getitem__ (line 20) | def __getitem__(self, item): class ToTensor (line 29) | class ToTensor: method __call__ (line 30) | def __call__(self, sample): class MaskImages (line 38) | class MaskImages: method __init__ (line 46) | def __init__(self, num_quadrant_inputs, mask_with=-1): method __call__ (line 52) | def __call__(self, sample): function get_data (line 77) | def get_data(num_quadrant_inputs, batch_size): FILE: examples/cvae/util.py function imshow (line 19) | def imshow(inp, image_path=None): function visualize (line 41) | def visualize( function generate_table (line 115) | def generate_table( FILE: examples/dmm.py class Emitter (line 44) | class Emitter(nn.Module): method __init__ (line 49) | def __init__(self, input_dim, z_dim, emission_dim): method forward (line 58) | def forward(self, z_t): class GatedTransition (line 69) | class GatedTransition(nn.Module): method __init__ (line 75) | def __init__(self, z_dim, transition_dim): method forward (line 92) | def forward(self, z_t_1): class Combiner (line 114) | class Combiner(nn.Module): method __init__ (line 121) | def __init__(self, z_dim, rnn_dim): method forward (line 131) | def forward(self, z_t_1, h_rnn): class DMM (line 147) | class DMM(nn.Module): method __init__ (line 153) | def __init__( method model (line 203) | def model( method guide (line 263) | def guide( function main (line 334) | def main(args): FILE: examples/eight_schools/mcmc.py function model (line 19) | def model(sigma): function conditioned_model (line 29) | def conditioned_model(model, sigma, y): function main (line 33) | def main(args): FILE: examples/eight_schools/svi.py function model (line 20) | def model(data): function guide (line 34) | def guide(data): function main (line 67) | def main(args): FILE: examples/einsum.py function jit_prob (line 38) | def jit_prob(equation, *operands, **kwargs): function jit_logprob (line 55) | def jit_logprob(equation, *operands, **kwargs): function jit_gradient (line 74) | def jit_gradient(equation, *operands, **kwargs): function _jit_adjoint (line 105) | def _jit_adjoint(equation, *operands, **kwargs): function jit_map (line 144) | def jit_map(equation, *operands, **kwargs): function jit_sample (line 150) | def jit_sample(equation, *operands, **kwargs): function jit_marginal (line 156) | def jit_marginal(equation, *operands, **kwargs): function time_fn (line 162) | def time_fn(fn, equation, *operands, **kwargs): function main (line 175) | def main(args): FILE: examples/hmm.py function model_0 (line 83) | def model_0(sequences, lengths, args, batch_size=None, include_prior=True): function model_1 (line 174) | def model_1(sequences, lengths, args, batch_size=None, include_prior=True): function model_2 (line 265) | def model_2(sequences, lengths, args, batch_size=None, include_prior=True): function model_3 (line 316) | def model_3(sequences, lengths, args, batch_size=None, include_prior=True): function model_4 (line 371) | def model_4(sequences, lengths, args, batch_size=None, include_prior=True): class TonesGenerator (line 426) | class TonesGenerator(nn.Module): method __init__ (line 427) | def __init__(self, args, data_dim): method forward (line 437) | def forward(self, x, y): function model_5 (line 459) | def model_5(sequences, lengths, args, batch_size=None, include_prior=True): function model_6 (line 512) | def model_6(sequences, lengths, args, batch_size=None, include_prior=Fal... function model_7 (line 580) | def model_7(sequences, lengths, args, batch_size=None, include_prior=True): function main (line 621) | def main(args): FILE: examples/inclined_plane.py function simulate (line 36) | def simulate(mu, length=2.0, phi=np.pi / 6.0, dt=0.005, noise_sigma=None): function analytic_T (line 66) | def analytic_T(mu, length=2.0, phi=np.pi / 6.0): function model (line 86) | def model(observed_data): function main (line 101) | def main(args): FILE: examples/lda.py function model (line 42) | def model(data=None, args=None, batch_size=None): function make_predictor (line 78) | def make_predictor(args): function parametrized_guide (line 96) | def parametrized_guide(predictor, data, args, batch_size=None): function main (line 125) | def main(args): FILE: examples/lkj.py function model (line 22) | def model(y): function main (line 45) | def main(args): FILE: examples/minipyro.py function main (line 19) | def main(args): FILE: examples/mixed_hmm/experiment.py function aic_num_parameters (line 19) | def aic_num_parameters(model, guide=None): function run_expt (line 37) | def run_expt(args): FILE: examples/mixed_hmm/model.py function guide_generic (line 14) | def guide_generic(config): function model_generic (line 70) | def model_generic(config): FILE: examples/mixed_hmm/seal_data.py function download_seal_data (line 13) | def download_seal_data(filename): function prepare_seal (line 20) | def prepare_seal(filename, random_effects): FILE: examples/neutra.py class BananaShaped (line 46) | class BananaShaped(dist.TorchDistribution): method __init__ (line 50) | def __init__(self, a, b, rho=0.9): method sample (line 58) | def sample(self, sample_shape=()): method log_prob (line 66) | def log_prob(self, x): function model (line 74) | def model(a, b, rho=0.9): function fit_guide (line 78) | def fit_guide(guide, args): function run_hmc (line 88) | def run_hmc(args, model): function main (line 96) | def main(args): FILE: examples/rsa/generics.py function Marginal (line 26) | def Marginal(fn): function discretize_beta_pdf (line 38) | def discretize_beta_pdf(bins, gamma, delta): function structured_prior_model (line 55) | def structured_prior_model(params): function threshold_prior (line 73) | def threshold_prior(): function utterance_prior (line 81) | def utterance_prior(): function meaning (line 87) | def meaning(utterance, state, threshold): function listener0 (line 106) | def listener0(utterance, threshold, prior): function speaker1 (line 114) | def speaker1(state, threshold, prior): function listener1 (line 124) | def listener1(utterance, prior): function speaker2 (line 133) | def speaker2(prevalence, prior): function main (line 140) | def main(args): FILE: examples/rsa/hyperbole.py function Marginal (line 23) | def Marginal(fn): function approx (line 35) | def approx(x, b=None): function price_prior (line 43) | def price_prior(): function valence_prior (line 52) | def valence_prior(price): function meaning (line 68) | def meaning(utterance, price): function qud_prior (line 83) | def qud_prior(): function utterance_cost (line 91) | def utterance_cost(numberUtt): function utterance_prior (line 96) | def utterance_prior(): function literal_listener (line 106) | def literal_listener(utterance, qud): function speaker (line 114) | def speaker(qudValue, qud): function pragmatic_listener (line 124) | def pragmatic_listener(utterance): function test_truth (line 138) | def test_truth(): function main (line 204) | def main(args): FILE: examples/rsa/schelling.py function location (line 23) | def location(preference): function alice (line 32) | def alice(preference, depth): function bob (line 42) | def bob(preference, depth): function main (line 55) | def main(args): FILE: examples/rsa/schelling_false.py function location (line 24) | def location(preference): function alice_fb (line 33) | def alice_fb(preference, depth): function alice (line 46) | def alice(preference, depth): function bob (line 56) | def bob(preference, depth): function main (line 69) | def main(args): FILE: examples/rsa/search_inference.py function memoize (line 22) | def memoize(fn=None, **kwargs): class HashingMarginal (line 28) | class HashingMarginal(dist.Distribution): method __init__ (line 37) | def __init__(self, trace_dist, sites=None): method _dist_and_values (line 54) | def _dist_and_values(self): method sample (line 85) | def sample(self): method log_prob (line 90) | def log_prob(self, val): method enumerate_support (line 100) | def enumerate_support(self): method _dict_to_tuple (line 104) | def _dict_to_tuple(self, d): method _weighted_mean (line 115) | def _weighted_mean(self, value, dim=0): method mean (line 122) | def mean(self): method variance (line 127) | def variance(self): class Search (line 138) | class Search(TracePosterior): method __init__ (line 143) | def __init__(self, model, max_tries=int(1e6), **kwargs): method _traces (line 148) | def _traces(self, *args, **kwargs): function pqueue (line 162) | def pqueue(fn, queue): class BestFirstSearch (line 200) | class BestFirstSearch(TracePosterior): method __init__ (line 206) | def __init__(self, model, num_samples=None, **kwargs): method _traces (line 213) | def _traces(self, *args, **kwargs): FILE: examples/rsa/semantic_parsing.py function Marginal (line 22) | def Marginal(fn=None, **kwargs): function flip (line 35) | def flip(name, p): function Obj (line 43) | def Obj(name): class Meaning (line 52) | class Meaning: method sem (line 53) | def sem(self, world): method syn (line 58) | def syn(self): class UndefinedMeaning (line 62) | class UndefinedMeaning(Meaning): method sem (line 63) | def sem(self, world): method syn (line 66) | def syn(self): class BlondMeaning (line 70) | class BlondMeaning(Meaning): method sem (line 71) | def sem(self, world): method syn (line 74) | def syn(self): class NiceMeaning (line 78) | class NiceMeaning(Meaning): method sem (line 79) | def sem(self, world): method syn (line 82) | def syn(self): class TallMeaning (line 86) | class TallMeaning(Meaning): method sem (line 87) | def sem(self, world): method syn (line 90) | def syn(self): class BobMeaning (line 94) | class BobMeaning(Meaning): method sem (line 95) | def sem(self, world): method syn (line 98) | def syn(self): class SomeMeaning (line 102) | class SomeMeaning(Meaning): method sem (line 103) | def sem(self, world): method syn (line 112) | def syn(self): class AllMeaning (line 124) | class AllMeaning(Meaning): method sem (line 125) | def sem(self, world): method syn (line 136) | def syn(self): class NoneMeaning (line 148) | class NoneMeaning(Meaning): method sem (line 149) | def sem(self, world): method syn (line 158) | def syn(self): class CompoundMeaning (line 170) | class CompoundMeaning(Meaning): method __init__ (line 171) | def __init__(self, sem, syn): method sem (line 175) | def sem(self, world): method syn (line 178) | def syn(self): function heuristic (line 187) | def heuristic(is_good): function world_prior (line 193) | def world_prior(num_objs, meaning_fn): function lexical_meaning (line 206) | def lexical_meaning(word): function apply_world_passing (line 221) | def apply_world_passing(f, a): function syntax_match (line 225) | def syntax_match(s, t): function can_apply (line 236) | def can_apply(meanings): function combine_meaning (line 255) | def combine_meaning(meanings, c): function combine_meanings (line 273) | def combine_meanings(meanings, c=0): function meaning (line 280) | def meaning(utterance): function literal_listener (line 288) | def literal_listener(utterance): function utterance_prior (line 295) | def utterance_prior(): function speaker (line 306) | def speaker(world): function rsa_listener (line 313) | def rsa_listener(utterance, qud): function literal_listener_raw (line 320) | def literal_listener_raw(utterance, qud): function main (line 327) | def main(args): FILE: examples/scanvi/scanvi.py function make_fc (line 41) | def make_fc(dims): function split_in_half (line 51) | def split_in_half(t): function broadcast_inputs (line 56) | def broadcast_inputs(input_args): class Z2Decoder (line 63) | class Z2Decoder(nn.Module): method __init__ (line 64) | def __init__(self, z1_dim, y_dim, z2_dim, hidden_dims): method forward (line 69) | def forward(self, z1, y): class XDecoder (line 84) | class XDecoder(nn.Module): method __init__ (line 85) | def __init__(self, num_genes, z2_dim, hidden_dims): method forward (line 90) | def forward(self, z2): class Z2LEncoder (line 97) | class Z2LEncoder(nn.Module): method __init__ (line 98) | def __init__(self, num_genes, z2_dim, hidden_dims): method forward (line 103) | def forward(self, x): class Z1Encoder (line 115) | class Z1Encoder(nn.Module): method __init__ (line 116) | def __init__(self, num_labels, z1_dim, z2_dim, hidden_dims): method forward (line 121) | def forward(self, z2, y): class Classifier (line 136) | class Classifier(nn.Module): method __init__ (line 137) | def __init__(self, z2_dim, hidden_dims, num_labels): method forward (line 142) | def forward(self, x): class SCANVI (line 148) | class SCANVI(nn.Module): method __init__ (line 149) | def __init__( method model (line 209) | def model(self, x, y=None): method guide (line 252) | def guide(self, x, y=None): function main (line 280) | def main(args): FILE: examples/sir_hmc.py function global_model (line 63) | def global_model(population): function discrete_model (line 75) | def discrete_model(args, data): function generate_data (line 94) | def generate_data(args): function reparameterized_discrete_model (line 166) | def reparameterized_discrete_model(args, data): function infer_hmc_enum (line 209) | def infer_hmc_enum(args, data): function _infer_hmc (line 214) | def _infer_hmc(args, data, model, init_values={}): function quantize (line 267) | def quantize(name, x_real, min, max): function continuous_model (line 303) | def continuous_model(args, data): function heuristic_init (line 350) | def heuristic_init(args, data): function infer_hmc_cont (line 371) | def infer_hmc_cont(model, args, data): function quantize_enumerate (line 383) | def quantize_enumerate(x_real, min, max): function vectorized_model (line 415) | def vectorized_model(args, data): function evaluate (line 482) | def evaluate(args, samples): function predict (line 527) | def predict(args, data, samples, truth=None): function main (line 611) | def main(args): FILE: examples/smcfilter.py class SimpleHarmonicModel (line 25) | class SimpleHarmonicModel: method __init__ (line 26) | def __init__(self, process_noise, measurement_noise): method init (line 32) | def init(self, state, initial): method step (line 36) | def step(self, state, y=None): class SimpleHarmonicModel_Guide (line 48) | class SimpleHarmonicModel_Guide: method __init__ (line 49) | def __init__(self, model): method init (line 52) | def init(self, state, initial): method step (line 56) | def step(self, state, y=None): function generate_data (line 68) | def generate_data(args): function main (line 84) | def main(args): FILE: examples/sparse_gamma_def.py function rand_tensor (line 39) | def rand_tensor(shape, mean, sigma): class SparseGammaDEF (line 43) | class SparseGammaDEF: method __init__ (line 44) | def __init__(self): method model (line 61) | def model(self, x): method guide (line 113) | def guide(self, x): function clip_params (line 161) | def clip_params(): class MyEasyGuide (line 178) | class MyEasyGuide(EasyGuide): method guide (line 179) | def guide(self, x): function main (line 206) | def main(args): FILE: examples/sparse_regression.py function dot (line 47) | def dot(X, Z): function kernel (line 52) | def kernel(X, Z, eta1, eta2, c): function model (line 62) | def model(X, Y, hypers, jitter=1.0e-4): function compute_posterior_stats (line 102) | def compute_posterior_stats(X, Y, msq, lam, eta1, xisq, c, sigma, jitter... function get_data (line 219) | def get_data(N=20, P=10, S=2, Q=2, sigma_obs=0.15): function init_loc_fn (line 255) | def init_loc_fn(site): function main (line 264) | def main(args): FILE: examples/svi_horovod.py class Model (line 42) | class Model(PyroModule): method __init__ (line 43) | def __init__(self, size): method forward (line 47) | def forward(self, covariates, data=None): function main (line 64) | def main(args): FILE: examples/svi_lightning.py class Model (line 30) | class Model(PyroModule): method __init__ (line 31) | def __init__(self, size): method forward (line 35) | def forward(self, covariates, data=None): class PyroLightningModule (line 53) | class PyroLightningModule(pl.LightningModule): method __init__ (line 54) | def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float): method forward (line 64) | def forward(self, *args): method training_step (line 67) | def training_step(self, batch, batch_idx): method configure_optimizers (line 74) | def configure_optimizers(self): function main (line 79) | def main(args): FILE: examples/svi_torch.py class Model (line 24) | class Model(PyroModule): method __init__ (line 25) | def __init__(self, size): method forward (line 34) | def forward(self, covariates, data=None): function main (line 48) | def main(args): FILE: examples/toy_mixture_model_discrete_enumeration.py function main (line 40) | def main(args): function generate_data (line 48) | def generate_data(num_obs): function model (line 71) | def model(prior, obs, num_obs): function guide (line 86) | def guide(prior, obs, num_obs): function train (line 95) | def train(prior, data, num_steps, num_obs): function evaluate (line 114) | def evaluate(CPDs, posterior_params): function get_true_pred_CPDs (line 129) | def get_true_pred_CPDs(CPD, posterior_param): FILE: examples/vae/ss_vae_M2.py class SSVAE (line 27) | class SSVAE(nn.Module): method __init__ (line 44) | def __init__( method setup_networks (line 69) | def setup_networks(self): method model (line 109) | def model(self, xs, ys=None): method guide (line 152) | def guide(self, xs, ys=None): method classifier (line 179) | def classifier(self, xs): method model_classify (line 198) | def model_classify(self, xs, ys=None): method guide_classify (line 214) | def guide_classify(self, xs, ys=None): function run_inference_for_epoch (line 221) | def run_inference_for_epoch(data_loaders, losses, periodic_interval_batc... function get_accuracy (line 268) | def get_accuracy(data_loader, classifier_fn, batch_size): function visualize (line 292) | def visualize(ss_vae, viz, test_loader): function main (line 298) | def main(args): FILE: examples/vae/utils/custom_mlp.py class Exp (line 12) | class Exp(nn.Module): method __init__ (line 17) | def __init__(self): method forward (line 20) | def forward(self, val): class ConcatModule (line 24) | class ConcatModule(nn.Module): method __init__ (line 29) | def __init__(self, allow_broadcast=False): method forward (line 33) | def forward(self, *input_args): class ListOutModule (line 51) | class ListOutModule(nn.ModuleList): method __init__ (line 56) | def __init__(self, modules): method forward (line 59) | def forward(self, *args, **kwargs): function call_nn_op (line 64) | def call_nn_op(op): class MLP (line 78) | class MLP(nn.Module): method __init__ (line 79) | def __init__( method forward (line 202) | def forward(self, *args, **kwargs): FILE: examples/vae/utils/mnist_cached.py function fn_x_mnist (line 20) | def fn_x_mnist(x, use_cuda): function fn_y_mnist (line 35) | def fn_y_mnist(y, use_cuda): function get_ss_indices_per_class (line 48) | def get_ss_indices_per_class(y, sup_per_class): function split_sup_unsup_valid (line 73) | def split_sup_unsup_valid(X, y, sup_num, validation_num=10000): function print_distribution_labels (line 104) | def print_distribution_labels(y): class MNISTCached (line 119) | class MNISTCached(MNIST): method __init__ (line 133) | def __init__(self, mode, sup_num, use_cuda=True, *args, **kwargs): method __getitem__ (line 201) | def __getitem__(self, index): function setup_data_loaders (line 215) | def setup_data_loaders( function mkdir_p (line 251) | def mkdir_p(path): FILE: examples/vae/utils/vae_plots.py function plot_conditional_samples_ssvae (line 7) | def plot_conditional_samples_ssvae(ssvae, visdom_session): function plot_llk (line 28) | def plot_llk(train_elbo, test_elbo): function plot_vae_samples (line 65) | def plot_vae_samples(vae, visdom_session): function mnist_test_tsne (line 78) | def mnist_test_tsne(vae=None, test_loader=None): function mnist_test_tsne_ssvae (line 89) | def mnist_test_tsne_ssvae(name=None, ssvae=None, test_loader=None): function plot_tsne (line 101) | def plot_tsne(z_loc, classes, name): FILE: examples/vae/vae.py class Encoder (line 22) | class Encoder(nn.Module): method __init__ (line 23) | def __init__(self, z_dim, hidden_dim): method forward (line 32) | def forward(self, x): class Decoder (line 47) | class Decoder(nn.Module): method __init__ (line 48) | def __init__(self, z_dim, hidden_dim): method forward (line 56) | def forward(self, z): class VAE (line 67) | class VAE(nn.Module): method __init__ (line 70) | def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False): method model (line 84) | def model(self, x): method guide (line 105) | def guide(self, x): method reconstruct_img (line 115) | def reconstruct_img(self, x): function main (line 125) | def main(args): FILE: examples/vae/vae_comparison.py class Encoder (line 35) | class Encoder(nn.Module): method __init__ (line 36) | def __init__(self): method forward (line 43) | def forward(self, x): class Decoder (line 50) | class Decoder(nn.Module): method __init__ (line 51) | def __init__(self): method forward (line 57) | def forward(self, z): class VAE (line 62) | class VAE(object, metaclass=ABCMeta): method __init__ (line 68) | def __init__(self, args, train_loader, test_loader): method set_train (line 76) | def set_train(self, is_train=True): method compute_loss_and_gradient (line 87) | def compute_loss_and_gradient(self, x): method model_eval (line 97) | def model_eval(self, x): method train (line 112) | def train(self, epoch): method test (line 124) | def test(self, epoch): class PyTorchVAEImpl (line 146) | class PyTorchVAEImpl(VAE): method __init__ (line 152) | def __init__(self, *args, **kwargs): method compute_loss_and_gradient (line 156) | def compute_loss_and_gradient(self, x): method initialize_optimizer (line 174) | def initialize_optimizer(self, lr=1e-3): class PyroVAEImpl (line 181) | class PyroVAEImpl(VAE): method __init__ (line 188) | def __init__(self, *args, **kwargs): method model (line 192) | def model(self, data): method guide (line 204) | def guide(self, data): method compute_loss_and_gradient (line 210) | def compute_loss_and_gradient(self, x): method initialize_optimizer (line 218) | def initialize_optimizer(self, lr): function setup (line 224) | def setup(args): function main (line 248) | def main(args): FILE: profiler/distributions.py function T (line 26) | def T(arr): function get_tool (line 66) | def get_tool(): function get_tool_cfg (line 70) | def get_tool_cfg(): function sample (line 82) | def sample(dist, batch_size): function log_prob (line 94) | def log_prob(dist, batch): function run_with_tool (line 98) | def run_with_tool(tool, dists, batch_sizes): function set_tool_cfg (line 127) | def set_tool_cfg(args): function main (line 139) | def main(): FILE: profiler/gaussianhmm.py function random_mvn (line 12) | def random_mvn(batch_shape, dim, requires_grad=False): function main (line 22) | def main(args): FILE: profiler/hmm.py function main (line 20) | def main(args): FILE: profiler/profiling_utils.py class ProfilePrinter (line 20) | class ProfilePrinter: method __init__ (line 21) | def __init__(self, column_widths=None, field_format=None, template="co... method _formatted_values (line 32) | def _formatted_values(self, values): method _add_using_row_format (line 41) | def _add_using_row_format(self, values): method _add_using_column_format (line 47) | def _add_using_column_format(self, values): method push (line 51) | def push(self, values): method header (line 57) | def header(self, values): method print (line 70) | def print(self): function profile_print (line 75) | def profile_print(column_widths=None, field_format=None, template="colum... function profile_timeit (line 83) | def profile_timeit(fn_callable, repeat=1): function profile_cprofile (line 88) | def profile_cprofile(fn_callable, prof_file): class Profile (line 98) | class Profile: method __init__ (line 99) | def __init__(self, tool, tool_cfg, fn_id): method _set_decorator_params (line 104) | def _set_decorator_params(self): method __call__ (line 110) | def __call__(self, fn): FILE: pyro/contrib/autoname/autoname.py function genname (line 15) | def genname(name="name"): class NameScope (line 19) | class NameScope: method __init__ (line 20) | def __init__(self, name=None): method __str__ (line 25) | def __str__(self): method allocate (line 30) | def allocate(self, name): class ScopeStack (line 36) | class ScopeStack: method __init__ (line 41) | def __init__(self): method __str__ (line 44) | def __str__(self): method global_scope (line 48) | def global_scope(self): method current_scope (line 52) | def current_scope(self): method push_scope (line 57) | def push_scope(self, scope): method pop_scope (line 61) | def pop_scope(self): method fresh_name (line 64) | def fresh_name(self, name): class AutonameMessenger (line 71) | class AutonameMessenger(ReentrantMessenger): method __init__ (line 115) | def __init__(self, name=None): method __call__ (line 119) | def __call__(self, fn_or_iter): method __enter__ (line 131) | def __enter__(self): method __exit__ (line 136) | def __exit__(self, *args): method __iter__ (line 140) | def __iter__(self): method _pyro_genname (line 148) | def _pyro_genname(msg): function autoname (line 157) | def autoname(fn=None, name=None): ... function sample (line 161) | def sample(*args): function _sample_name (line 166) | def _sample_name(name, fn, *args, **kwargs): # the current syntax of py... function _sample_dist (line 172) | def _sample_dist(fn, *args, **kwargs): FILE: pyro/contrib/autoname/named.py class Object (line 57) | class Object: method __init__ (line 81) | def __init__(self, name): method __str__ (line 85) | def __str__(self): method __getattribute__ (line 88) | def __getattribute__(self, key): method __setattr__ (line 101) | def __setattr__(self, key, value): method sample_ (line 111) | def sample_(self, fn, *args, **kwargs): method param_ (line 121) | def param_(self, *args, **kwargs): class List (line 129) | class List(list): method __init__ (line 147) | def __init__(self, name=None): method __str__ (line 150) | def __str__(self): method _set_name (line 153) | def _set_name(self, name): method add (line 160) | def add(self): method __setitem__ (line 179) | def __setitem__(self, pos, value): class Dict (line 195) | class Dict(dict): method __init__ (line 213) | def __init__(self, name=None): method __str__ (line 216) | def __str__(self): method _set_name (line 219) | def _set_name(self, name): method __getitem__ (line 226) | def __getitem__(self, key): method __setitem__ (line 239) | def __setitem__(self, key, value): FILE: pyro/contrib/autoname/scoping.py class NameCountMessenger (line 15) | class NameCountMessenger(Messenger): method __enter__ (line 20) | def __enter__(self): method _increment_name (line 24) | def _increment_name(self, name, label): method _pyro_sample (line 34) | def _pyro_sample(self, msg): method _pyro_post_sample (line 37) | def _pyro_post_sample(self, msg): method _pyro_post_scope (line 40) | def _pyro_post_scope(self, msg): method _pyro_scope (line 43) | def _pyro_scope(self, msg): class ScopeMessenger (line 47) | class ScopeMessenger(Messenger): method __init__ (line 52) | def __init__(self, prefix=None, inner=None): method _collect_scope (line 59) | def _collect_scope(prefixed_scope): method __enter__ (line 62) | def __enter__(self): method __call__ (line 73) | def __call__(self, fn): method _pyro_scope (line 84) | def _pyro_scope(self, msg): method _pyro_sample (line 87) | def _pyro_sample(self, msg): function scope (line 91) | def scope(fn=None, prefix=None, inner=None): function name_count (line 146) | def name_count(fn=None): FILE: pyro/contrib/bnn/hidden_layer.py class HiddenLayer (line 12) | class HiddenLayer(TorchDistribution): method __init__ (line 61) | def __init__( method log_prob (line 90) | def log_prob(self, value): method KL (line 94) | def KL(self): method rsample (line 101) | def rsample(self, sample_shape=torch.Size()): FILE: pyro/contrib/bnn/utils.py function xavier_uniform (line 9) | def xavier_uniform(D_in, D_out): function adjoin_ones_vector (line 15) | def adjoin_ones_vector(x): function adjoin_zeros_vector (line 19) | def adjoin_zeros_vector(x): FILE: pyro/contrib/cevae/__init__.py class FullyConnected (line 42) | class FullyConnected(nn.Sequential): method __init__ (line 47) | def __init__(self, sizes, final_activation=None): method append (line 57) | def append(self, layer): class DistributionNet (line 62) | class DistributionNet(nn.Module): method get_class (line 68) | def get_class(dtype): class BernoulliNet (line 80) | class BernoulliNet(DistributionNet): method __init__ (line 94) | def __init__(self, sizes): method forward (line 99) | def forward(self, x): method make_dist (line 104) | def make_dist(logits): class ExponentialNet (line 108) | class ExponentialNet(DistributionNet): method __init__ (line 122) | def __init__(self, sizes): method forward (line 127) | def forward(self, x): method make_dist (line 133) | def make_dist(rate): class LaplaceNet (line 137) | class LaplaceNet(DistributionNet): method __init__ (line 152) | def __init__(self, sizes): method forward (line 157) | def forward(self, x): method make_dist (line 164) | def make_dist(loc, scale): class NormalNet (line 168) | class NormalNet(DistributionNet): method __init__ (line 183) | def __init__(self, sizes): method forward (line 188) | def forward(self, x): method make_dist (line 195) | def make_dist(loc, scale): class StudentTNet (line 199) | class StudentTNet(DistributionNet): method __init__ (line 214) | def __init__(self, sizes): method forward (line 220) | def forward(self, x): method make_dist (line 228) | def make_dist(df, loc, scale): class DiagNormalNet (line 232) | class DiagNormalNet(nn.Module): method __init__ (line 250) | def __init__(self, sizes): method forward (line 256) | def forward(self, x): class PreWhitener (line 265) | class PreWhitener(nn.Module): method __init__ (line 270) | def __init__(self, data): method forward (line 279) | def forward(self, data): class Model (line 283) | class Model(PyroModule): method __init__ (line 301) | def __init__(self, config): method forward (line 319) | def forward(self, x, t=None, y=None, size=None): method y_mean (line 329) | def y_mean(self, x, t=None): method z_dist (line 336) | def z_dist(self): method x_dist (line 339) | def x_dist(self, z): method y_dist (line 343) | def y_dist(self, t, z): method t_dist (line 351) | def t_dist(self, z): class Guide (line 356) | class Guide(PyroModule): method __init__ (line 374) | def __init__(self, config): method forward (line 397) | def forward(self, x, t=None, y=None, size=None): method t_dist (line 409) | def t_dist(self, x): method y_dist (line 413) | def y_dist(self, t, x): method z_dist (line 423) | def z_dist(self, y, t, x): class TraceCausalEffect_ELBO (line 435) | class TraceCausalEffect_ELBO(Trace_ELBO): method _differentiable_loss_particle (line 443) | def _differentiable_loss_particle(self, model_trace, guide_trace): method loss (line 466) | def loss(self, model, guide, *args, **kwargs): class CEVAE (line 470) | class CEVAE(nn.Module): method __init__ (line 512) | def __init__( method fit (line 539) | def fit( method ite (line 607) | def ite(self, x, num_samples=None, batch_size=None): method to_script_module (line 648) | def to_script_module(self): FILE: pyro/contrib/conjugate/infer.py function _make_cls (line 15) | def _make_cls(base, static_attrs, instance_attrs, parent_linkage=None): function _latent (line 46) | def _latent(base, parent): function _conditional (line 52) | def _conditional(base, parent): function _compound (line 58) | def _compound(base, parent): class BetaBinomialPair (line 62) | class BetaBinomialPair: method __init__ (line 63) | def __init__(self): method latent (line 67) | def latent(self, *args, **kwargs): method conditional (line 71) | def conditional(self, *args, **kwargs): method posterior (line 75) | def posterior(self, obs): method compound (line 90) | def compound(self): class GammaPoissonPair (line 98) | class GammaPoissonPair: method __init__ (line 99) | def __init__(self): method latent (line 103) | def latent(self, *args, **kwargs): method conditional (line 107) | def conditional(self, *args, **kwargs): method posterior (line 111) | def posterior(self, obs): method compound (line 119) | def compound(self): class UncollapseConjugateMessenger (line 125) | class UncollapseConjugateMessenger(Messenger): method __init__ (line 131) | def __init__(self, trace): method _pyro_sample (line 141) | def _pyro_sample(self, msg): function uncollapse_conjugate (line 172) | def uncollapse_conjugate(fn=None, trace=None): class CollapseConjugateMessenger (line 185) | class CollapseConjugateMessenger(Messenger): method _pyro_sample (line 186) | def _pyro_sample(self, msg): function collapse_conjugate (line 198) | def collapse_conjugate(fn=None): function posterior_replay (line 210) | def posterior_replay(model, posterior_samples, *args, **kwargs): FILE: pyro/contrib/easyguide/easyguide.py class _EasyGuideMeta (line 22) | class _EasyGuideMeta(type(PyroModule), ABCMeta): class EasyGuide (line 26) | class EasyGuide(PyroModule, metaclass=_EasyGuideMeta): method __init__ (line 46) | def __init__(self, model): method model (line 56) | def model(self): method _setup_prototype (line 59) | def _setup_prototype(self, *args, **kwargs): method guide (line 75) | def guide(self, *args, **kargs): method init (line 81) | def init(self, site): method forward (line 95) | def forward(self, *args, **kwargs): method plate (line 108) | def plate( method group (line 122) | def group(self, match=".*"): method map_estimate (line 145) | def map_estimate(self, name): class Group (line 177) | class Group: method __init__ (line 189) | def __init__(self, guide, sites): method __getstate__ (line 232) | def __getstate__(self): method __setstate__ (line 237) | def __setstate__(self, state): method guide (line 242) | def guide(self): method sample (line 245) | def sample(self, guide_name, fn, infer=None): method map_estimate (line 305) | def map_estimate(self): function easy_guide (line 318) | def easy_guide(model): FILE: pyro/contrib/epidemiology/compartmental.py function _require_double_precision (line 57) | def _require_double_precision(): function _disallow_latent_variables (line 67) | def _disallow_latent_variables(section_name): class CompartmentalModel (line 81) | class CompartmentalModel(ABC): method __init__ (line 150) | def __init__(self, compartments, duration, population, *, approximate=... method time_plate (line 183) | def time_plate(self): method region_plate (line 194) | def region_plate(self): method _clear_plates (line 206) | def _clear_plates(self): method full_mass (line 211) | def full_mass(self): method series (line 226) | def series(self): method global_model (line 248) | def global_model(self): method initialize (line 258) | def initialize(self, params): method transition (line 269) | def transition(self, params, state, t): method finalize (line 297) | def finalize(self, params, prev, curr): method compute_flows (line 322) | def compute_flows(self, prev, curr, t): method generate (line 361) | def generate(self, fixed={}): method fit_svi (line 384) | def fit_svi( method fit_mcmc (line 534) | def fit_mcmc(self, **options): method predict (line 663) | def predict(self, forecast=0): method heuristic (line 737) | def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10): method _heuristic (line 788) | def _heuristic(self, haar, **options): method _concat_series (line 804) | def _concat_series(self, samples, trace, forecast=0): method _non_compartmental (line 829) | def _non_compartmental(self): method _sample_auxiliary (line 861) | def _sample_auxiliary(self): method _transition_bwd (line 900) | def _transition_bwd(self, params, prev, curr, t): method _generative_model (line 921) | def _generative_model(self, forecast=0): method _sequential_model (line 948) | def _sequential_model(self): method _quantized_model (line 1000) | def _quantized_model(self): method _relaxed_model (line 1097) | def _relaxed_model(self): class _SMCModel (line 1138) | class _SMCModel: method __init__ (line 1143) | def __init__(self, model): method init (line 1147) | def init(self, state): method step (line 1158) | def step(self, state): class _SMCGuide (line 1180) | class _SMCGuide(_SMCModel): method init (line 1185) | def init(self, state): method step (line 1188) | def step(self, state): class _HaarSplitReparam (line 1193) | class _HaarSplitReparam: method __init__ (line 1199) | def __init__(self, split, duration, dims, supports): method __bool__ (line 1206) | def __bool__(self): method reparam (line 1209) | def reparam(self, model): method aux_to_user (line 1229) | def aux_to_user(self, samples): FILE: pyro/contrib/epidemiology/distributions.py function _all (line 17) | def _all(x): function _is_zero (line 21) | def _is_zero(x): function set_approx_sample_thresh (line 26) | def set_approx_sample_thresh(thresh): function set_approx_log_prob_tol (line 56) | def set_approx_log_prob_tol(tol): function set_relaxed_distributions (line 84) | def set_relaxed_distributions(relaxed=True): function _validate_overdispersion (line 94) | def _validate_overdispersion(overdispersion): function _relaxed_binomial (line 102) | def _relaxed_binomial(total_count, probs): function _relaxed_beta_binomial (line 117) | def _relaxed_beta_binomial(concentration1, concentration0, total_count): function binomial_dist (line 137) | def binomial_dist(total_count, probs, *, overdispersion=0.0): function beta_binomial_dist (line 194) | def beta_binomial_dist( function poisson_dist (line 230) | def poisson_dist(rate, *, overdispersion=0.0): function negative_binomial_dist (line 237) | def negative_binomial_dist( function infection_dist (line 246) | def infection_dist( FILE: pyro/contrib/epidemiology/models.py class SimpleSIRModel (line 16) | class SimpleSIRModel(CompartmentalModel): method __init__ (line 35) | def __init__(self, population, recovery_time, data): method global_model (line 46) | def global_model(self): method initialize (line 52) | def initialize(self, params): method transition (line 56) | def transition(self, params, state, t): class SimpleSEIRModel (line 84) | class SimpleSEIRModel(CompartmentalModel): method __init__ (line 106) | def __init__(self, population, incubation_time, recovery_time, data): method global_model (line 121) | def global_model(self): method initialize (line 128) | def initialize(self, params): method transition (line 132) | def transition(self, params, state, t): class SimpleSEIRDModel (line 162) | class SimpleSEIRDModel(CompartmentalModel): method __init__ (line 189) | def __init__( method global_model (line 210) | def global_model(self): method initialize (line 218) | def initialize(self, params): method transition (line 222) | def transition(self, params, state, t): method compute_flows (line 260) | def compute_flows(self, prev, curr, t): class OverdispersedSIRModel (line 275) | class OverdispersedSIRModel(CompartmentalModel): method __init__ (line 314) | def __init__(self, population, recovery_time, data): method global_model (line 325) | def global_model(self): method initialize (line 332) | def initialize(self, params): method transition (line 336) | def transition(self, params, state, t): class OverdispersedSEIRModel (line 367) | class OverdispersedSEIRModel(CompartmentalModel): method __init__ (line 408) | def __init__(self, population, incubation_time, recovery_time, data): method global_model (line 423) | def global_model(self): method initialize (line 431) | def initialize(self, params): method transition (line 435) | def transition(self, params, state, t): class SuperspreadingSIRModel (line 470) | class SuperspreadingSIRModel(CompartmentalModel): method __init__ (line 509) | def __init__(self, population, recovery_time, data): method global_model (line 520) | def global_model(self): method initialize (line 527) | def initialize(self, params): method transition (line 531) | def transition(self, params, state, t): class SuperspreadingSEIRModel (line 560) | class SuperspreadingSEIRModel(CompartmentalModel): method __init__ (line 610) | def __init__( method global_model (line 642) | def global_model(self): method initialize (line 650) | def initialize(self, params): method transition (line 654) | def transition(self, params, state, t): class HeterogeneousSIRModel (line 696) | class HeterogeneousSIRModel(CompartmentalModel): method __init__ (line 716) | def __init__(self, population, recovery_time, data): method global_model (line 727) | def global_model(self): method initialize (line 753) | def initialize(self, params): method transition (line 759) | def transition(self, params, state, t): class SparseSIRModel (line 797) | class SparseSIRModel(CompartmentalModel): method __init__ (line 825) | def __init__(self, population, recovery_time, data, mask): method global_model (line 838) | def global_model(self): method initialize (line 844) | def initialize(self, params): method transition (line 848) | def transition(self, params, state, t): method compute_flows (line 880) | def compute_flows(self, prev, curr, t): class UnknownStartSIRModel (line 892) | class UnknownStartSIRModel(CompartmentalModel): method __init__ (line 917) | def __init__(self, population, recovery_time, pre_obs_window, data): method global_model (line 943) | def global_model(self): method initialize (line 968) | def initialize(self, params): method transition (line 972) | def transition(self, params, state, t): method predict (line 1000) | def predict(self, forecast=0): class RegionalSIRModel (line 1022) | class RegionalSIRModel(CompartmentalModel): method __init__ (line 1064) | def __init__(self, population, coupling, recovery_time, data): method global_model (line 1084) | def global_model(self): method initialize (line 1100) | def initialize(self, params): method transition (line 1107) | def transition(self, params, state, t): class HeterogeneousRegionalSIRModel (line 1144) | class HeterogeneousRegionalSIRModel(CompartmentalModel): method __init__ (line 1171) | def __init__(self, population, coupling, recovery_time, data): method global_model (line 1191) | def global_model(self): method initialize (line 1205) | def initialize(self, params): method transition (line 1217) | def transition(self, params, state, t): FILE: pyro/contrib/epidemiology/util.py function clamp (line 14) | def clamp(tensor, *, min=None, max=None): function cat2 (line 30) | def cat2(lhs, rhs, *, dim=-1): function align_samples (line 56) | def align_samples(samples, model, particle_dim): function compute_bin_probs (line 174) | def compute_bin_probs(s, num_quant_bins): function _all (line 332) | def _all(x): function _unsqueeze (line 336) | def _unsqueeze(x): function quantize (line 340) | def quantize(name, x_real, min, max, num_quant_bins=4): function quantize_enumerate (line 363) | def quantize_enumerate(x_real, min, max, num_quant_bins=4): FILE: pyro/contrib/examples/bart.py function _load_hourly_od (line 40) | def _load_hourly_od(basename): function load_bart_od (line 91) | def load_bart_od(): function load_fake_od (line 167) | def load_fake_od(): FILE: pyro/contrib/examples/finance.py function load_snp500 (line 17) | def load_snp500(): FILE: pyro/contrib/examples/multi_mnist.py function imresize (line 21) | def imresize(arr, size): function sample_one (line 25) | def sample_one(canvas_size, mnist): function sample_multi (line 42) | def sample_multi(num_digits, canvas_size, mnist): function mk_dataset (line 56) | def mk_dataset(n, mnist, max_digits, canvas_size): function load_mnist (line 67) | def load_mnist(root_path): function load (line 75) | def load(root_path): FILE: pyro/contrib/examples/nextstrain.py function load_nextstrain_counts (line 17) | def load_nextstrain_counts(map_location=None) -> dict: FILE: pyro/contrib/examples/polyphonic_data_loader.py function process_data (line 58) | def process_data(base_path, dataset, min_note=21, note_range=88): function load_data (line 100) | def load_data(dataset): function reverse_sequences (line 119) | def reverse_sequences(mini_batch, seq_lengths): function pad_and_reverse (line 131) | def pad_and_reverse(rnn_output, seq_lengths): function get_mini_batch_mask (line 139) | def get_mini_batch_mask(mini_batch, seq_lengths): function get_mini_batch (line 151) | def get_mini_batch(mini_batch_indices, sequences, seq_lengths, cuda=False): FILE: pyro/contrib/examples/scanvi_data.py class BatchDataLoader (line 18) | class BatchDataLoader(object): method __init__ (line 24) | def __init__(self, data_x, data_y, batch_size, num_classes=4, missing_... method size (line 43) | def size(self): method __len__ (line 46) | def __len__(self): method _sample_batch_indices (line 49) | def _sample_batch_indices(self): method __iter__ (line 66) | def __iter__(self): function _get_score (line 81) | def _get_score(normalized_adata, gene_set): function _get_cell_mask (line 95) | def _get_cell_mask(normalized_adata, gene_set): function get_data (line 107) | def get_data(dataset="pbmc", batch_size=100, cuda=False): FILE: pyro/contrib/examples/util.py class MNIST (line 12) | class MNIST(datasets.MNIST): method download (line 15) | def download(self) -> None: function get_data_loader (line 43) | def get_data_loader( function print_and_log (line 64) | def print_and_log(logger, msg): function get_data_directory (line 73) | def get_data_directory(filepath=None): function _mkdir_p (line 79) | def _mkdir_p(dirname): FILE: pyro/contrib/forecast/evaluate.py function eval_mae (line 19) | def eval_mae(pred, truth): function eval_rmse (line 32) | def eval_rmse(pred, truth): function eval_crps (line 46) | def eval_crps(pred, truth): function backtest (line 71) | def backtest( FILE: pyro/contrib/forecast/forecaster.py class _ForecastingModelMeta (line 33) | class _ForecastingModelMeta(type(PyroModule), ABCMeta): class ForecastingModel (line 37) | class ForecastingModel(PyroModule, metaclass=_ForecastingModelMeta): method __init__ (line 44) | def __init__(self): method model (line 49) | def model(self, zero_data, covariates): method time_plate (line 71) | def time_plate(self): method predict (line 82) | def predict(self, noise_dist, prediction): method forward (line 169) | def forward(self, data, covariates): class Forecaster (line 197) | class Forecaster(nn.Module): method __init__ (line 262) | def __init__( method __call__ (line 340) | def __call__(self, data, covariates, num_samples, batch_size=None): method forward (line 365) | def forward(self, data, covariates, num_samples, batch_size=None): class HMCForecaster (line 395) | class HMCForecaster(nn.Module): method __init__ (line 427) | def __init__( method __call__ (line 487) | def __call__(self, data, covariates, num_samples, batch_size=None): method forward (line 512) | def forward(self, data, covariates, num_samples, batch_size=None): FILE: pyro/contrib/forecast/util.py function time_reparam_dct (line 17) | def time_reparam_dct(msg): function time_reparam_haar (line 30) | def time_reparam_haar(msg): class MarkDCTParamMessenger (line 43) | class MarkDCTParamMessenger(Messenger): method __init__ (line 52) | def __init__(self, name): method _postprocess_message (line 56) | def _postprocess_message(self, msg): class PrefixWarmStartMessenger (line 70) | class PrefixWarmStartMessenger(Messenger): method _pyro_param (line 77) | def _pyro_param(self, msg): class PrefixReplayMessenger (line 113) | class PrefixReplayMessenger(Messenger): method __init__ (line 124) | def __init__(self, trace): method _pyro_post_sample (line 128) | def _pyro_post_sample(self, msg): class PrefixConditionMessenger (line 154) | class PrefixConditionMessenger(Messenger): method __init__ (line 162) | def __init__(self, data): method _pyro_sample (line 166) | def _pyro_sample(self, msg): function prefix_condition (line 205) | def prefix_condition(d, data): function _ (line 227) | def _(d, data): function _ (line 234) | def _(d, data): function _ (line 240) | def _(d, data): function _ (line 247) | def _(d, data): function _ (line 253) | def _(d, data): function _prefix_condition_univariate (line 260) | def _prefix_condition_univariate(d, data): function _ (line 271) | def _(d, data): function reshape_batch (line 279) | def reshape_batch(d, batch_shape): function _ (line 298) | def _(d, batch_shape): function _ (line 307) | def _(d, batch_shape): function _ (line 314) | def _(d, batch_shape): function _ (line 321) | def _(d, batch_shape): function _ (line 327) | def _(d, batch_shape): function _reshape_batch_univariate (line 337) | def _reshape_batch_univariate(d, batch_shape): function _ (line 351) | def _(d, batch_shape): function _ (line 359) | def _(d, batch_shape): function _ (line 388) | def _(d, batch_shape): function reshape_transform_batch (line 431) | def reshape_transform_batch(t, old_shape, new_shape): function _reshape_batch_univariate_transform (line 451) | def _reshape_batch_univariate_transform(t, old_shape, new_shape): function _ (line 463) | def _(t, old_shape, new_shape): function _ (line 468) | def _(t, old_shape, new_shape): function _ (line 480) | def _(t, old_shape, new_shape): FILE: pyro/contrib/funsor/__init__.py function plate (line 24) | def plate(*args, **kwargs): FILE: pyro/contrib/funsor/handlers/__init__.py function enum (line 26) | def enum(fn=None, first_available_dim=None): ... function markov (line 30) | def markov(fn=None, history=1, keep=False): ... function named (line 34) | def named(fn=None, first_available_dim=None): ... function plate (line 38) | def plate( function replay (line 51) | def replay(fn=None, trace=None, params=None): ... function trace (line 55) | def trace(fn=None, graph_type=None, param_only=None, pack_online=True): ... function vectorized_markov (line 59) | def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1... FILE: pyro/contrib/funsor/handlers/enum_messenger.py function _get_support_value (line 28) | def _get_support_value(funsor_dist, name, **kwargs): function _get_support_value_contraction (line 35) | def _get_support_value_contraction(funsor_dist, name, **kwargs): function _get_support_value_delta (line 46) | def _get_support_value_delta(funsor_dist, name, **kwargs): function _get_support_value_tensor (line 52) | def _get_support_value_tensor(funsor_dist, name, **kwargs): function _get_support_value_distribution (line 62) | def _get_support_value_distribution(funsor_dist, name, expand=False): function _enum_strategy_default (line 67) | def _enum_strategy_default(dist, msg): function _enum_strategy_diagonal (line 78) | def _enum_strategy_diagonal(dist, msg): function _enum_strategy_mixture (line 103) | def _enum_strategy_mixture(dist, msg): function _enum_strategy_full (line 146) | def _enum_strategy_full(dist, msg): function _enum_strategy_exact (line 156) | def _enum_strategy_exact(dist, msg): function enumerate_site (line 162) | def enumerate_site(dist, msg): class EnumMessenger (line 182) | class EnumMessenger(NamedMessenger): method _pyro_sample (line 188) | def _pyro_sample(self, msg): function queue (line 213) | def queue( FILE: pyro/contrib/funsor/handlers/named_messenger.py class NamedMessenger (line 16) | class NamedMessenger(ReentrantMessenger): method __init__ (line 27) | def __init__(self, first_available_dim=None): method __enter__ (line 35) | def __enter__(self): method __exit__ (line 48) | def __exit__(self, *args, **kwargs): method _pyro_to_data (line 61) | def _pyro_to_data(msg): method _pyro_to_funsor (line 83) | def _pyro_to_funsor(msg): class MarkovMessenger (line 117) | class MarkovMessenger(NamedMessenger): method __init__ (line 130) | def __init__(self, history=1, keep=False): method __call__ (line 137) | def __call__(self, fn): method __iter__ (line 143) | def __iter__(self): method __enter__ (line 152) | def __enter__(self): method __exit__ (line 166) | def __exit__(self, *args, **kwargs): class GlobalNamedMessenger (line 174) | class GlobalNamedMessenger(NamedMessenger): method __init__ (line 185) | def __init__(self, first_available_dim=None): method __enter__ (line 189) | def __enter__(self): method __exit__ (line 198) | def __exit__(self, *args): FILE: pyro/contrib/funsor/handlers/plate_messenger.py class IndepMessenger (line 29) | class IndepMessenger(GlobalNamedMessenger): method __init__ (line 35) | def __init__(self, name=None, size=None, dim=None, indices=None): method __enter__ (line 57) | def __enter__(self): method _pyro_sample (line 65) | def _pyro_sample(self, msg): method _pyro_param (line 69) | def _pyro_param(self, msg): class SubsampleMessenger (line 75) | class SubsampleMessenger(IndepMessenger): method __init__ (line 76) | def __init__( method _pyro_sample (line 95) | def _pyro_sample(self, msg): method _pyro_param (line 99) | def _pyro_param(self, msg): method _subsample_site_value (line 103) | def _subsample_site_value(self, value, event_dim=None): method _pyro_post_param (line 115) | def _pyro_post_param(self, msg): method _pyro_post_subsample (line 131) | def _pyro_post_subsample(self, msg): class PlateMessenger (line 136) | class PlateMessenger(SubsampleMessenger): method __enter__ (line 143) | def __enter__(self): method _pyro_sample (line 147) | def _pyro_sample(self, msg): method __iter__ (line 151) | def __iter__(self): class _SequentialPlateMessenger (line 159) | class _SequentialPlateMessenger(Messenger): method __init__ (line 164) | def __init__(self, name, size, indices, scale): method __iter__ (line 172) | def __iter__(self): method _pyro_sample (line 179) | def _pyro_sample(self, msg): method _pyro_param (line 184) | def _pyro_param(self, msg): class VectorizedMarkovMessenger (line 190) | class VectorizedMarkovMessenger(NamedMessenger): method __init__ (line 296) | def __init__(self, name=None, size=None, dim=None, history=1): method _markov_chain (line 305) | def _markov_chain(name=None, markov_vars=set(), suffixes=list()): method __iter__ (line 325) | def __iter__(self): method _pyro_sample (line 348) | def _pyro_sample(self, msg): method _pyro_post_sample (line 366) | def _pyro_post_sample(self, msg): FILE: pyro/contrib/funsor/handlers/primitives.py function to_funsor (line 9) | def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL): function to_data (line 21) | def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL): FILE: pyro/contrib/funsor/handlers/replay_messenger.py class ReplayMessenger (line 8) | class ReplayMessenger(OrigReplayMessenger): method _pyro_sample (line 15) | def _pyro_sample(self, msg): FILE: pyro/contrib/funsor/handlers/runtime.py class StackFrame (line 8) | class StackFrame: method __init__ (line 14) | def __init__(self, name_to_dim, dim_to_name, history=1, keep=False): method __setitem__ (line 28) | def __setitem__(self, key, value): method __getitem__ (line 37) | def __getitem__(self, key): method __delitem__ (line 41) | def __delitem__(self, key): method __contains__ (line 51) | def __contains__(self, key): class DimType (line 56) | class DimType(Enum): class DimStack (line 68) | class DimStack: method __init__ (line 77) | def __init__(self): method set_first_available_dim (line 93) | def set_first_available_dim(self, dim): method push_global (line 98) | def push_global(self, frame): method pop_global (line 101) | def pop_global(self): method push_iter (line 105) | def push_iter(self, frame): method pop_iter (line 108) | def pop_iter(self): method push_local (line 112) | def push_local(self, frame): method pop_local (line 115) | def pop_local(self): method global_frame (line 120) | def global_frame(self): method local_frame (line 124) | def local_frame(self): method current_write_env (line 128) | def current_write_env(self): method current_read_env (line 136) | def current_read_env(self): method _genvalue (line 147) | def _genvalue(self, key, value_request): method allocate (line 183) | def allocate(self, key_to_value_request): method names_from_batch_shape (line 227) | def names_from_batch_shape(self, batch_shape, dim_type=DimType.LOCAL): FILE: pyro/contrib/funsor/handlers/trace_messenger.py function _mask_fn (line 13) | def _mask_fn(fn, mask): class TraceMessenger (line 20) | class TraceMessenger(OrigTraceMessenger): method __init__ (line 30) | def __init__(self, graph_type=None, param_only=None, pack_online=True): method _pyro_post_sample (line 34) | def _pyro_post_sample(self, msg): method _pyro_post_markov_chain (line 81) | def _pyro_post_markov_chain(self, msg): FILE: pyro/contrib/funsor/infer/discrete.py function _sample_posterior (line 15) | def _sample_posterior(model, first_available_dim, temperature, *args, **... function infer_discrete (line 71) | def infer_discrete(model, first_available_dim=None, temperature=1): FILE: pyro/contrib/funsor/infer/elbo.py class ELBO (line 9) | class ELBO(_OrigELBO): method _get_trace (line 10) | def _get_trace(self, *args, **kwargs): method differentiable_loss (line 13) | def differentiable_loss(self, model, guide, *args, **kwargs): method loss (line 16) | def loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 19) | def loss_and_grads(self, model, guide, *args, **kwargs): class Jit_ELBO (line 25) | class Jit_ELBO(ELBO): method differentiable_loss (line 26) | def differentiable_loss(self, model, guide, *args, **kwargs): FILE: pyro/contrib/funsor/infer/trace_elbo.py class Trace_ELBO (line 19) | class Trace_ELBO(ELBO): method differentiable_loss (line 20) | def differentiable_loss(self, model, guide, *args, **kwargs): class JitTrace_ELBO (line 51) | class JitTrace_ELBO(Jit_ELBO, Trace_ELBO): FILE: pyro/contrib/funsor/infer/traceenum_elbo.py function apply_optimizer (line 20) | def apply_optimizer(x): function terms_from_trace (line 28) | def terms_from_trace(tr): class TraceMarkovEnum_ELBO (line 93) | class TraceMarkovEnum_ELBO(ELBO): method differentiable_loss (line 94) | def differentiable_loss(self, model, guide, *args, **kwargs): class TraceEnum_ELBO (line 172) | class TraceEnum_ELBO(ELBO): method differentiable_loss (line 173) | def differentiable_loss(self, model, guide, *args, **kwargs): class JitTraceEnum_ELBO (line 278) | class JitTraceEnum_ELBO(Jit_ELBO, TraceEnum_ELBO): class JitTraceMarkovEnum_ELBO (line 282) | class JitTraceMarkovEnum_ELBO(Jit_ELBO, TraceMarkovEnum_ELBO): FILE: pyro/contrib/funsor/infer/tracetmc_elbo.py class TraceTMC_ELBO (line 17) | class TraceTMC_ELBO(ELBO): method differentiable_loss (line 18) | def differentiable_loss(self, model, guide, *args, **kwargs): class JitTraceTMC_ELBO (line 53) | class JitTraceTMC_ELBO(Jit_ELBO, TraceTMC_ELBO): FILE: pyro/contrib/gp/kernels/brownian.py class Brownian (line 11) | class Brownian(Kernel): method __init__ (line 26) | def __init__(self, input_dim, variance=None, active_dims=None): method forward (line 34) | def forward(self, X, Z=None, diag=False): FILE: pyro/contrib/gp/kernels/coregionalize.py class Coregionalize (line 12) | class Coregionalize(Kernel): method __init__ (line 48) | def __init__( method forward (line 80) | def forward(self, X, Z=None, diag=False): FILE: pyro/contrib/gp/kernels/dot_product.py class DotProduct (line 11) | class DotProduct(Kernel): method __init__ (line 16) | def __init__(self, input_dim, variance=None, active_dims=None): method _dot_product (line 22) | def _dot_product(self, X, Z=None, diag=False): class Linear (line 39) | class Linear(DotProduct): method __init__ (line 53) | def __init__(self, input_dim, variance=None, active_dims=None): method forward (line 56) | def forward(self, X, Z=None, diag=False): class Polynomial (line 60) | class Polynomial(DotProduct): method __init__ (line 70) | def __init__(self, input_dim, variance=None, bias=None, degree=1, acti... method forward (line 82) | def forward(self, X, Z=None, diag=False): FILE: pyro/contrib/gp/kernels/isotropic.py function _torch_sqrt (line 11) | def _torch_sqrt(x, eps=1e-12): class Isotropy (line 20) | class Isotropy(Kernel): method __init__ (line 32) | def __init__(self, input_dim, variance=None, lengthscale=None, active_... method _square_scaled_dist (line 41) | def _square_scaled_dist(self, X, Z=None): method _scaled_dist (line 60) | def _scaled_dist(self, X, Z=None): method _diag (line 66) | def _diag(self, X): class RBF (line 73) | class RBF(Isotropy): method __init__ (line 82) | def __init__(self, input_dim, variance=None, lengthscale=None, active_... method forward (line 85) | def forward(self, X, Z=None, diag=False): class RationalQuadratic (line 93) | class RationalQuadratic(Isotropy): method __init__ (line 104) | def __init__( method forward (line 118) | def forward(self, X, Z=None, diag=False): class Exponential (line 128) | class Exponential(Isotropy): method __init__ (line 135) | def __init__(self, input_dim, variance=None, lengthscale=None, active_... method forward (line 138) | def forward(self, X, Z=None, diag=False): class Matern32 (line 146) | class Matern32(Isotropy): method __init__ (line 154) | def __init__(self, input_dim, variance=None, lengthscale=None, active_... method forward (line 157) | def forward(self, X, Z=None, diag=False): class Matern52 (line 166) | class Matern52(Isotropy): method __init__ (line 174) | def __init__(self, input_dim, variance=None, lengthscale=None, active_... method forward (line 177) | def forward(self, X, Z=None, diag=False): FILE: pyro/contrib/gp/kernels/kernel.py class Kernel (line 9) | class Kernel(Parameterized): method __init__ (line 30) | def __init__(self, input_dim, active_dims=None): method forward (line 42) | def forward(self, X, Z=None, diag=False): method _slice_input (line 57) | def _slice_input(self, X): class Combination (line 74) | class Combination(Kernel): method __init__ (line 83) | def __init__(self, kern0, kern1): class Sum (line 105) | class Sum(Combination): method forward (line 111) | def forward(self, X, Z=None, diag=False): class Product (line 118) | class Product(Combination): method forward (line 124) | def forward(self, X, Z=None, diag=False): class Transforming (line 131) | class Transforming(Kernel): method __init__ (line 139) | def __init__(self, kern): class Exponent (line 145) | class Exponent(Transforming): method forward (line 152) | def forward(self, X, Z=None, diag=False): class VerticalScaling (line 156) | class VerticalScaling(Transforming): method __init__ (line 167) | def __init__(self, kern, vscaling_fn): method forward (line 172) | def forward(self, X, Z=None, diag=False): function _Horner_evaluate (line 188) | def _Horner_evaluate(x, coef): class Warping (line 200) | class Warping(Transforming): method __init__ (line 229) | def __init__(self, kern, iwarping_fn=None, owarping_coef=None): method forward (line 248) | def forward(self, X, Z=None, diag=False): FILE: pyro/contrib/gp/kernels/periodic.py class Cosine (line 14) | class Cosine(Isotropy): method __init__ (line 23) | def __init__(self, input_dim, variance=None, lengthscale=None, active_... method forward (line 26) | def forward(self, X, Z=None, diag=False): class Periodic (line 34) | class Periodic(Kernel): method __init__ (line 51) | def __init__( method forward (line 65) | def forward(self, X, Z=None, diag=False): FILE: pyro/contrib/gp/kernels/static.py class Constant (line 11) | class Constant(Kernel): method __init__ (line 18) | def __init__(self, input_dim, variance=None, active_dims=None): method forward (line 24) | def forward(self, X, Z=None, diag=False): class WhiteNoise (line 33) | class WhiteNoise(Kernel): method __init__ (line 42) | def __init__(self, input_dim, variance=None, active_dims=None): method forward (line 48) | def forward(self, X, Z=None, diag=False): FILE: pyro/contrib/gp/likelihoods/binary.py class Binary (line 11) | class Binary(Likelihood): method __init__ (line 24) | def __init__(self, response_function=None): method forward (line 30) | def forward(self, f_loc, f_var, y=None): FILE: pyro/contrib/gp/likelihoods/gaussian.py class Gaussian (line 13) | class Gaussian(Likelihood): method __init__ (line 23) | def __init__(self, variance=None): method forward (line 29) | def forward(self, f_loc, f_var, y=None): FILE: pyro/contrib/gp/likelihoods/likelihood.py class Likelihood (line 7) | class Likelihood(Parameterized): method __init__ (line 15) | def __init__(self): method forward (line 18) | def forward(self, f_loc, f_var, y=None): FILE: pyro/contrib/gp/likelihoods/multi_class.py function _softmax (line 11) | def _softmax(x): class MultiClass (line 15) | class MultiClass(Likelihood): method __init__ (line 29) | def __init__(self, num_classes, response_function=None): method forward (line 36) | def forward(self, f_loc, f_var, y=None): FILE: pyro/contrib/gp/likelihoods/poisson.py class Poisson (line 11) | class Poisson(Likelihood): method __init__ (line 23) | def __init__(self, response_function=None): method forward (line 29) | def forward(self, f_loc, f_var, y=None): FILE: pyro/contrib/gp/models/gplvm.py class GPLVM (line 9) | class GPLVM(Parameterized): method __init__ (line 59) | def __init__(self, base_model): method model (line 75) | def model(self): method guide (line 82) | def guide(self): method forward (line 88) | def forward(self, **kwargs): FILE: pyro/contrib/gp/models/gpr.py class GPRegression (line 16) | class GPRegression(GPModel): method __init__ (line 69) | def __init__(self, X, y, kernel, noise=None, mean_function=None, jitte... method model (line 83) | def model(self): method guide (line 106) | def guide(self): method forward (line 110) | def forward(self, Xnew, full_cov=False, noiseless=True): method iter_sample (line 159) | def iter_sample(self, noiseless=True): FILE: pyro/contrib/gp/models/model.py function _zero_mean_function (line 9) | def _zero_mean_function(x): class GPModel (line 13) | class GPModel(Parameterized): method __init__ (line 92) | def __init__(self, X, y, kernel, mean_function=None, jitter=1e-6): method model (line 109) | def model(self): method guide (line 116) | def guide(self): method forward (line 123) | def forward(self, Xnew, full_cov=False): method set_data (line 144) | def set_data(self, X, y=None): method _check_Xnew_shape (line 207) | def _check_Xnew_shape(self, Xnew): FILE: pyro/contrib/gp/models/sgpr.py class SparseGPRegression (line 14) | class SparseGPRegression(GPModel): method __init__ (line 98) | def __init__( method model (line 130) | def model(self): method guide (line 178) | def guide(self): method forward (line 182) | def forward(self, Xnew, full_cov=False, noiseless=True): FILE: pyro/contrib/gp/models/vgp.py class VariationalGP (line 16) | class VariationalGP(GPModel): method __init__ (line 63) | def __init__( method model (line 100) | def model(self): method guide (line 137) | def guide(self): method forward (line 148) | def forward(self, Xnew, full_cov=False): FILE: pyro/contrib/gp/models/vsgp.py class VariationalSparseGP (line 17) | class VariationalSparseGP(GPModel): method __init__ (line 82) | def __init__( method model (line 127) | def model(self): method guide (line 174) | def guide(self): method forward (line 185) | def forward(self, Xnew, full_cov=False): FILE: pyro/contrib/gp/parameterized.py function _is_real_support (line 17) | def _is_real_support(support): function _get_sample_fn (line 24) | def _get_sample_fn(module, name): class Parameterized (line 57) | class Parameterized(PyroModule): method __init__ (line 92) | def __init__(self): method set_prior (line 98) | def set_prior(self, name, prior): method __setattr__ (line 113) | def __setattr__(self, name, value): method autoguide (line 122) | def autoguide(self, name, dist_constructor): method _load_pyro_samples (line 181) | def _load_pyro_samples(self): method set_mode (line 190) | def set_mode(self, mode): method mode (line 207) | def mode(self): method mode (line 211) | def mode(self, mode): FILE: pyro/contrib/gp/util.py function conditional (line 10) | def conditional( function train (line 161) | def train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num... FILE: pyro/contrib/minipyro.py function get_param_store (line 38) | def get_param_store(): class Messenger (line 43) | class Messenger: method __init__ (line 44) | def __init__(self, fn=None): method __enter__ (line 49) | def __enter__(self): method __exit__ (line 52) | def __exit__(self, *args, **kwargs): method process_message (line 56) | def process_message(self, msg): method postprocess_message (line 59) | def postprocess_message(self, msg): method __call__ (line 62) | def __call__(self, *args, **kwargs): class trace (line 70) | class trace(Messenger): method __enter__ (line 71) | def __enter__(self): method postprocess_message (line 78) | def postprocess_message(self, msg): method get_trace (line 84) | def get_trace(self, *args, **kwargs): class replay (line 94) | class replay(Messenger): method __init__ (line 95) | def __init__(self, fn, guide_trace): method process_message (line 99) | def process_message(self, msg): class block (line 107) | class block(Messenger): method __init__ (line 108) | def __init__(self, fn=None, hide_fn=lambda msg: True): method process_message (line 112) | def process_message(self, msg): class seed (line 118) | class seed(Messenger): method __init__ (line 119) | def __init__(self, fn=None, rng_seed=None): method __enter__ (line 123) | def __enter__(self): method __exit__ (line 133) | def __exit__(self, type, value, traceback): class PlateMessenger (line 143) | class PlateMessenger(Messenger): method __init__ (line 144) | def __init__(self, fn, size, dim): method process_message (line 150) | def process_message(self, msg): method __iter__ (line 158) | def __iter__(self): function apply_stack (line 164) | def apply_stack(msg): function sample (line 186) | def sample(name, fn, *args, **kwargs): function param (line 210) | def param( function plate (line 259) | def plate(name, size, dim=None): class Adam (line 268) | class Adam: method __init__ (line 269) | def __init__(self, optim_args): method __call__ (line 275) | def __call__(self, params): class SVI (line 293) | class SVI: method __init__ (line 294) | def __init__(self, model, guide, optim, loss): method step (line 302) | def step(self, *args, **kwargs): function elbo (line 328) | def elbo(model, guide, *args, **kwargs): function Trace_ELBO (line 358) | def Trace_ELBO(**kwargs): class JitTrace_ELBO (line 365) | class JitTrace_ELBO: method __init__ (line 366) | def __init__(self, **kwargs): method __call__ (line 371) | def __call__(self, model, guide, *args): FILE: pyro/contrib/mue/dataloaders.py class BiosequenceDataset (line 37) | class BiosequenceDataset(Dataset): method __init__ (line 56) | def __init__( method _load_fasta (line 102) | def _load_fasta(self, source): method _one_hot (line 122) | def _one_hot(self, seq, alphabet, length): method __len__ (line 136) | def __len__(self): method __getitem__ (line 139) | def __getitem__(self, ind): function write (line 143) | def write(x, alphabet, file, truncate_stop=False, append=False, scores=N... FILE: pyro/contrib/mue/missingdatahmm.py class MissingDataDiscreteHMM (line 13) | class MissingDataDiscreteHMM(TorchDistribution): method __init__ (line 47) | def __init__( method log_prob (line 85) | def log_prob(self, value): method sample (line 115) | def sample(self, sample_shape=torch.Size([])): method filter (line 147) | def filter(self, value): method smooth (line 188) | def smooth(self, value): method sample_states (line 220) | def sample_states(self, value): method map_states (line 244) | def map_states(self, value): method given_states (line 290) | def given_states(self, states): method sample_given_states (line 308) | def sample_given_states(self, states): FILE: pyro/contrib/mue/models.py class ProfileHMM (line 26) | class ProfileHMM(nn.Module): method __init__ (line 47) | def __init__( method model (line 79) | def model(self, seq_data, local_scale): method guide (line 132) | def guide(self, seq_data, local_scale): method fit_svi (line 173) | def fit_svi( method evaluate (line 242) | def evaluate(self, dataset_train, dataset_test=None, jit=False): method _local_variables (line 276) | def _local_variables(self, name, site): method _evaluate_local_elbo (line 280) | def _evaluate_local_elbo(self, svi, dataload, data_size): class Encoder (line 309) | class Encoder(nn.Module): method __init__ (line 310) | def __init__(self, data_length, alphabet_length, z_dim): method forward (line 317) | def forward(self, data): class FactorMuE (line 325) | class FactorMuE(nn.Module): method __init__ (line 371) | def __init__( method decoder (line 452) | def decoder(self, z, W, B, inverse_temp): method model (line 488) | def model(self, seq_data, local_scale, local_prior_scale): method guide (line 610) | def guide(self, seq_data, local_scale, local_prior_scale): method fit_svi (line 681) | def fit_svi( method _beta_anneal (line 763) | def _beta_anneal(self, step, batch_size, data_size, anneal_length): method evaluate (line 770) | def evaluate(self, dataset_train, dataset_test=None, jit=False): method _local_variables (line 809) | def _local_variables(self, name, site): method _evaluate_local_elbo (line 813) | def _evaluate_local_elbo(self, svi, dataload, data_size): method embed (line 841) | def embed(self, dataset, batch_size=None): method _reconstruct_regressor_seq (line 863) | def _reconstruct_regressor_seq(self, data, ind, param): FILE: pyro/contrib/mue/statearrangers.py class Profile (line 8) | class Profile(nn.Module): method __init__ (line 32) | def __init__(self, M, epsilon=1e-32): method _make_transfer (line 40) | def _make_transfer(self): method forward (line 135) | def forward( function mg2k (line 205) | def mg2k(m, g, M): FILE: pyro/contrib/oed/eig.py function laplace_eig (line 29) | def laplace_eig( function _eig_from_ape (line 87) | def _eig_from_ape(model, design, target_labels, ape, eig, prior_entropy_... function _laplace_vi_ape (line 108) | def _laplace_vi_ape( function vi_eig (line 152) | def vi_eig( function _vi_ape (line 230) | def _vi_ape( function nmc_eig (line 268) | def nmc_eig( function donsker_varadhan_eig (line 376) | def donsker_varadhan_eig( function posterior_eig (line 442) | def posterior_eig( function _posterior_ape (line 525) | def _posterior_ape( function marginal_eig (line 555) | def marginal_eig( function marginal_likelihood_eig (line 620) | def marginal_likelihood_eig( function lfire_eig (line 683) | def lfire_eig( function vnmc_eig (line 756) | def vnmc_eig( function opt_eig_ape_loss (line 826) | def opt_eig_ape_loss( function monte_carlo_entropy (line 869) | def monte_carlo_entropy(model, design, target_labels, num_prior_samples=... function _donsker_varadhan_loss (line 884) | def _donsker_varadhan_loss(model, T, observation_labels, target_labels): function _posterior_loss (line 927) | def _posterior_loss( function _marginal_loss (line 965) | def _marginal_loss(model, guide, observation_labels, target_labels): function _marginal_likelihood_loss (line 994) | def _marginal_likelihood_loss( function _lfire_loss (line 1034) | def _lfire_loss( function _vnmc_eig_loss (line 1082) | def _vnmc_eig_loss(model, guide, observation_labels, target_labels): function _safe_mean_terms (line 1123) | def _safe_mean_terms(terms): function xexpx (line 1135) | def xexpx(a): class _EwmaLogFn (line 1149) | class _EwmaLogFn(torch.autograd.Function): method forward (line 1151) | def forward(ctx, input, ewma): method backward (line 1156) | def backward(ctx, grad_output): class EwmaLog (line 1164) | class EwmaLog: method __init__ (line 1180) | def __init__(self, alpha): method __call__ (line 1186) | def __call__(self, inputs, s, dim=0, keepdim=False): FILE: pyro/contrib/oed/glmm/glmm.py function known_covariance_linear_model (line 22) | def known_covariance_linear_model( function normal_guide (line 57) | def normal_guide(observation_sd, coef_shape, coef_label="w"): function group_linear_model (line 65) | def group_linear_model( function group_normal_guide (line 92) | def group_normal_guide( function zero_mean_unit_obs_sd_lm (line 102) | def zero_mean_unit_obs_sd_lm(coef_sd, coef_label="w"): function normal_inverse_gamma_linear_model (line 110) | def normal_inverse_gamma_linear_model( function normal_inverse_gamma_guide (line 123) | def normal_inverse_gamma_guide(coef_shape, coef_label="w", **kwargs): function logistic_regression_model (line 132) | def logistic_regression_model( function lmer_model (line 145) | def lmer_model( function sigmoid_model (line 168) | def sigmoid_model( function bayesian_linear_model (line 209) | def bayesian_linear_model( function normal_inv_gamma_family_guide (line 348) | def normal_inv_gamma_family_guide(design, obs_sd, w_sizes, mf=False): function group_assignment_matrix (line 409) | def group_assignment_matrix(design): function rf_group_assignments (line 432) | def rf_group_assignments(n, random_intercept=True): function analytic_posterior_cov (line 448) | def analytic_posterior_cov(prior_cov, x, obs_sd): function broadcast_cat (line 464) | def broadcast_cat(ws): FILE: pyro/contrib/oed/glmm/guides.py class LinearModelPosteriorGuide (line 23) | class LinearModelPosteriorGuide(nn.Module): method __init__ (line 24) | def __init__( method get_params (line 71) | def get_params(self, y_dict, design, target_labels): method linear_model_formula (line 75) | def linear_model_formula(self, y, design, target_labels): method forward (line 84) | def forward(self, y_dict, design, observation_labels, target_labels): class LinearModelLaplaceGuide (line 95) | class LinearModelLaplaceGuide(nn.Module): method __init__ (line 107) | def __init__(self, d, w_sizes, tau_label=None, init_value=0.1, **kwargs): method _hessian_diag (line 124) | def _hessian_diag(y, x, event_shape): method finalize (line 164) | def finalize(self, loss, target_labels): method forward (line 184) | def forward(self, design, target_labels=None): class SigmoidGuide (line 214) | class SigmoidGuide(LinearModelPosteriorGuide): method __init__ (line 215) | def __init__(self, d, n, w_sizes, **kwargs): method get_params (line 221) | def get_params(self, y_dict, design, target_labels): class NormalInverseGammaGuide (line 235) | class NormalInverseGammaGuide(LinearModelPosteriorGuide): method __init__ (line 236) | def __init__( method get_params (line 252) | def get_params(self, y_dict, design, target_labels): method forward (line 267) | def forward(self, y_dict, design, observation_labels, target_labels): class GuideDV (line 290) | class GuideDV(nn.Module): method __init__ (line 295) | def __init__(self, guide): method forward (line 299) | def forward(self, design, trace, observation_labels, target_labels): FILE: pyro/contrib/oed/search.py class Search (line 14) | class Search(TracePosterior): method __init__ (line 19) | def __init__(self, model, max_tries=int(1e6), **kwargs): method _traces (line 24) | def _traces(self, *args, **kwargs): FILE: pyro/contrib/oed/util.py function linear_model_ground_truth (line 13) | def linear_model_ground_truth( FILE: pyro/contrib/randomvariable/random_variable.py class RVMagicOps (line 21) | class RVMagicOps: method __add__ (line 24) | def __add__(self, x: Union[float, Tensor]): method __radd__ (line 29) | def __radd__(self, x: Union[float, Tensor]): method __sub__ (line 34) | def __sub__(self, x: Union[float, Tensor]): method __rsub__ (line 39) | def __rsub__(self, x: Union[float, Tensor]): method __mul__ (line 44) | def __mul__(self, x: Union[float, Tensor]): method __rmul__ (line 49) | def __rmul__(self, x: Union[float, Tensor]): method __truediv__ (line 54) | def __truediv__(self, x: Union[float, Tensor]): method __neg__ (line 59) | def __neg__(self): method __abs__ (line 64) | def __abs__(self): method __pow__ (line 69) | def __pow__(self, x): class RVChainOps (line 75) | class RVChainOps: method add (line 80) | def add(self, x): method sub (line 83) | def sub(self, x): method mul (line 86) | def mul(self, x): method div (line 89) | def div(self, x): method abs (line 92) | def abs(self): method pow (line 95) | def pow(self, x): method neg (line 98) | def neg(self): method exp (line 101) | def exp(self): method log (line 104) | def log(self): method sigmoid (line 107) | def sigmoid(self): method tanh (line 110) | def tanh(self): method softmax (line 113) | def softmax(self): class RandomVariable (line 117) | class RandomVariable(RVMagicOps, RVChainOps): method __init__ (line 144) | def __init__(self, distribution): method transform (line 152) | def transform(self, t: Transform): method dist (line 168) | def dist(self): FILE: pyro/contrib/timeseries/base.py class TimeSeriesModel (line 7) | class TimeSeriesModel(PyroModule): method log_prob (line 13) | def log_prob(self, targets): method forecast (line 28) | def forecast(self, targets, dts): method get_dist (line 43) | def get_dist(self): FILE: pyro/contrib/timeseries/gp.py class IndependentMaternGP (line 17) | class IndependentMaternGP(TimeSeriesModel): method __init__ (line 35) | def __init__( method _get_init_dist (line 68) | def _get_init_dist(self): method _get_obs_dist (line 74) | def _get_obs_dist(self): method get_dist (line 80) | def get_dist(self, duration=None): method log_prob (line 107) | def log_prob(self, targets): method _filter (line 118) | def _filter(self, targets): method _forecast (line 126) | def _forecast(self, dts, filtering_state, include_observation_noise=Tr... method forecast (line 154) | def forecast(self, targets, dts): class LinearlyCoupledMaternGP (line 171) | class LinearlyCoupledMaternGP(TimeSeriesModel): method __init__ (line 196) | def __init__( method _get_obs_matrix (line 235) | def _get_obs_matrix(self): method _stationary_covariance (line 245) | def _stationary_covariance(self): method _get_init_dist (line 248) | def _get_init_dist(self): method _get_obs_dist (line 252) | def _get_obs_dist(self): method get_dist (line 256) | def get_dist(self, duration=None): method log_prob (line 282) | def log_prob(self, targets): method _filter (line 293) | def _filter(self, targets): method _forecast (line 301) | def _forecast( method forecast (line 338) | def forecast(self, targets, dts): class DependentMaternGP (line 356) | class DependentMaternGP(TimeSeriesModel): method __init__ (line 379) | def __init__( method _get_obs_matrix (line 428) | def _get_obs_matrix(self): method _get_init_dist (line 438) | def _get_init_dist(self, stationary_covariance): method _get_obs_dist (line 443) | def _get_obs_dist(self): method _get_wiener_cov (line 448) | def _get_wiener_cov(self): method _stationary_covariance (line 456) | def _stationary_covariance(self): method _get_trans_dist (line 470) | def _get_trans_dist(self, trans_matrix, stationary_covariance): method _trans_matrix_distribution_stat_covar (line 477) | def _trans_matrix_distribution_stat_covar(self, dts): method get_dist (line 484) | def get_dist(self, duration=None): method log_prob (line 507) | def log_prob(self, targets): method _filter (line 518) | def _filter(self, targets): method _forecast (line 526) | def _forecast(self, dts, filtering_state, include_observation_noise=Tr... method forecast (line 554) | def forecast(self, targets, dts): FILE: pyro/contrib/timeseries/lgssm.py class GenericLGSSM (line 14) | class GenericLGSSM(TimeSeriesModel): method __init__ (line 26) | def __init__( method _get_init_dist (line 61) | def _get_init_dist(self): method _get_obs_dist (line 65) | def _get_obs_dist(self): method _get_trans_dist (line 68) | def _get_trans_dist(self): method get_dist (line 72) | def get_dist(self, duration=None): method log_prob (line 90) | def log_prob(self, targets): method _filter (line 101) | def _filter(self, targets): method _forecast (line 109) | def _forecast(self, N_timesteps, filtering_state, include_observation_... method forecast (line 142) | def forecast(self, targets, N_timesteps): FILE: pyro/contrib/timeseries/lgssmgp.py class GenericLGSSMWithGPNoiseModel (line 15) | class GenericLGSSMWithGPNoiseModel(TimeSeriesModel): method __init__ (line 44) | def __init__( method _get_obs_matrix (line 102) | def _get_obs_matrix(self): method _get_init_dist (line 106) | def _get_init_dist(self): method _get_obs_dist (line 117) | def _get_obs_dist(self): method get_dist (line 120) | def get_dist(self, duration=None): method log_prob (line 165) | def log_prob(self, targets): method _filter (line 176) | def _filter(self, targets): method _forecast (line 184) | def _forecast(self, N_timesteps, filtering_state, include_observation_... method forecast (line 260) | def forecast(self, targets, N_timesteps): FILE: pyro/contrib/tracking/assignment.py function _product (line 14) | def _product(factors): function _exp (line 21) | def _exp(value): class MarginalAssignment (line 27) | class MarginalAssignment: method __init__ (line 56) | def __init__(self, exists_logits, assign_logits, bp_iters=None): class MarginalAssignmentSparse (line 81) | class MarginalAssignmentSparse: method __init__ (line 108) | def __init__( class MarginalAssignmentPersistent (line 142) | class MarginalAssignmentPersistent: method __init__ (line 180) | def __init__(self, exists_logits, assign_logits, bp_iters=None, bp_mom... function compute_marginals (line 207) | def compute_marginals(exists_logits, assign_logits): function compute_marginals_bp (line 249) | def compute_marginals_bp(exists_logits, assign_logits, bp_iters): function compute_marginals_sparse_bp (line 284) | def compute_marginals_sparse_bp( function compute_marginals_persistent (line 334) | def compute_marginals_persistent(exists_logits, assign_logits): function compute_marginals_persistent_bp (line 389) | def compute_marginals_persistent_bp( FILE: pyro/contrib/tracking/distributions.py class EKFDistribution (line 13) | class EKFDistribution(TorchDistribution): method __init__ (line 38) | def __init__( method rsample (line 60) | def rsample(self, sample_shape=torch.Size()): method filter_states (line 63) | def filter_states(self, value): method log_prob (line 83) | def log_prob(self, value): FILE: pyro/contrib/tracking/dynamic_models.py class DynamicModel (line 14) | class DynamicModel(nn.Module, metaclass=ABCMeta): method __init__ (line 25) | def __init__(self, dimension, dimension_pv, num_process_noise_paramete... method dimension (line 32) | def dimension(self): method dimension_pv (line 39) | def dimension_pv(self): method num_process_noise_parameters (line 46) | def num_process_noise_parameters(self): method forward (line 53) | def forward(self, x, dt, do_normalization=True): method geodesic_difference (line 67) | def geodesic_difference(self, x1, x0): method mean2pv (line 79) | def mean2pv(self, x): method cov2pv (line 91) | def cov2pv(self, P): method process_noise_cov (line 103) | def process_noise_cov(self, dt=0.0): method process_noise_dist (line 115) | def process_noise_dist(self, dt=0.0): class DifferentiableDynamicModel (line 129) | class DifferentiableDynamicModel(DynamicModel): method jacobian (line 136) | def jacobian(self, dt): class Ncp (line 147) | class Ncp(DifferentiableDynamicModel): method __init__ (line 159) | def __init__(self, dimension, sv2): method forward (line 168) | def forward(self, x, dt, do_normalization=True): method mean2pv (line 182) | def mean2pv(self, x): method cov2pv (line 196) | def cov2pv(self, P): method jacobian (line 211) | def jacobian(self, dt): method process_noise_cov (line 222) | def process_noise_cov(self, dt=0.0): class Ncv (line 233) | class Ncv(DifferentiableDynamicModel): method __init__ (line 245) | def __init__(self, dimension, sa2): method forward (line 254) | def forward(self, x, dt, do_normalization=True): method mean2pv (line 270) | def mean2pv(self, x): method cov2pv (line 281) | def cov2pv(self, P): method jacobian (line 292) | def jacobian(self, dt): method process_noise_cov (line 310) | def process_noise_cov(self, dt=0.0): class NcpContinuous (line 321) | class NcpContinuous(Ncp): method process_noise_cov (line 336) | def process_noise_cov(self, dt=0.0): class NcvContinuous (line 355) | class NcvContinuous(Ncv): method process_noise_cov (line 370) | def process_noise_cov(self, dt=0.0): class NcpDiscrete (line 398) | class NcpDiscrete(Ncp): method process_noise_cov (line 413) | def process_noise_cov(self, dt=0.0): class NcvDiscrete (line 428) | class NcvDiscrete(Ncv): method process_noise_cov (line 443) | def process_noise_cov(self, dt=0.0): FILE: pyro/contrib/tracking/extended_kalman_filter.py class EKFState (line 11) | class EKFState: method __init__ (line 27) | def __init__(self, dynamic_model, mean, cov, time=None, frame_num=None): method dynamic_model (line 37) | def dynamic_model(self): method dimension (line 44) | def dimension(self): method mean (line 51) | def mean(self): method cov (line 58) | def cov(self): method dimension_pv (line 65) | def dimension_pv(self): method mean_pv (line 72) | def mean_pv(self): method cov_pv (line 79) | def cov_pv(self): method time (line 86) | def time(self): method frame_num (line 93) | def frame_num(self): method predict (line 99) | def predict(self, dt=None, destination_time=None, destination_frame_nu... method innovation (line 139) | def innovation(self, measurement): method log_likelihood_of_update (line 165) | def log_likelihood_of_update(self, measurement): method update (line 180) | def update(self, measurement): FILE: pyro/contrib/tracking/hashing.py class LSH (line 12) | class LSH: method __init__ (line 48) | def __init__(self, radius): method _hash (line 57) | def _hash(self, point): method add (line 61) | def add(self, key, point): method remove (line 75) | def remove(self, key): method nearby (line 88) | def nearby(self, key): class ApproxSet (line 110) | class ApproxSet: method __init__ (line 119) | def __init__(self, radius): method _hash (line 127) | def _hash(self, point): method try_add (line 131) | def try_add(self, point): function merge_points (line 147) | def merge_points(points, radius): FILE: pyro/contrib/tracking/measurements.py class Measurement (line 11) | class Measurement(object, metaclass=ABCMeta): method __init__ (line 23) | def __init__(self, mean, cov, time=None, frame_num=None): method dimension (line 33) | def dimension(self): method mean (line 40) | def mean(self): method cov (line 47) | def cov(self): method time (line 54) | def time(self): method frame_num (line 61) | def frame_num(self): method __call__ (line 68) | def __call__(self, x, do_normalization=True): method geodesic_difference (line 80) | def geodesic_difference(self, z1, z0): class DifferentiableMeasurement (line 92) | class DifferentiableMeasurement(Measurement): method jacobian (line 99) | def jacobian(self, x=None): class PositionMeasurement (line 111) | class PositionMeasurement(DifferentiableMeasurement): method __init__ (line 120) | def __init__(self, mean, cov, time=None, frame_num=None): method __call__ (line 132) | def __call__(self, x, do_normalization=True): method jacobian (line 144) | def jacobian(self, x=None): FILE: pyro/contrib/util.py function get_indices (line 12) | def get_indices(labels, sizes=None, tensors=None): function tensor_to_dict (line 25) | def tensor_to_dict(sizes, tensor, subset=None): function rmm (line 38) | def rmm(A, B): function rmv (line 43) | def rmv(A, b): function rvv (line 48) | def rvv(a, b): function lexpand (line 53) | def lexpand(A, *dimensions): function rexpand (line 58) | def rexpand(A, *dimensions): function rdiag (line 63) | def rdiag(v): function rtril (line 68) | def rtril(M, diagonal=0, upper=False): function iter_plates_to_shape (line 75) | def iter_plates_to_shape(shape): function check_no_weakref (line 81) | def check_no_weakref(obj, path="", avoid_ids=None): FILE: pyro/contrib/zuko.py class ZukoToPyro (line 18) | class ZukoToPyro(pyro.distributions.TorchDistribution): method __init__ (line 48) | def __init__(self, dist: torch.distributions.Distribution): method has_rsample (line 53) | def has_rsample(self) -> bool: method event_shape (line 57) | def event_shape(self) -> Size: method batch_shape (line 61) | def batch_shape(self) -> Size: method __call__ (line 64) | def __call__(self, shape: Size = ()) -> Tensor: method log_prob (line 74) | def log_prob(self, x: Tensor) -> Tensor: method expand (line 80) | def expand(self, *args, **kwargs): FILE: pyro/distributions/affine_beta.py class AffineBeta (line 12) | class AffineBeta(TransformedDistribution): method __init__ (line 39) | def __init__(self, concentration1, concentration0, loc, scale, validat... method infer_shapes (line 48) | def infer_shapes(concentration1, concentration0, loc, scale): method expand (line 53) | def expand(self, batch_shape, _instance=None): method sample (line 57) | def sample(self, sample_shape=torch.Size()): method rsample (line 71) | def rsample(self, sample_shape=torch.Size()): method support (line 85) | def support(self): method concentration1 (line 89) | def concentration1(self): method concentration0 (line 93) | def concentration0(self): method sample_size (line 97) | def sample_size(self): method loc (line 101) | def loc(self): method scale (line 105) | def scale(self): method low (line 109) | def low(self): method high (line 113) | def high(self): method mean (line 117) | def mean(self): method variance (line 121) | def variance(self): FILE: pyro/distributions/asymmetriclaplace.py class AsymmetricLaplace (line 13) | class AsymmetricLaplace(TorchDistribution): method __init__ (line 36) | def __init__(self, loc, scale, asymmetry, *, validate_args=None): method left_scale (line 41) | def left_scale(self): method right_scale (line 45) | def right_scale(self): method expand (line 48) | def expand(self, batch_shape, _instance=None): method log_prob (line 58) | def log_prob(self, value): method rsample (line 65) | def rsample(self, sample_shape=torch.Size()): method mean (line 71) | def mean(self): method variance (line 76) | def variance(self): class SoftAsymmetricLaplace (line 85) | class SoftAsymmetricLaplace(TorchDistribution): method __init__ (line 121) | def __init__(self, loc, scale, asymmetry=1.0, softness=1.0, *, validat... method left_scale (line 131) | def left_scale(self): method right_scale (line 135) | def right_scale(self): method soft_scale (line 139) | def soft_scale(self): method expand (line 142) | def expand(self, batch_shape, _instance=None): method log_prob (line 153) | def log_prob(self, value): method rsample (line 182) | def rsample(self, sample_shape=torch.Size()): method mean (line 191) | def mean(self): method variance (line 196) | def variance(self): function _logerfc (line 205) | def _logerfc(x): FILE: pyro/distributions/avf_mvn.py class AVFMultivariateNormal (line 13) | class AVFMultivariateNormal(MultivariateNormal): method __init__ (line 48) | def __init__(self, loc, scale_tril, control_var): method rsample (line 64) | def rsample(self, sample_shape=torch.Size()): class _AVFMVNSample (line 70) | class _AVFMVNSample(Function): method forward (line 72) | def forward(ctx, loc, scale_tril, control_var, shape): method backward (line 80) | def backward(ctx, grad_output): FILE: pyro/distributions/coalescent.py class CoalescentTimesConstraint (line 17) | class CoalescentTimesConstraint(constraints.Constraint): method __init__ (line 18) | def __init__(self, leaf_times, *, ordered=True): method check (line 22) | def check(self, value): class CoalescentTimes (line 35) | class CoalescentTimes(TorchDistribution): method __init__ (line 65) | def __init__(self, leaf_times, rate=1.0, *, validate_args=None): method support (line 74) | def support(self): method log_prob (line 77) | def log_prob(self, value): method sample (line 96) | def sample(self, sample_shape=torch.Size()): class CoalescentTimesWithRate (line 102) | class CoalescentTimesWithRate(TorchDistribution): method __init__ (line 149) | def __init__(self, leaf_times, rate_grid, *, validate_args=None): method support (line 157) | def support(self): method duration (line 161) | def duration(self): method expand (line 164) | def expand(self, batch_shape, _instance=None): method log_prob (line 174) | def log_prob(self, value): class CoalescentRateLikelihood (line 213) | class CoalescentRateLikelihood: method __init__ (line 249) | def __init__(self, leaf_times, coal_times, duration, *, validate_args=... method __call__ (line 292) | def __call__(self, rate_grid, t=slice(None)): function bio_phylo_to_times (line 326) | def bio_phylo_to_times(tree, *, get_time=None): function _gather (line 374) | def _gather(tensor, dim, index): function _interpolate_gather (line 386) | def _interpolate_gather(array, x): function _interpolate_scatter_add_ (line 399) | def _interpolate_scatter_add_(dst, x, src): function _weak_memoize (line 412) | def _weak_memoize(fn): function _make_phylogeny (line 450) | def _make_phylogeny(leaf_times, coal_times): function _sample_coalescent_times (line 487) | def _sample_coalescent_times(leaf_times): FILE: pyro/distributions/conditional.py class ConditionalDistribution (line 13) | class ConditionalDistribution(ABC): method condition (line 15) | def condition(self, context): class ConditionalTransform (line 20) | class ConditionalTransform(ABC): method condition (line 22) | def condition(self, context): class ConditionalTransformModule (line 27) | class ConditionalTransformModule(ConditionalTransform, torch.nn.Module): method __init__ (line 34) | def __init__(self, *args, **kwargs): method __hash__ (line 37) | def __hash__(self): method inv (line 41) | def inv(self) -> "ConditionalTransformModule": class _ConditionalInverseTransformModule (line 45) | class _ConditionalInverseTransformModule(ConditionalTransformModule): method __init__ (line 46) | def __init__(self, transform: ConditionalTransform): method inv (line 51) | def inv(self) -> ConditionalTransform: method condition (line 54) | def condition(self, context: torch.Tensor): class ConditionalComposeTransformModule (line 58) | class ConditionalComposeTransformModule( method __init__ (line 83) | def __init__(self, transforms, cache_size: int = 0): method condition (line 101) | def condition(self, context: torch.Tensor) -> ComposeTransformModule: class ConstantConditionalDistribution (line 107) | class ConstantConditionalDistribution(ConditionalDistribution): method __init__ (line 108) | def __init__(self, base_dist): method condition (line 112) | def condition(self, context): class ConstantConditionalTransform (line 116) | class ConstantConditionalTransform(ConditionalTransform): method __init__ (line 117) | def __init__(self, transform): method condition (line 121) | def condition(self, context): method clear_cache (line 124) | def clear_cache(self): class ConditionalTransformedDistribution (line 128) | class ConditionalTransformedDistribution(ConditionalDistribution): method __init__ (line 129) | def __init__(self, base_dist, transforms): method condition (line 144) | def condition(self, context): method clear_cache (line 149) | def clear_cache(self): FILE: pyro/distributions/conjugate.py function _log_beta_1 (line 17) | def _log_beta_1(alpha, value, is_sparse): class BetaBinomial (line 34) | class BetaBinomial(TorchDistribution): method __init__ (line 65) | def __init__( method concentration1 (line 76) | def concentration1(self): method concentration0 (line 80) | def concentration0(self): method expand (line 83) | def expand(self, batch_shape, _instance=None): method sample (line 92) | def sample(self, sample_shape=()): method log_prob (line 96) | def log_prob(self, value): method mean (line 112) | def mean(self): method variance (line 116) | def variance(self): method enumerate_support (line 123) | def enumerate_support(self, expand=True): class DirichletMultinomial (line 140) | class DirichletMultinomial(TorchDistribution): method __init__ (line 161) | def __init__( method concentration (line 178) | def concentration(self): method infer_shapes (line 182) | def infer_shapes(concentration, total_count=()): method expand (line 187) | def expand(self, batch_shape, _instance=None): method sample (line 199) | def sample(self, sample_shape=()): method log_prob (line 208) | def log_prob(self, value): method mean (line 217) | def mean(self): method variance (line 221) | def variance(self): class GammaPoisson (line 229) | class GammaPoisson(TorchDistribution): method __init__ (line 252) | def __init__(self, concentration, rate, validate_args=None): method concentration (line 259) | def concentration(self): method rate (line 263) | def rate(self): method expand (line 266) | def expand(self, batch_shape, _instance=None): method sample (line 274) | def sample(self, sample_shape=()): method log_prob (line 278) | def log_prob(self, value): method mean (line 290) | def mean(self): method variance (line 294) | def variance(self): FILE: pyro/distributions/constraints.py class _Integer (line 50) | class _Integer(Constraint): method check (line 57) | def check(self, value): method __repr__ (line 60) | def __repr__(self): class _Sphere (line 64) | class _Sphere(Constraint): method check (line 72) | def check(self, value): method __repr__ (line 78) | def __repr__(self): class _CorrMatrix (line 82) | class _CorrMatrix(Constraint): method check (line 89) | def check(self, value): class _OrderedVector (line 98) | class _OrderedVector(Constraint): method check (line 106) | def check(self, value): class _PositiveOrderedVector (line 115) | class _PositiveOrderedVector(Constraint): method check (line 121) | def check(self, value): class _SoftplusPositive (line 125) | class _SoftplusPositive(type(positive)): method __init__ (line 126) | def __init__(self): class _SoftplusLowerCholesky (line 130) | class _SoftplusLowerCholesky(type(lower_cholesky)): class _UnitLowerCholesky (line 134) | class _UnitLowerCholesky(Constraint): method check (line 141) | def check(self, value): FILE: pyro/distributions/delta.py class Delta (line 14) | class Delta(TorchDistribution): method __init__ (line 32) | def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): method support (line 57) | def support(self): method expand (line 60) | def expand(self, batch_shape, _instance=None): method rsample (line 69) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 73) | def log_prob(self, x): method mean (line 80) | def mean(self): method variance (line 84) | def variance(self): FILE: pyro/distributions/diag_normal_mixture.py class MixtureOfDiagNormals (line 15) | class MixtureOfDiagNormals(TorchDistribution): method __init__ (line 51) | def __init__(self, locs, coord_scale, component_logits): method expand (line 87) | def expand(self, batch_shape, _instance=None): method log_prob (line 107) | def log_prob(self, value): method rsample (line 122) | def rsample(self, sample_shape=torch.Size()): class _MixDiagNormalSample (line 134) | class _MixDiagNormalSample(Function): method forward (line 136) | def forward(ctx, locs, scales, component_logits, pis, which, noise_sha... method backward (line 151) | def backward(ctx, grad_output): FILE: pyro/distributions/diag_normal_mixture_shared_cov.py class MixtureOfDiagNormalsSharedCovariance (line 15) | class MixtureOfDiagNormalsSharedCovariance(TorchDistribution): method __init__ (line 50) | def __init__(self, locs, coord_scale, component_logits): method expand (line 85) | def expand(self, batch_shape, _instance=None): method log_prob (line 108) | def log_prob(self, value): method rsample (line 124) | def rsample(self, sample_shape=torch.Size()): class _MixDiagNormalSharedCovarianceSample (line 136) | class _MixDiagNormalSharedCovarianceSample(Function): method forward (line 138) | def forward(ctx, locs, coord_scale, component_logits, pis, which, nois... method backward (line 152) | def backward(ctx, grad_output): FILE: pyro/distributions/distribution.py class DistributionMeta (line 15) | class DistributionMeta(ABCMeta): method __init__ (line 16) | def __init__(cls, *args, **kwargs): method __call__ (line 21) | def __call__(cls, *args, **kwargs): class Distribution (line 29) | class Distribution(metaclass=DistributionMeta): method __call__ (line 55) | def __call__(self, *args, **kwargs): method sample (line 68) | def sample(self, *args, **kwargs): method log_prob (line 85) | def log_prob(self, x, *args, **kwargs): method score_parts (line 98) | def score_parts(self, x, *args, **kwargs): method enumerate_support (line 127) | def enumerate_support(self, expand: bool = True) -> torch.Tensor: method conjugate_update (line 147) | def conjugate_update(self, other): method has_rsample_ (line 180) | def has_rsample_(self, value): method rv (line 199) | def rv(self): FILE: pyro/distributions/empirical.py class Empirical (line 13) | class Empirical(TorchDistribution): method __init__ (line 53) | def __init__(self, samples, log_weights, validate_args=None): method sample_size (line 77) | def sample_size(self): method sample (line 85) | def sample(self, sample_shape=torch.Size()): method log_prob (line 103) | def log_prob(self, value): method _weighted_mean (line 128) | def _weighted_mean(self, value, keepdim=False): method event_shape (line 141) | def event_shape(self): method mean (line 145) | def mean(self): method variance (line 157) | def variance(self): method log_weights (line 171) | def log_weights(self): method enumerate_support (line 174) | def enumerate_support(self, expand=True): FILE: pyro/distributions/extended.py class ExtendedBinomial (line 12) | class ExtendedBinomial(Binomial): method log_prob (line 27) | def log_prob(self, value): class ExtendedBetaBinomial (line 33) | class ExtendedBetaBinomial(BetaBinomial): method log_prob (line 48) | def log_prob(self, value): FILE: pyro/distributions/folded.py class FoldedDistribution (line 10) | class FoldedDistribution(TransformedDistribution): method __init__ (line 21) | def __init__(self, base_dist, validate_args=None): method expand (line 26) | def expand(self, batch_shape, _instance=None): method log_prob (line 30) | def log_prob(self, value): FILE: pyro/distributions/gaussian_scale_mixture.py class GaussianScaleMixture (line 15) | class GaussianScaleMixture(TorchDistribution): method __init__ (line 60) | def __init__(self, coord_scale, component_logits, component_scale): method _compute_coeffs (line 83) | def _compute_coeffs(self): method log_prob (line 93) | def log_prob(self, value): method rsample (line 108) | def rsample(self, sample_shape=torch.Size()): class _GSMSample (line 121) | class _GSMSample(Function): method forward (line 123) | def forward( method backward (line 136) | def backward(ctx, grad_output): FILE: pyro/distributions/grouped_normal_normal.py class GroupedNormalNormal (line 15) | class GroupedNormalNormal(TorchDistribution): method __init__ (line 59) | def __init__( method expand (line 98) | def expand(self, batch_shape, _instance=None): method sample (line 101) | def sample(self, sample_shape=()): method get_posterior (line 104) | def get_posterior(self, value): method log_prob (line 131) | def log_prob(self, value): FILE: pyro/distributions/hmm.py function _linear_integrate (line 32) | def _linear_integrate(init, trans, shift): function _logmatmulexp (line 51) | def _logmatmulexp(x, y): function _sequential_logmatmulexp (line 65) | def _sequential_logmatmulexp(logits): function _markov_index (line 88) | def _markov_index(x, y): function _sequential_index (line 96) | def _sequential_index(samples): function _sequential_gamma_gaussian_tensordot (line 164) | def _sequential_gamma_gaussian_tensordot(gamma_gaussian): class HiddenMarkovModel (line 189) | class HiddenMarkovModel(TorchDistribution): method __init__ (line 200) | def __init__(self, duration, batch_shape, event_shape, validate_args=N... method duration (line 218) | def duration(self): method _validate_sample (line 224) | def _validate_sample(self, value): class DiscreteHMM (line 243) | class DiscreteHMM(HiddenMarkovModel): method __init__ (line 293) | def __init__( method support (line 333) | def support(self): method expand (line 336) | def expand(self, batch_shape, _instance=None): method log_prob (line 352) | def log_prob(self, value): method filter (line 371) | def filter(self, value): method sample (line 400) | def sample(self, sample_shape=torch.Size()): class GaussianHMM (line 434) | class GaussianHMM(HiddenMarkovModel): method __init__ (line 498) | def __init__( method expand (line 546) | def expand(self, batch_shape, _instance=None): method log_prob (line 565) | def log_prob(self, value): method rsample (line 584) | def rsample(self, sample_shape=torch.Size()): method rsample_posterior (line 596) | def rsample_posterior(self, value, sample_shape=torch.Size()): method filter (line 606) | def filter(self, value): method conjugate_update (line 638) | def conjugate_update(self, other): method prefix_condition (line 690) | def prefix_condition(self, data): class GammaGaussianHMM (line 744) | class GammaGaussianHMM(HiddenMarkovModel): method __init__ (line 817) | def __init__( method expand (line 862) | def expand(self, batch_shape, _instance=None): method log_prob (line 879) | def log_prob(self, value): method filter (line 901) | def filter(self, value): class LinearHMM (line 939) | class LinearHMM(HiddenMarkovModel): method __init__ (line 1011) | def __init__( method support (line 1094) | def support(self): # noqa: F811 method expand (line 1097) | def expand(self, batch_shape, _instance=None): method log_prob (line 1119) | def log_prob(self, value): method rsample (line 1122) | def rsample(self, sample_shape=torch.Size()): class IndependentHMM (line 1141) | class IndependentHMM(TorchDistribution): method __init__ (line 1159) | def __init__(self, base_dist): method support (line 1169) | def support(self): method has_rsample (line 1173) | def has_rsample(self): method duration (line 1177) | def duration(self): method expand (line 1180) | def expand(self, batch_shape, _instance=None): method rsample (line 1192) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 1196) | def log_prob(self, value): class GaussianMRF (line 1201) | class GaussianMRF(TorchDistribution): method __init__ (line 1244) | def __init__( method support (line 1270) | def support(self): method expand (line 1273) | def expand(self, batch_shape, _instance=None): method log_prob (line 1291) | def log_prob(self, value): FILE: pyro/distributions/improper_uniform.py class ImproperUniform (line 11) | class ImproperUniform(TorchDistribution): method __init__ (line 46) | def __init__(self, support, batch_shape, event_shape): method support (line 52) | def support(self): method expand (line 55) | def expand(self, batch_shape, _instance=None): method log_prob (line 62) | def log_prob(self, value): method sample (line 67) | def sample(self, sample_shape=torch.Size()): FILE: pyro/distributions/inverse_gamma.py class InverseGamma (line 11) | class InverseGamma(TransformedDistribution): method __init__ (line 30) | def __init__(self, concentration, rate, validate_args=None): method expand (line 38) | def expand(self, batch_shape, _instance=None): method concentration (line 43) | def concentration(self): method rate (line 47) | def rate(self): FILE: pyro/distributions/kl.py function _kl_delta (line 20) | def _kl_delta(p, q): function _kl_independent_independent (line 25) | def _kl_independent_independent(p, q): function _kl_independent_mvn (line 38) | def _kl_independent_mvn(p, q): FILE: pyro/distributions/lkj.py class LKJCorrCholesky (line 14) | class LKJCorrCholesky(LKJCholesky): # DEPRECATED method __init__ (line 15) | def __init__(self, d, eta, validate_args=None): class LKJ (line 24) | class LKJ(TransformedDistribution): method __init__ (line 49) | def __init__(self, dim, concentration=1.0, validate_args=None): method expand (line 56) | def expand(self, batch_shape, _instance=None): method mean (line 61) | def mean(self): FILE: pyro/distributions/log_normal_negative_binomial.py class LogNormalNegativeBinomial (line 14) | class LogNormalNegativeBinomial(TorchDistribution): method __init__ (line 77) | def __init__( method log_prob (line 114) | def log_prob(self, value): method sample (line 118) | def sample(self, sample_shape=torch.Size()): method expand (line 121) | def expand(self, batch_shape, _instance=None): method mean (line 139) | def mean(self): method variance (line 147) | def variance(self): FILE: pyro/distributions/logistic.py class Logistic (line 14) | class Logistic(TorchDistribution): method __init__ (line 40) | def __init__(self, loc, scale, *, validate_args=None): method expand (line 44) | def expand(self, batch_shape, _instance=None): method log_prob (line 53) | def log_prob(self, value): method rsample (line 59) | def rsample(self, sample_shape=torch.Size()): method cdf (line 64) | def cdf(self, value): method icdf (line 70) | def icdf(self, value): method mean (line 74) | def mean(self): method variance (line 78) | def variance(self): method entropy (line 81) | def entropy(self): class SkewLogistic (line 85) | class SkewLogistic(TorchDistribution): method __init__ (line 124) | def __init__(self, loc, scale, asymmetry=1.0, *, validate_args=None): method expand (line 128) | def expand(self, batch_shape, _instance=None): method log_prob (line 138) | def log_prob(self, value): method rsample (line 145) | def rsample(self, sample_shape=torch.Size()): method cdf (line 150) | def cdf(self, value): method icdf (line 156) | def icdf(self, value): FILE: pyro/distributions/mixture.py class MaskedConstraint (line 12) | class MaskedConstraint(constraints.Constraint): method __init__ (line 23) | def __init__(self, mask, constraint0, constraint1): method check (line 28) | def check(self, value): class MaskedMixture (line 39) | class MaskedMixture(TorchDistribution): method __init__ (line 66) | def __init__(self, mask, component0, component1, validate_args=None): method has_rsample (line 98) | def has_rsample(self): method support (line 102) | def support(self): method expand (line 109) | def expand(self, batch_shape): method sample (line 118) | def sample(self, sample_shape=torch.Size()): method rsample (line 128) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 138) | def log_prob(self, value): method mean (line 154) | def mean(self): method variance (line 160) | def variance(self): FILE: pyro/distributions/multivariate_studentt.py class MultivariateStudentT (line 15) | class MultivariateStudentT(TorchDistribution): method __init__ (line 34) | def __init__(self, df, loc, scale_tril, validate_args=None): method scale_tril (line 50) | def scale_tril(self): method covariance_matrix (line 56) | def covariance_matrix(self): method precision_matrix (line 65) | def precision_matrix(self): method infer_shapes (line 74) | def infer_shapes(df, loc, scale_tril): method expand (line 79) | def expand(self, batch_shape, _instance=None): method rsample (line 100) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 107) | def log_prob(self, value): method mean (line 124) | def mean(self): method variance (line 130) | def variance(self): FILE: pyro/distributions/nanmasked.py class NanMaskedNormal (line 9) | class NanMaskedNormal(Normal): method log_prob (line 24) | def log_prob(self, value: torch.Tensor) -> torch.Tensor: class NanMaskedMultivariateNormal (line 40) | class NanMaskedMultivariateNormal(MultivariateNormal): method log_prob (line 65) | def log_prob(self, value: torch.Tensor) -> torch.Tensor: FILE: pyro/distributions/omt_mvn.py class OMTMultivariateNormal (line 13) | class OMTMultivariateNormal(MultivariateNormal): method __init__ (line 30) | def __init__(self, loc, scale_tril): method rsample (line 37) | def rsample(self, sample_shape=torch.Size()): class _OMTMVNSample (line 43) | class _OMTMVNSample(Function): method forward (line 45) | def forward(ctx, loc, scale_tril, shape): method backward (line 53) | def backward(ctx, grad_output): FILE: pyro/distributions/one_one_matching.py class OneOneMatchingConstraint (line 18) | class OneOneMatchingConstraint(constraints.Constraint): method __init__ (line 19) | def __init__(self, num_nodes): method check (line 22) | def check(self, value): class OneOneMatching (line 41) | class OneOneMatching(TorchDistribution): method __init__ (line 84) | def __init__(self, logits, *, bp_iters=None, validate_args=None): method support (line 97) | def support(self): method log_partition_function (line 101) | def log_partition_function(self): method log_prob (line 133) | def log_prob(self, value): method enumerate_support (line 140) | def enumerate_support(self, expand=True): method sample (line 143) | def sample(self, sample_shape=torch.Size()): method mode (line 161) | def mode(self): function maximum_weight_matching (line 169) | def maximum_weight_matching(logits): FILE: pyro/distributions/one_two_matching.py class OneTwoMatchingConstraint (line 18) | class OneTwoMatchingConstraint(constraints.Constraint): method __init__ (line 19) | def __init__(self, num_destins): method check (line 23) | def check(self, value): class OneTwoMatching (line 42) | class OneTwoMatching(TorchDistribution): method __init__ (line 85) | def __init__(self, logits, *, bp_iters=None, validate_args=None): method support (line 98) | def support(self): method log_partition_function (line 102) | def log_partition_function(self): method log_prob (line 142) | def log_prob(self, value): method enumerate_support (line 149) | def enumerate_support(self, expand=True): method sample (line 152) | def sample(self, sample_shape=torch.Size()): method mode (line 170) | def mode(self): function enumerate_one_two_matchings (line 177) | def enumerate_one_two_matchings(num_destins): function maximum_weight_matching (line 204) | def maximum_weight_matching(logits): FILE: pyro/distributions/ordered_logistic.py class OrderedLogistic (line 10) | class OrderedLogistic(Categorical): method __init__ (line 41) | def __init__(self, predictor, cutpoints, validate_args=None): method expand (line 56) | def expand(self, batch_shape, _instance=None): FILE: pyro/distributions/polya_gamma.py class TruncatedPolyaGamma (line 13) | class TruncatedPolyaGamma(TorchDistribution): method __init__ (line 41) | def __init__(self, prototype, validate_args=None): method expand (line 47) | def expand(self, batch_shape, _instance=None): method sample (line 56) | def sample(self, sample_shape=()): method log_prob (line 65) | def log_prob(self, value): FILE: pyro/distributions/projected_normal.py class ProjectedNormal (line 14) | class ProjectedNormal(TorchDistribution): method __init__ (line 56) | def __init__(self, concentration, *, validate_args=None): method infer_shapes (line 64) | def infer_shapes(concentration): method expand (line 69) | def expand(self, batch_shape, _instance=None): method mean (line 80) | def mean(self): method mode (line 88) | def mode(self): method rsample (line 91) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 98) | def log_prob(self, value): method _register_log_prob (line 118) | def _register_log_prob(cls, dim, fn=None): function _dot (line 125) | def _dot(x, y): function _safe_log (line 129) | def _safe_log(x): function _log_prob_2 (line 134) | def _log_prob_2(concentration, value): function _log_prob_3 (line 157) | def _log_prob_3(concentration, value): function _log_prob_4 (line 179) | def _log_prob_4(concentration, value): FILE: pyro/distributions/rejector.py class Rejector (line 10) | class Rejector(TorchDistribution): method __init__ (line 25) | def __init__( method _log_prob_accept (line 41) | def _log_prob_accept(self, x): method _propose_log_prob (line 46) | def _propose_log_prob(self, x): method rsample (line 51) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 67) | def log_prob(self, x): method score_parts (line 70) | def score_parts(self, x): FILE: pyro/distributions/relaxed_straight_through.py class RelaxedOneHotCategoricalStraightThrough (line 12) | class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical): method rsample (line 34) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 40) | def log_prob(self, value): class QuantizeCategorical (line 45) | class QuantizeCategorical(torch.autograd.Function): method forward (line 47) | def forward(ctx, soft_value): method backward (line 56) | def backward(ctx, grad): class RelaxedBernoulliStraightThrough (line 61) | class RelaxedBernoulliStraightThrough(RelaxedBernoulli): method rsample (line 83) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 89) | def log_prob(self, value): class QuantizeBernoulli (line 94) | class QuantizeBernoulli(torch.autograd.Function): method forward (line 96) | def forward(ctx, soft_value): method backward (line 102) | def backward(ctx, grad): FILE: pyro/distributions/score_parts.py class ScoreParts (line 11) | class ScoreParts(NamedTuple): method scale_and_mask (line 21) | def scale_and_mask( FILE: pyro/distributions/sine_bivariate_von_mises.py class SineBivariateVonMises (line 18) | class SineBivariateVonMises(TorchDistribution): method __init__ (line 84) | def __init__( method norm_const (line 143) | def norm_const(self): method log_prob (line 166) | def log_prob(self, value): method sample (line 179) | def sample(self, sample_shape=torch.Size()): method mean (line 295) | def mean(self): method infer_shapes (line 299) | def infer_shapes(cls, **arg_shapes): method expand (line 303) | def expand(self, batch_shape, _instance=None): method _bfind (line 315) | def _bfind(self, eig): method _lbinoms (line 326) | def _lbinoms(n): FILE: pyro/distributions/sine_skewed.py class SineSkewed (line 16) | class SineSkewed(TorchDistribution): method __init__ (line 92) | def __init__(self, base_dist: TorchDistribution, skewness, validate_ar... method __repr__ (line 112) | def __repr__(self): method sample (line 134) | def sample(self, sample_shape=torch.Size()): method log_prob (line 149) | def log_prob(self, value): method expand (line 161) | def expand(self, batch_shape, _instance=None): FILE: pyro/distributions/softlaplace.py class SoftLaplace (line 13) | class SoftLaplace(TorchDistribution): method __init__ (line 35) | def __init__(self, loc, scale, *, validate_args=None): method expand (line 39) | def expand(self, batch_shape, _instance=None): method log_prob (line 48) | def log_prob(self, value): method rsample (line 54) | def rsample(self, sample_shape=torch.Size()): method cdf (line 59) | def cdf(self, value): method icdf (line 65) | def icdf(self, value): method mean (line 69) | def mean(self): method variance (line 73) | def variance(self): FILE: pyro/distributions/spanning_tree.cpp function make_complete_graph (line 11) | at::Tensor make_complete_graph(int num_vertices) { function _remove_edge (line 26) | int _remove_edge(at::Tensor grid, at::Tensor edge_ids, function _add_edge (line 49) | void _add_edge(at::Tensor grid, at::Tensor edge_ids, function _find_valid_edges (line 60) | int _find_valid_edges(const std::vector &components, at::Tensor va... function sample_tree_mcmc (line 77) | at::Tensor sample_tree_mcmc(at::Tensor edge_logits, at::Tensor edges) { function sample_tree_approx (line 134) | at::Tensor sample_tree_approx(at::Tensor edge_logits) { function find_best_tree (line 178) | at::Tensor find_best_tree(at::Tensor edge_logits) { function PYBIND11_MODULE (line 221) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { FILE: pyro/distributions/spanning_tree.py class SpanningTree (line 14) | class SpanningTree(TorchDistribution): method __init__ (line 57) | def __init__(self, edge_logits, sampler_options=None, validate_args=No... method validate_edges (line 78) | def validate_edges(self, edges): method log_partition_function (line 119) | def log_partition_function(self): method log_prob (line 142) | def log_prob(self, edges): method sample (line 150) | def sample(self, sample_shape=torch.Size()): method enumerate_support (line 177) | def enumerate_support(self, expand=True): method mode (line 185) | def mode(self): method edge_mean (line 194) | def edge_mean(self): function _get_cpp_module (line 225) | def _get_cpp_module(): function make_complete_graph (line 244) | def make_complete_graph(num_vertices, backend="python"): function _make_complete_graph (line 261) | def _make_complete_graph(num_vertices): function _remove_edge (line 277) | def _remove_edge(grid, edge_ids, neighbors, components, e): function _add_edge (line 297) | def _add_edge(grid, edge_ids, neighbors, components, e, k): function _find_valid_edges (line 309) | def _find_valid_edges(components, valid_edge_ids): function _sample_tree_mcmc (line 332) | def _sample_tree_mcmc(edge_logits, edges): function sample_tree_mcmc (line 381) | def sample_tree_mcmc(edge_logits, edges, backend="python"): function _sample_tree_approx (line 415) | def _sample_tree_approx(edge_logits): function sample_tree_approx (line 452) | def sample_tree_approx(edge_logits, backend="python"): function sample_tree (line 473) | def sample_tree(edge_logits, init_edges=None, mcmc_steps=1, backend="pyt... function _find_best_tree (line 483) | def _find_best_tree(edge_logits): function find_best_tree (line 519) | def find_best_tree(edge_logits, backend="python"): function _permute_tree (line 593) | def _permute_tree(perm, tree): function _close_under_permutations (line 599) | def _close_under_permutations(V, tree_generators): function enumerate_spanning_trees (line 610) | def enumerate_spanning_trees(V): FILE: pyro/distributions/stable.py function _unsafe_standard_stable (line 14) | def _unsafe_standard_stable(alpha, beta, V, W, coords): function _standard_stable (line 51) | def _standard_stable(alpha, beta, aux_uniform, aux_exponential, coords): class Stable (line 96) | class Stable(TorchDistribution): method __init__ (line 161) | def __init__( method expand (line 171) | def expand(self, batch_shape, _instance=None): method log_prob (line 181) | def log_prob(self, value): method rsample (line 209) | def rsample(self, sample_shape=torch.Size()): method mean (line 224) | def mean(self): method variance (line 233) | def variance(self): class StableWithLogProb (line 238) | class StableWithLogProb(Stable): FILE: pyro/distributions/stable_log_prob.py function create_integrator (line 19) | def create_integrator(num_points): function set_integrator (line 40) | def set_integrator(num_points): function integrate (line 47) | def integrate(*args, **kwargs): # noqa: F811 function _stable_log_prob (line 52) | def _stable_log_prob(alpha, beta, value, coords): function _unsafe_alpha_stable_log_prob_S0 (line 90) | def _unsafe_alpha_stable_log_prob_S0(alpha, beta, Z): function _unsafe_stable_log_prob (line 130) | def _unsafe_stable_log_prob(alpha, beta, Z): function _unsafe_stable_given_uniform_log_prob (line 154) | def _unsafe_stable_given_uniform_log_prob(V, alpha, beta, Z): function _unsafe_alpha_stable_log_prob_at_zero (line 188) | def _unsafe_alpha_stable_log_prob_at_zero(alpha, beta): FILE: pyro/distributions/testing/fakes.py class NonreparameterizedBeta (line 7) | class NonreparameterizedBeta(Beta): class NonreparameterizedDirichlet (line 11) | class NonreparameterizedDirichlet(Dirichlet): class NonreparameterizedGamma (line 15) | class NonreparameterizedGamma(Gamma): class NonreparameterizedNormal (line 19) | class NonreparameterizedNormal(Normal): FILE: pyro/distributions/testing/gof.py class InvalidTest (line 68) | class InvalidTest(ValueError): function print_histogram (line 72) | def print_histogram(probs, counts): function multinomial_goodness_of_fit (line 81) | def multinomial_goodness_of_fit( function unif01_goodness_of_fit (line 137) | def unif01_goodness_of_fit(samples, *, plot=False): function exp_goodness_of_fit (line 160) | def exp_goodness_of_fit(samples, plot=False): function density_goodness_of_fit (line 176) | def density_goodness_of_fit(samples, probs, plot=False): function volume_of_sphere (line 205) | def volume_of_sphere(dim, radius): function get_nearest_neighbor_distances (line 209) | def get_nearest_neighbor_distances(samples): function vector_density_goodness_of_fit (line 224) | def vector_density_goodness_of_fit(samples, probs, *, dim=None, plot=Fal... function auto_goodness_of_fit (line 266) | def auto_goodness_of_fit(samples, probs, *, dim=None, plot=False): FILE: pyro/distributions/testing/naive_dirichlet.py class NaiveDirichlet (line 11) | class NaiveDirichlet(Dirichlet): method __init__ (line 19) | def __init__(self, concentration, validate_args=None): method rsample (line 25) | def rsample(self, sample_shape=torch.Size()): class NaiveBeta (line 31) | class NaiveBeta(Beta): method __init__ (line 39) | def __init__(self, concentration1, concentration0, validate_args=None): method rsample (line 44) | def rsample(self, sample_shape=torch.Size()): FILE: pyro/distributions/testing/rejection_exponential.py class RejectionExponential (line 14) | class RejectionExponential(Rejector): method __init__ (line 18) | def __init__(self, rate, factor): method log_prob_accept (line 26) | def log_prob_accept(self, x): method batch_shape (line 32) | def batch_shape(self): method event_shape (line 36) | def event_shape(self): FILE: pyro/distributions/testing/rejection_gamma.py class RejectionStandardGamma (line 13) | class RejectionStandardGamma(Rejector): method __init__ (line 19) | def __init__(self, concentration): method expand (line 42) | def expand(self, batch_shape, _instance=None): method propose (line 63) | def propose(self, sample_shape=torch.Size()): method propose_log_prob (line 74) | def propose_log_prob(self, value): method log_prob_accept (line 87) | def log_prob_accept(self, value): method log_prob (line 95) | def log_prob(self, x): class RejectionGamma (line 100) | class RejectionGamma(Gamma): method __init__ (line 103) | def __init__(self, concentration, rate, validate_args=None): method expand (line 108) | def expand(self, batch_shape, _instance=None): method rsample (line 115) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 118) | def log_prob(self, x): method score_parts (line 121) | def score_parts(self, x): class ShapeAugmentedGamma (line 128) | class ShapeAugmentedGamma(Gamma): method __init__ (line 136) | def __init__(self, concentration, rate, boost=1, validate_args=None): method expand (line 145) | def expand(self, batch_shape, _instance=None): method rsample (line 155) | def rsample(self, sample_shape=torch.Size()): method score_parts (line 164) | def score_parts(self, boosted_x=None): class ShapeAugmentedDirichlet (line 175) | class ShapeAugmentedDirichlet(Dirichlet): method __init__ (line 183) | def __init__(self, concentration, boost=1, validate_args=None): method expand (line 189) | def expand(self, batch_shape, _instance=None): method rsample (line 199) | def rsample(self, sample_shape=torch.Size()): class ShapeAugmentedBeta (line 205) | class ShapeAugmentedBeta(Beta): method __init__ (line 213) | def __init__(self, concentration1, concentration0, boost=1, validate_a... method expand (line 220) | def expand(self, batch_shape, _instance=None): method rsample (line 230) | def rsample(self, sample_shape=torch.Size()): FILE: pyro/distributions/testing/special.py function log (line 41) | def log(x): function incomplete_gamma (line 49) | def incomplete_gamma(x, s): function chi2sf (line 82) | def chi2sf(x, s): FILE: pyro/distributions/torch.py function _clamp_by_zero (line 18) | def _clamp_by_zero(x): class Beta (line 23) | class Beta(torch.distributions.Beta, TorchDistributionMixin): method conjugate_update (line 24) | def conjugate_update(self, other): class Binomial (line 44) | class Binomial(torch.distributions.Binomial, TorchDistributionMixin): method sample (line 56) | def sample(self, sample_shape=torch.Size()): method log_prob (line 83) | def log_prob(self, value): function _validate_thresh (line 107) | def _validate_thresh(thresh): function _validate_tol (line 115) | def _validate_tol(tol): class Categorical (line 124) | class Categorical(torch.distributions.Categorical, TorchDistributionMixin): method log_prob (line 127) | def log_prob(self, value): method enumerate_support (line 145) | def enumerate_support(self, expand=True): class Dirichlet (line 152) | class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin): method infer_shapes (line 154) | def infer_shapes(concentration): method conjugate_update (line 159) | def conjugate_update(self, other): class Gamma (line 177) | class Gamma(torch.distributions.Gamma, TorchDistributionMixin): method conjugate_update (line 178) | def conjugate_update(self, other): class Geometric (line 197) | class Geometric(torch.distributions.Geometric, TorchDistributionMixin): method log_prob (line 199) | def log_prob(self, value): class LogNormal (line 205) | class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin): method __init__ (line 206) | def __init__(self, loc, scale, validate_args=None): method expand (line 216) | def expand(self, batch_shape, _instance=None): class LowRankMultivariateNormal (line 223) | class LowRankMultivariateNormal( method infer_shapes (line 227) | def infer_shapes(loc, cov_factor, cov_diag): class MultivariateNormal (line 233) | class MultivariateNormal( method infer_shapes (line 237) | def infer_shapes( class Multinomial (line 247) | class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin): method infer_shapes (line 248) | def infer_shapes(total_count=None, probs=None, logits=None): class Normal (line 256) | class Normal(torch.distributions.Normal, TorchDistributionMixin): class OneHotCategorical (line 260) | class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDist... method infer_shapes (line 262) | def infer_shapes(probs=None, logits=None): class Poisson (line 269) | class Poisson(torch.distributions.Poisson, TorchDistributionMixin): method __init__ (line 270) | def __init__(self, rate, *, is_sparse=False, validate_args=None): method expand (line 274) | def expand(self, batch_shape, _instance=None): method log_prob (line 280) | def log_prob(self, value): class Independent (line 297) | class Independent(torch.distributions.Independent, TorchDistributionMixin): method infer_shapes (line 299) | def infer_shapes(**kwargs): method _validate_args (line 303) | def _validate_args(self): method _validate_args (line 307) | def _validate_args(self, value): method conjugate_update (line 310) | def conjugate_update(self, other): class Uniform (line 321) | class Uniform(torch.distributions.Uniform, TorchDistributionMixin): method __init__ (line 322) | def __init__(self, low, high, validate_args=None): method expand (line 327) | def expand(self, batch_shape, _instance=None): method support (line 335) | def support(self): function _cat_docstrings (line 339) | def _cat_docstrings(*docstrings): FILE: pyro/distributions/torch_distribution.py class TorchDistributionMixin (line 19) | class TorchDistributionMixin(Distribution, Callable): method __call__ (line 31) | def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.T... method batch_shape (line 55) | def batch_shape(self) -> torch.Size: method event_shape (line 63) | def event_shape(self) -> torch.Size: method event_dim (line 71) | def event_dim(self) -> int: method shape (line 78) | def shape(self, sample_shape=torch.Size()): method infer_shapes (line 95) | def infer_shapes(cls, **arg_shapes): method expand (line 122) | def expand(self, batch_shape, _instance=None) -> "ExpandedDistribution": method expand_by (line 135) | def expand_by(self, sample_shape): method reshape (line 156) | def reshape(self, sample_shape=None, extra_event_dims=None): method to_event (line 163) | def to_event(self, reinterpreted_batch_ndims=None): method independent (line 215) | def independent(self, reinterpreted_batch_ndims=None): method mask (line 221) | def mask(self, mask): class TorchDistribution (line 235) | class TorchDistribution(torch.distributions.Distribution, TorchDistribut... class MaskedDistribution (line 302) | class MaskedDistribution(TorchDistribution): method __init__ (line 317) | def __init__(self, base_dist, mask): method expand (line 330) | def expand(self, batch_shape, _instance=None): method has_rsample (line 344) | def has_rsample(self): method has_enumerate_support (line 348) | def has_enumerate_support(self): method support (line 352) | def support(self): method sample (line 355) | def sample(self, sample_shape=torch.Size()): method rsample (line 358) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 361) | def log_prob(self, value): method score_parts (line 371) | def score_parts(self, value): method enumerate_support (line 376) | def enumerate_support(self, expand=True): method mean (line 380) | def mean(self): method variance (line 384) | def variance(self): method conjugate_update (line 387) | def conjugate_update(self, other): class ExpandedDistribution (line 399) | class ExpandedDistribution(TorchDistribution): method __init__ (line 402) | def __init__(self, base_dist, batch_shape=torch.Size()): method expand (line 408) | def expand(self, batch_shape, _instance=None): method _broadcast_shape (line 421) | def _broadcast_shape(existing_shape, new_shape): method has_rsample (line 451) | def has_rsample(self): method has_enumerate_support (line 455) | def has_enumerate_support(self): method support (line 459) | def support(self): method _sample (line 462) | def _sample(self, sample_fn, sample_shape): method sample (line 477) | def sample(self, sample_shape=torch.Size()): method rsample (line 480) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 483) | def log_prob(self, value): method score_parts (line 490) | def score_parts(self, value): method enumerate_support (line 503) | def enumerate_support(self, expand=True): method mean (line 512) | def mean(self): method variance (line 516) | def variance(self): method conjugate_update (line 519) | def conjugate_update(self, other): function _kl_masked_masked (line 530) | def _kl_masked_masked(p, q): FILE: pyro/distributions/torch_patch.py function patch_dependency (line 12) | def patch_dependency(target, root_module=torch): function _Transform__getstate__ (line 46) | def _Transform__getstate__(self): function _Transform_clear_cache (line 57) | def _Transform_clear_cache(self): function _TransformedDistribution_clear_cache (line 64) | def _TransformedDistribution_clear_cache(self): function _HalfCauchy_logprob (line 71) | def _HalfCauchy_logprob(self, value): function _CorrCholesky_check (line 83) | def _CorrCholesky_check(self, value): function _lazy_property__call__ (line 91) | def _lazy_property__call__(self): FILE: pyro/distributions/torch_transform.py class TransformModule (line 7) | class TransformModule(torch.distributions.Transform, torch.nn.Module): method __init__ (line 13) | def __init__(self, *args, **kwargs): method __hash__ (line 16) | def __hash__(self): class ComposeTransformModule (line 20) | class ComposeTransformModule(torch.distributions.ComposeTransform, torch... method __init__ (line 28) | def __init__(self, parts, cache_size=0): method __hash__ (line 34) | def __hash__(self): method with_cache (line 37) | def with_cache(self, cache_size=1): FILE: pyro/distributions/transforms/__init__.py function _transform_to_sphere (line 112) | def _transform_to_sphere(constraint): function _transform_to_corr_matrix (line 118) | def _transform_to_corr_matrix(constraint): function _transform_to_ordered_vector (line 126) | def _transform_to_ordered_vector(constraint): function _transform_to_positive_ordered_vector (line 132) | def _transform_to_positive_ordered_vector(constraint): function _transform_to_positive_definite (line 138) | def _transform_to_positive_definite(constraint): function _transform_to_softplus_positive (line 144) | def _transform_to_softplus_positive(constraint): function _transform_to_softplus_lower_cholesky (line 149) | def _transform_to_softplus_lower_cholesky(constraint): function _transform_to_unit_lower_cholesky (line 154) | def _transform_to_unit_lower_cholesky(constraint): function iterated (line 158) | def iterated(repeats, base_fn, *args, **kwargs): FILE: pyro/distributions/transforms/affine_autoregressive.py class AffineAutoregressive (line 19) | class AffineAutoregressive(TransformModule): method __init__ (line 100) | def __init__( method _call (line 122) | def _call(self, x): method _inverse (line 141) | def _inverse(self, y): method log_abs_det_jacobian (line 174) | def log_abs_det_jacobian(self, x, y): method _call_stable (line 196) | def _call_stable(self, x): method _inverse_stable (line 214) | def _inverse_stable(self, y): class ConditionalAffineAutoregressive (line 238) | class ConditionalAffineAutoregressive(ConditionalTransformModule): method __init__ (line 326) | def __init__(self, autoregressive_nn, **kwargs): method condition (line 331) | def condition(self, context): function affine_autoregressive (line 343) | def affine_autoregressive(input_dim, hidden_dims=None, **kwargs): function conditional_affine_autoregressive (line 376) | def conditional_affine_autoregressive( FILE: pyro/distributions/transforms/affine_coupling.py class AffineCoupling (line 20) | class AffineCoupling(TransformModule): method __init__ (line 91) | def __init__( method domain (line 112) | def domain(self): method codomain (line 116) | def codomain(self): method _call (line 119) | def _call(self, x): method _inverse (line 146) | def _inverse(self, y): method log_abs_det_jacobian (line 172) | def log_abs_det_jacobian(self, x, y): class ConditionalAffineCoupling (line 192) | class ConditionalAffineCoupling(ConditionalTransformModule): method __init__ (line 270) | def __init__(self, split_dim, hypernet, **kwargs): method condition (line 276) | def condition(self, context): function affine_coupling (line 281) | def affine_coupling(input_dim, hidden_dims=None, split_dim=None, dim=-1,... function conditional_affine_coupling (line 337) | def conditional_affine_coupling( FILE: pyro/distributions/transforms/basic.py class ELUTransform (line 15) | class ELUTransform(Transform): method __eq__ (line 25) | def __eq__(self, other): method _call (line 28) | def _call(self, x): method _inverse (line 31) | def _inverse(self, y, eps=1e-8): method log_abs_det_jacobian (line 36) | def log_abs_det_jacobian(self, x, y): function elu (line 40) | def elu(): class LeakyReLUTransform (line 52) | class LeakyReLUTransform(Transform): method __eq__ (line 62) | def __eq__(self, other): method _call (line 65) | def _call(self, x): method _inverse (line 68) | def _inverse(self, y): method log_abs_det_jacobian (line 71) | def log_abs_det_jacobian(self, x, y): function leaky_relu (line 77) | def leaky_relu(): function tanh (line 86) | def tanh(): FILE: pyro/distributions/transforms/batchnorm.py class BatchNorm (line 14) | class BatchNorm(TransformModule): method __init__ (line 77) | def __init__(self, input_dim, momentum=0.1, epsilon=1e-5): method constrained_gamma (line 90) | def constrained_gamma(self): method _call (line 93) | def _call(self, x): method _inverse (line 107) | def _inverse(self, y): method log_abs_det_jacobian (line 131) | def log_abs_det_jacobian(self, x, y): function batchnorm (line 143) | def batchnorm(input_dim, **kwargs): FILE: pyro/distributions/transforms/block_autoregressive.py function log_matrix_product (line 19) | def log_matrix_product(A, B): class BlockAutoregressive (line 29) | class BlockAutoregressive(TransformModule): method __init__ (line 78) | def __init__( method _call (line 126) | def _call(self, x): method _inverse (line 174) | def _inverse(self, y): method log_abs_det_jacobian (line 189) | def log_abs_det_jacobian(self, x, y): class MaskedBlockLinear (line 202) | class MaskedBlockLinear(torch.nn.Module): method __init__ (line 209) | def __init__(self, in_features, out_features, dim, bias=True): method get_weights (line 258) | def get_weights(self): method forward (line 282) | def forward(self, x): function block_autoregressive (line 287) | def block_autoregressive(input_dim, **kwargs): FILE: pyro/distributions/transforms/cholesky.py class CorrLCholeskyTransform (line 13) | class CorrLCholeskyTransform(CorrCholeskyTransform): # DEPRECATED method __init__ (line 14) | def __init__(self, cache_size=0): class CholeskyTransform (line 22) | class CholeskyTransform(Transform): method __eq__ (line 32) | def __eq__(self, other): method _call (line 35) | def _call(self, x): method _inverse (line 38) | def _inverse(self, y): method log_abs_det_jacobian (line 41) | def log_abs_det_jacobian(self, x, y): class CorrMatrixCholeskyTransform (line 50) | class CorrMatrixCholeskyTransform(CholeskyTransform): method __eq__ (line 61) | def __eq__(self, other): method log_abs_det_jacobian (line 64) | def log_abs_det_jacobian(self, x, y): FILE: pyro/distributions/transforms/discrete_cosine.py class DiscreteCosineTransform (line 12) | class DiscreteCosineTransform(Transform): method __init__ (line 30) | def __init__(self, dim=-1, smooth=0.0, cache_size=0): method __hash__ (line 37) | def __hash__(self): method __eq__ (line 40) | def __eq__(self, other): method domain (line 48) | def domain(self): method codomain (line 52) | def codomain(self): method _weight (line 56) | def _weight(self, y): method _call (line 66) | def _call(self, x): method _inverse (line 77) | def _inverse(self, y): method log_abs_det_jacobian (line 88) | def log_abs_det_jacobian(self, x, y): method with_cache (line 91) | def with_cache(self, cache_size=1): method forward_shape (line 96) | def forward_shape(self, shape): method inverse_shape (line 101) | def inverse_shape(self, shape): FILE: pyro/distributions/transforms/generalized_channel_permute.py class ConditionedGeneralizedChannelPermute (line 16) | class ConditionedGeneralizedChannelPermute(Transform): method __init__ (line 21) | def __init__(self, permutation=None, LU=None): method U_diag (line 28) | def U_diag(self): method L (line 32) | def L(self): method U (line 38) | def U(self): method _call (line 41) | def _call(self, x): method _inverse (line 65) | def _inverse(self, y): method log_abs_det_jacobian (line 97) | def log_abs_det_jacobian(self, x, y): class GeneralizedChannelPermute (line 111) | class GeneralizedChannelPermute(ConditionedGeneralizedChannelPermute, Tr... method __init__ (line 169) | def __init__(self, channels=3, permutation=None): class ConditionalGeneralizedChannelPermute (line 200) | class ConditionalGeneralizedChannelPermute(ConditionalTransformModule): method __init__ (line 267) | def __init__(self, nn, channels=3, permutation=None): method condition (line 280) | def condition(self, context): function generalized_channel_permute (line 286) | def generalized_channel_permute(**kwargs): function conditional_generalized_channel_permute (line 300) | def conditional_generalized_channel_permute(context_dim, channels=3, hid... FILE: pyro/distributions/transforms/haar.py class HaarTransform (line 11) | class HaarTransform(Transform): method __init__ (line 30) | def __init__(self, dim=-1, flip=False, cache_size=0): method __hash__ (line 36) | def __hash__(self): method __eq__ (line 39) | def __eq__(self, other): method domain (line 47) | def domain(self): method codomain (line 51) | def codomain(self): method _call (line 54) | def _call(self, x): method _inverse (line 65) | def _inverse(self, y): method log_abs_det_jacobian (line 76) | def log_abs_det_jacobian(self, x, y): method with_cache (line 79) | def with_cache(self, cache_size=1): method forward_shape (line 84) | def forward_shape(self, shape): method inverse_shape (line 89) | def inverse_shape(self, shape): FILE: pyro/distributions/transforms/householder.py class ConditionedHouseholder (line 19) | class ConditionedHouseholder(Transform): method __init__ (line 25) | def __init__(self, u_unnormed=None): method u (line 30) | def u(self): method _call (line 35) | def _call(self, x): method _inverse (line 52) | def _inverse(self, y): method log_abs_det_jacobian (line 70) | def log_abs_det_jacobian(self, x, y): class Householder (line 82) | class Householder(ConditionedHouseholder, TransformModule): method __init__ (line 131) | def __init__(self, input_dim, count_transforms=1): method reset_parameters (line 151) | def reset_parameters(self): class ConditionalHouseholder (line 157) | class ConditionalHouseholder(ConditionalTransformModule): method __init__ (line 217) | def __init__(self, input_dim, nn, count_transforms=1): method _u_unnormed (line 236) | def _u_unnormed(self, context): method condition (line 246) | def condition(self, context): function householder (line 251) | def householder(input_dim, count_transforms=None): function conditional_householder (line 270) | def conditional_householder( FILE: pyro/distributions/transforms/lower_cholesky_affine.py class LowerCholeskyAffine (line 12) | class LowerCholeskyAffine(Transform): method __init__ (line 32) | def __init__(self, loc, scale_tril, cache_size=0): method _call (line 42) | def _call(self, x): method _inverse (line 53) | def _inverse(self, y): method log_abs_det_jacobian (line 64) | def log_abs_det_jacobian(self, x, y): method with_cache (line 74) | def with_cache(self, cache_size=1): FILE: pyro/distributions/transforms/matrix_exponential.py class ConditionedMatrixExponential (line 19) | class ConditionedMatrixExponential(Transform): method __init__ (line 24) | def __init__(self, weights=None, iterations=8, normalization="none", b... method _exp (line 39) | def _exp(self, x, M): method _trace (line 52) | def _trace(self, M): method _call (line 63) | def _call(self, x): method _inverse (line 75) | def _inverse(self, y): method log_abs_det_jacobian (line 85) | def log_abs_det_jacobian(self, x, y): class MatrixExponential (line 95) | class MatrixExponential(ConditionedMatrixExponential, TransformModule): method __init__ (line 154) | def __init__(self, input_dim, iterations=8, normalization="none", boun... method reset_parameters (line 162) | def reset_parameters(self): class ConditionalMatrixExponential (line 168) | class ConditionalMatrixExponential(ConditionalTransformModule): method __init__ (line 235) | def __init__(self, input_dim, nn, iterations=8, normalization="none", ... method _params (line 243) | def _params(self, context): method condition (line 246) | def condition(self, context): function matrix_exponential (line 262) | def matrix_exponential(input_dim, iterations=8, normalization="none", bo... function conditional_matrix_exponential (line 292) | def conditional_matrix_exponential( FILE: pyro/distributions/transforms/neural_autoregressive.py class NeuralAutoregressive (line 23) | class NeuralAutoregressive(TransformModule): method __init__ (line 68) | def __init__(self, autoregressive_nn, hidden_units=16, activation="sig... method _call (line 91) | def _call(self, x): method log_abs_det_jacobian (line 121) | def log_abs_det_jacobian(self, x, y): class ConditionalNeuralAutoregressive (line 144) | class ConditionalNeuralAutoregressive(ConditionalTransformModule): method __init__ (line 194) | def __init__(self, autoregressive_nn, **kwargs): method condition (line 199) | def condition(self, context): function neural_autoregressive (line 212) | def neural_autoregressive(input_dim, hidden_dims=None, activation="sigmo... function conditional_neural_autoregressive (line 239) | def conditional_neural_autoregressive( FILE: pyro/distributions/transforms/normalize.py class Normalize (line 13) | class Normalize(Transform): method __init__ (line 23) | def __init__(self, p=2, cache_size=0): method __eq__ (line 29) | def __eq__(self, other): method _call (line 32) | def _call(self, x): method _inverse (line 35) | def _inverse(self, y): method with_cache (line 38) | def with_cache(self, cache_size=1): FILE: pyro/distributions/transforms/ordered.py class OrderedTransform (line 10) | class OrderedTransform(Transform): method _call (line 23) | def _call(self, x): method _inverse (line 27) | def _inverse(self, y): method log_abs_det_jacobian (line 31) | def log_abs_det_jacobian(self, x, y): FILE: pyro/distributions/transforms/permute.py class Permute (line 14) | class Permute(Transform): method __init__ (line 50) | def __init__(self, permutation, *, dim=-1, cache_size=1): method domain (line 60) | def domain(self): method codomain (line 64) | def codomain(self): method inv_permutation (line 68) | def inv_permutation(self): method _call (line 75) | def _call(self, x): method _inverse (line 87) | def _inverse(self, y): method log_abs_det_jacobian (line 96) | def log_abs_det_jacobian(self, x, y): method with_cache (line 109) | def with_cache(self, cache_size=1): function permute (line 115) | def permute(input_dim, permutation=None, dim=-1): FILE: pyro/distributions/transforms/planar.py class ConditionedPlanar (line 20) | class ConditionedPlanar(Transform): method __init__ (line 25) | def __init__(self, params): method u_hat (line 31) | def u_hat(self, u, w): method _call (line 36) | def _call(self, x): method _inverse (line 67) | def _inverse(self, y): method log_abs_det_jacobian (line 81) | def log_abs_det_jacobian(self, x, y): class Planar (line 95) | class Planar(ConditionedPlanar, TransformModule): method __init__ (line 137) | def __init__(self, input_dim): method _params (line 159) | def _params(self): method reset_parameters (line 162) | def reset_parameters(self): class ConditionalPlanar (line 170) | class ConditionalPlanar(ConditionalTransformModule): method __init__ (line 221) | def __init__(self, nn): method _params (line 225) | def _params(self, context): method condition (line 228) | def condition(self, context): function planar (line 233) | def planar(input_dim): function conditional_planar (line 246) | def conditional_planar(input_dim, context_dim, hidden_dims=None): FILE: pyro/distributions/transforms/polynomial.py class Polynomial (line 17) | class Polynomial(TransformModule): method __init__ (line 76) | def __init__(self, autoregressive_nn, input_dim, count_degree, count_s... method reset_parameters (line 102) | def reset_parameters(self): method _call (line 106) | def _call(self, x): method _inverse (line 142) | def _inverse(self, y): method log_abs_det_jacobian (line 157) | def log_abs_det_jacobian(self, x, y): function polynomial (line 170) | def polynomial(input_dim, hidden_dims=None): FILE: pyro/distributions/transforms/power.py class PositivePowerTransform (line 9) | class PositivePowerTransform(Transform): method __init__ (line 26) | def __init__(self, exponent, *, cache_size=0, validate_args=None): method with_cache (line 38) | def with_cache(self, cache_size=1): method __eq__ (line 43) | def __eq__(self, other): method _call (line 48) | def _call(self, x): method _inverse (line 51) | def _inverse(self, y): method log_abs_det_jacobian (line 54) | def log_abs_det_jacobian(self, x, y): method forward_shape (line 57) | def forward_shape(self, shape): method inverse_shape (line 60) | def inverse_shape(self, shape): FILE: pyro/distributions/transforms/radial.py class ConditionedRadial (line 20) | class ConditionedRadial(Transform): method __init__ (line 25) | def __init__(self, params): method u_hat (line 31) | def u_hat(self, u, w): method _call (line 36) | def _call(self, x): method _inverse (line 66) | def _inverse(self, y): method log_abs_det_jacobian (line 80) | def log_abs_det_jacobian(self, x, y): class Radial (line 94) | class Radial(ConditionedRadial, TransformModule): method __init__ (line 134) | def __init__(self, input_dim): method _params (line 155) | def _params(self): method reset_parameters (line 158) | def reset_parameters(self): class ConditionalRadial (line 166) | class ConditionalRadial(ConditionalTransformModule): method __init__ (line 215) | def __init__(self, nn): method _params (line 219) | def _params(self, context): method condition (line 222) | def condition(self, context): function radial (line 227) | def radial(input_dim): function conditional_radial (line 240) | def conditional_radial(input_dim, context_dim, hidden_dims=None): FILE: pyro/distributions/transforms/simplex_to_ordered.py class SimplexToOrderedTransform (line 12) | class SimplexToOrderedTransform(Transform): method __init__ (line 31) | def __init__(self, anchor_point=None): method _call (line 37) | def _call(self, x): method _inverse (line 42) | def _inverse(self, y): method log_abs_det_jacobian (line 53) | def log_abs_det_jacobian(self, x, y): method __eq__ (line 61) | def __eq__(self, other): method forward_shape (line 66) | def forward_shape(self, shape): method inverse_shape (line 69) | def inverse_shape(self, shape): FILE: pyro/distributions/transforms/softplus.py function softplus_inv (line 9) | def softplus_inv(y): class SoftplusTransform (line 14) | class SoftplusTransform(Transform): method __eq__ (line 24) | def __eq__(self, other): method _call (line 27) | def _call(self, x): method _inverse (line 30) | def _inverse(self, y): method log_abs_det_jacobian (line 33) | def log_abs_det_jacobian(self, x, y): class SoftplusLowerCholeskyTransform (line 37) | class SoftplusLowerCholeskyTransform(Transform): method __eq__ (line 47) | def __eq__(self, other): method _call (line 50) | def _call(self, x): method _inverse (line 54) | def _inverse(self, y): FILE: pyro/distributions/transforms/spline.py function _searchsorted (line 27) | def _searchsorted(sorted_sequence, values): function _select_bins (line 37) | def _select_bins(x, idx): function _calculate_knots (line 59) | def _calculate_knots(lengths, lower, upper): function _monotonic_rational_spline (line 83) | def _monotonic_rational_spline( class ConditionedSpline (line 303) | class ConditionedSpline(Transform): method __init__ (line 313) | def __init__(self, params, bound=3.0, order="linear"): method _call (line 321) | def _call(self, x): method _inverse (line 326) | def _inverse(self, y): method log_abs_det_jacobian (line 338) | def log_abs_det_jacobian(self, x, y): method spline_op (line 350) | def spline_op(self, x, **kwargs): class Spline (line 359) | class Spline(ConditionedSpline, TransformModule): method __init__ (line 412) | def __init__(self, input_dim, count_bins=8, bound=3.0, order="linear"): method _params (line 442) | def _params(self): class ConditionalSpline (line 455) | class ConditionalSpline(ConditionalTransformModule): method __init__ (line 521) | def __init__(self, nn, input_dim, count_bins, bound=3.0, order="linear"): method _params (line 530) | def _params(self, context): method condition (line 567) | def condition(self, context): function spline (line 572) | def spline(input_dim, **kwargs): function conditional_spline (line 588) | def conditional_spline( FILE: pyro/distributions/transforms/spline_autoregressive.py class SplineAutoregressive (line 18) | class SplineAutoregressive(TransformModule): method __init__ (line 78) | def __init__( method _call (line 87) | def _call(self, x): method _inverse (line 101) | def _inverse(self, y): method log_abs_det_jacobian (line 120) | def log_abs_det_jacobian(self, x, y): class ConditionalSplineAutoregressive (line 134) | class ConditionalSplineAutoregressive(ConditionalTransformModule): method __init__ (line 201) | def __init__(self, input_dim, autoregressive_nn, **kwargs): method condition (line 207) | def condition(self, context): function spline_autoregressive (line 220) | def spline_autoregressive( function conditional_spline_autoregressive (line 254) | def conditional_spline_autoregressive( FILE: pyro/distributions/transforms/spline_coupling.py class SplineCoupling (line 15) | class SplineCoupling(TransformModule): method __init__ (line 81) | def __init__( method _call (line 102) | def _call(self, x): method _inverse (line 129) | def _inverse(self, y): method log_abs_det_jacobian (line 155) | def log_abs_det_jacobian(self, x, y): function spline_coupling (line 168) | def spline_coupling( FILE: pyro/distributions/transforms/sylvester.py class Sylvester (line 14) | class Sylvester(Householder): method __init__ (line 60) | def __init__(self, input_dim, count_transforms=1): method dtanh_dx (line 79) | def dtanh_dx(self, x): method R (line 83) | def R(self): method S (line 87) | def S(self): method Q (line 91) | def Q(self, x): method reset_parameters2 (line 105) | def reset_parameters2(self): method _call (line 109) | def _call(self, x): method _inverse (line 133) | def _inverse(self, y): method log_abs_det_jacobian (line 147) | def log_abs_det_jacobian(self, x, y): function sylvester (line 160) | def sylvester(input_dim, count_transforms=None): FILE: pyro/distributions/transforms/unit_cholesky.py class UnitLowerCholeskyTransform (line 11) | class UnitLowerCholeskyTransform(Transform): method __eq__ (line 20) | def __eq__(self, other): method _call (line 23) | def _call(self, x): method _inverse (line 26) | def _inverse(self, y): FILE: pyro/distributions/transforms/utils.py function clamp_preserve_gradients (line 5) | def clamp_preserve_gradients(x, min, max): FILE: pyro/distributions/unit.py class Unit (line 11) | class Unit(TorchDistribution): method __init__ (line 23) | def __init__(self, log_factor, *, has_rsample=None, validate_args=None): method expand (line 32) | def expand(self, batch_shape, _instance=None): method sample (line 42) | def sample(self, sample_shape=torch.Size()): method rsample (line 45) | def rsample(self, sample_shape=torch.Size()): method log_prob (line 48) | def log_prob(self, value): FILE: pyro/distributions/util.py function copy_docs_from (line 33) | def copy_docs_from(source_class, full_text=False): function weakmethod (line 72) | def weakmethod(fn): class _DetachMemo (line 116) | class _DetachMemo(dict): method get (line 117) | def get(self, key, default=None): function detach (line 129) | def detach(obj): class _DeepToMemo (line 141) | class _DeepToMemo(dict): method __init__ (line 142) | def __init__(self, to_args, to_kwargs): method get (line 147) | def get(self, key, default=None): function deep_to (line 159) | def deep_to(obj, *args, **kwargs): function is_identically_zero (line 188) | def is_identically_zero(x): function is_identically_one (line 201) | def is_identically_one(x): function broadcast_shape (line 214) | def broadcast_shape(*shapes, **kwargs): function gather (line 242) | def gather(value, index, dim): function sum_rightmost (line 253) | def sum_rightmost(value, dim): function sum_leftmost (line 279) | def sum_leftmost(value, dim): function scale_and_mask (line 311) | def scale_and_mask(tensor, scale=1.0, mask=None): function scalar_like (line 331) | def scalar_like(prototype, fill_value): function eye_like (line 336) | def eye_like(value, m, n=None): function enable_validation (line 344) | def enable_validation(is_validate): function is_validation_enabled (line 350) | def is_validation_enabled(): function validation_enabled (line 355) | def validation_enabled(is_validate=True): FILE: pyro/distributions/von_mises_3d.py class VonMises3D (line 12) | class VonMises3D(TorchDistribution): method __init__ (line 35) | def __init__(self, concentration, validate_args=None): method log_prob (line 46) | def log_prob(self, value): method expand (line 60) | def expand(self, batch_shape): FILE: pyro/distributions/zero_inflated.py class ZeroInflatedDistribution (line 18) | class ZeroInflatedDistribution(TorchDistribution): method __init__ (line 35) | def __init__(self, base_dist, *, gate=None, gate_logits=None, validate... method support (line 58) | def support(self): method gate (line 62) | def gate(self): method gate_logits (line 66) | def gate_logits(self): method log_prob (line 69) | def log_prob(self, value): method sample (line 86) | def sample(self, sample_shape=torch.Size()): method mean (line 95) | def mean(self): method variance (line 99) | def variance(self): method expand (line 104) | def expand(self, batch_shape, _instance=None): class ZeroInflatedPoisson (line 121) | class ZeroInflatedPoisson(ZeroInflatedDistribution): method __init__ (line 137) | def __init__(self, rate, *, gate=None, gate_logits=None, validate_args... method rate (line 146) | def rate(self): class ZeroInflatedNegativeBinomial (line 150) | class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution): method __init__ (line 171) | def __init__( method total_count (line 194) | def total_count(self): method probs (line 198) | def probs(self): method logits (line 202) | def logits(self): FILE: pyro/infer/abstract_infer.py class EmpiricalMarginal (line 17) | class EmpiricalMarginal(Empirical): method __init__ (line 33) | def __init__(self, trace_posterior, sites=None, validate_args=None): method _get_samples_and_weights (line 46) | def _get_samples_and_weights(self): method _add_sample (line 71) | def _add_sample(self, value, log_weight=None, chain_id=0): method _populate_traces (line 101) | def _populate_traces(self, trace_posterior, sites): class Marginals (line 116) | class Marginals: method __init__ (line 128) | def __init__(self, trace_posterior, sites=None, validate_args=None): method _populate_traces (line 144) | def _populate_traces(self, trace_posterior, validate): method support (line 150) | def support(self, flatten=False): method empirical (line 174) | def empirical(self): class TracePosterior (line 184) | class TracePosterior(object, metaclass=ABCMeta): method __init__ (line 192) | def __init__(self, num_chains=1): method _reset (line 196) | def _reset(self): method marginal (line 205) | def marginal(self, sites=None): method _traces (line 217) | def _traces(self, *args, **kwargs): method __call__ (line 226) | def __call__(self, *args, **kwargs): method run (line 241) | def run(self, *args, **kwargs): method information_criterion (line 265) | def information_criterion(self, pointwise=False): class TracePredictive (line 313) | class TracePredictive(TracePosterior): method __init__ (line 330) | def __init__(self, model, posterior, num_samples, keep_sites=None): method _traces (line 342) | def _traces(self, *args, **kwargs): method _remove_dropped_nodes (line 355) | def _remove_dropped_nodes(self, trace): method _adjust_to_data (line 363) | def _adjust_to_data(self, trace, data_trace): method marginal (line 392) | def marginal(self, sites=None): FILE: pyro/infer/autoguide/effect.py class AutoMessengerMeta (line 21) | class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)): class AutoMessenger (line 25) | class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerM... method __init__ (line 35) | def __init__(self, model: Callable, *, amortized_plates: Tuple[str, ..... method __call__ (line 40) | def __call__(self, *args, **kwargs): method call (line 51) | def call(self, *args, **kwargs): method _adjust_plates (line 67) | def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch... class AutoNormalMessenger (line 84) | class AutoNormalMessenger(AutoMessenger): method __init__ (line 147) | def __init__( method get_posterior (line 162) | def get_posterior( method _get_params (line 177) | def _get_params(self, name: str, prior: Distribution): method median (line 202) | def median(self, *args, **kwargs): method _get_posterior_median (line 209) | def _get_posterior_median(self, name, prior): class AutoHierarchicalNormalMessenger (line 215) | class AutoHierarchicalNormalMessenger(AutoNormalMessenger): method __init__ (line 249) | def __init__( method get_posterior (line 268) | def get_posterior( method _get_params (line 289) | def _get_params(self, name: str, prior: Distribution): method median (line 348) | def median(self, *args, **kwargs): method _get_posterior_median (line 355) | def _get_posterior_median(self, name, prior): class AutoRegressiveMessenger (line 365) | class AutoRegressiveMessenger(AutoMessenger): method __init__ (line 401) | def __init__( method get_posterior (line 415) | def get_posterior( method _get_params (line 429) | def _get_params(self, name: str, prior: Distribution): FILE: pyro/infer/autoguide/gaussian.py class AutoGaussianMeta (line 36) | class AutoGaussianMeta(type(AutoGuide), ABCMeta): method __init__ (line 40) | def __init__(cls, *args, **kwargs): method __call__ (line 46) | def __call__(cls, *args, **kwargs): class AutoGaussian (line 53) | class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta): method __init__ (line 109) | def __init__( method _prototype_hide_fn (line 125) | def _prototype_hide_fn(msg): method _setup_prototype (line 130) | def _setup_prototype(self, *args, **kwargs) -> None: method _compress_site (line 233) | def _compress_site(site): method forward (line 247) | def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: method median (line 268) | def median(self, *args, **kwargs) -> Dict[str, torch.Tensor]: method _transform_values (line 280) | def _transform_values( method _sample_aux_values (line 307) | def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch... class AutoGaussianDense (line 311) | class AutoGaussianDense(AutoGaussian): method _setup_prototype (line 321) | def _setup_prototype(self, *args, **kwargs): method _sample_aux_values (line 385) | def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch... method _dense_unflatten (line 400) | def _dense_unflatten(self, flat_samples: torch.Tensor) -> Dict[str, to... method _dense_flatten (line 415) | def _dense_flatten(self, samples: Dict[str, torch.Tensor]) -> torch.Te... method _dense_get_mvn (line 424) | def _dense_get_mvn(self): class AutoGaussianFunsor (line 444) | class AutoGaussianFunsor(AutoGaussian): method __init__ (line 453) | def __init__(self, *args, **kwargs): method _setup_prototype (line 457) | def _setup_prototype(self, *args, **kwargs): method _sample_aux_values (line 497) | def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch... function _precision_to_scale_tril (line 554) | def _precision_to_scale_tril(P): function _try_possibly_intractable (line 565) | def _try_possibly_intractable(fn, *args, **kwargs): function _plates_to_shape (line 580) | def _plates_to_shape(plates): function _break_plates (line 587) | def _break_plates(x, all_plates, kept_plates): function _import_funsor (line 616) | def _import_funsor(): FILE: pyro/infer/autoguide/guides.py function prototype_hide_fn (line 45) | def prototype_hide_fn(msg): class AutoGuide (line 50) | class AutoGuide(PyroModule): method __init__ (line 67) | def __init__(self, model, *, create_plates=None): method model (line 77) | def model(self): method __getstate__ (line 80) | def __getstate__(self): method __setstate__ (line 86) | def __setstate__(self, state): method _update_master (line 94) | def _update_master(self, master_ref): method call (line 100) | def call(self, *args, **kwargs): method sample_latent (line 115) | def sample_latent(*args, **kwargs): method __setattr__ (line 122) | def __setattr__(self, name, value): method _create_plates (line 128) | def _create_plates(self, *args, **kwargs): method _setup_prototype (line 155) | def _setup_prototype(self, *args, **kwargs): method median (line 174) | def median(self, *args, **kwargs): class AutoGuideList (line 184) | class AutoGuideList(AutoGuide, nn.ModuleList): method _check_prototype (line 198) | def _check_prototype(self, part_trace): method append (line 205) | def append(self, part): method add (line 222) | def add(self, part): method forward (line 230) | def forward(self, *args, **kwargs): method median (line 253) | def median(self, *args, **kwargs): method quantiles (line 265) | def quantiles(self, quantiles, *args, **kwargs): class AutoCallable (line 279) | class AutoCallable(AutoGuide): method __init__ (line 309) | def __init__(self, model, guide, median=lambda *args, **kwargs: {}): method forward (line 314) | def forward(self, *args, **kwargs): class AutoDelta (line 319) | class AutoDelta(AutoGuide): method __init__ (line 352) | def __init__(self, model, init_loc_fn=init_to_median, *, create_plates... method _setup_prototype (line 357) | def _setup_prototype(self, *args, **kwargs): method forward (line 376) | def forward(self, *args, **kwargs): method median (line 404) | def median(self, *args, **kwargs): class AutoNormal (line 415) | class AutoNormal(AutoGuide): method __init__ (line 448) | def __init__( method _setup_prototype (line 460) | def _setup_prototype(self, *args, **kwargs): method _get_loc_and_scale (line 494) | def _get_loc_and_scale(self, name): method forward (line 499) | def forward(self, *args, **kwargs): method median (line 556) | def median(self, *args, **kwargs): method quantiles (line 574) | def quantiles(self, quantiles, *args, **kwargs): class AutoContinuous (line 605) | class AutoContinuous(AutoGuide): method __init__ (line 632) | def __init__(self, model, init_loc_fn=init_to_median): method _setup_prototype (line 636) | def _setup_prototype(self, *args, **kwargs): method _init_loc (line 661) | def _init_loc(self): method get_base_dist (line 674) | def get_base_dist(self): method get_transform (line 688) | def get_transform(self, *args, **kwargs): method get_posterior (line 702) | def get_posterior(self, *args, **kwargs): method sample_latent (line 710) | def sample_latent(self, *args, **kwargs): method _unpack_latent (line 720) | def _unpack_latent(self, latent): method forward (line 748) | def forward(self, *args, **kwargs): method _loc_scale (line 795) | def _loc_scale(self, *args, **kwargs): method median (line 803) | def median(self, *args, **kwargs): method quantiles (line 818) | def quantiles(self, quantiles, *args, **kwargs): class AutoMultivariateNormal (line 844) | class AutoMultivariateNormal(AutoContinuous): method __init__ (line 869) | def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): method _setup_prototype (line 875) | def _setup_prototype(self, *args, **kwargs): method get_base_dist (line 886) | def get_base_dist(self): method get_transform (line 891) | def get_transform(self, *args, **kwargs): method get_posterior (line 895) | def get_posterior(self, *args, **kwargs): method _loc_scale (line 902) | def _loc_scale(self, *args, **kwargs): class AutoDiagonalNormal (line 909) | class AutoDiagonalNormal(AutoContinuous): method __init__ (line 932) | def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): method _setup_prototype (line 938) | def _setup_prototype(self, *args, **kwargs): method get_base_dist (line 947) | def get_base_dist(self): method get_transform (line 952) | def get_transform(self, *args, **kwargs): method get_posterior (line 955) | def get_posterior(self, *args, **kwargs): method _loc_scale (line 961) | def _loc_scale(self, *args, **kwargs): class AutoLowRankMultivariateNormal (line 965) | class AutoLowRankMultivariateNormal(AutoContinuous): method __init__ (line 993) | def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1, ... method _setup_prototype (line 1002) | def _setup_prototype(self, *args, **kwargs): method get_posterior (line 1018) | def get_posterior(self, *args, **kwargs): method _loc_scale (line 1027) | def _loc_scale(self, *args, **kwargs): class AutoNormalizingFlow (line 1032) | class AutoNormalizingFlow(AutoContinuous): method __init__ (line 1054) | def __init__(self, model, init_transform_fn): method get_base_dist (line 1060) | def get_base_dist(self): method get_transform (line 1065) | def get_transform(self, *args, **kwargs): method get_posterior (line 1068) | def get_posterior(self, *args, **kwargs): class AutoIAFNormal (line 1079) | class AutoIAFNormal(AutoNormalizingFlow): method __init__ (line 1107) | def __init__( class AutoLaplaceApproximation (line 1133) | class AutoLaplaceApproximation(AutoContinuous): method _setup_prototype (line 1156) | def _setup_prototype(self, *args, **kwargs): method get_posterior (line 1161) | def get_posterior(self, *args, **kwargs): method laplace_approximation (line 1167) | def laplace_approximation(self, *args, **kwargs): class AutoDiscreteParallel (line 1199) | class AutoDiscreteParallel(AutoGuide): method _setup_prototype (line 1205) | def _setup_prototype(self, *args, **kwargs): method forward (line 1254) | def forward(self, *args, **kwargs): FILE: pyro/infer/autoguide/initialization.py function _is_multivariate (line 29) | def _is_multivariate(d): function init_to_feasible (line 35) | def init_to_feasible(site=None): function init_to_sample (line 50) | def init_to_sample(site=None): function init_to_median (line 62) | def init_to_median( function init_to_mean (line 102) | def init_to_mean( function init_to_uniform (line 136) | def init_to_uniform( function init_to_value (line 157) | def init_to_value( class _InitToGenerated (line 184) | class _InitToGenerated: method __init__ (line 185) | def __init__(self, generate): method __call__ (line 190) | def __call__(self, site): function init_to_generated (line 197) | def init_to_generated(site=None, generate=lambda: init_to_uniform): class InitMessenger (line 220) | class InitMessenger(Messenger): method __init__ (line 229) | def __init__(self, init_fn): method _pyro_sample (line 233) | def _pyro_sample(self, msg): method _pyro_get_init_messengers (line 253) | def _pyro_get_init_messengers(self, msg): FILE: pyro/infer/autoguide/structured.py function _config_auxiliary (line 26) | def _config_auxiliary(msg): class AutoStructured (line 30) | class AutoStructured(AutoGuide): method __init__ (line 104) | def __init__( method _auto_config (line 138) | def _auto_config(self, sample_sites, args, kwargs): method _setup_prototype (line 165) | def _setup_prototype(self, *args, **kwargs): method _compress_site (line 255) | def _compress_site(site): method get_deltas (line 268) | def get_deltas(self, save_params=None): method forward (line 352) | def forward(self, *args, **kwargs): method median (line 369) | def median(self, *args, **kwargs): FILE: pyro/infer/autoguide/utils.py function _product (line 11) | def _product(shape): function deep_setattr (line 21) | def deep_setattr(obj, key, val): function mean_field_entropy (line 41) | def mean_field_entropy(model, args, whitelist=None): function helpful_support_errors (line 63) | def helpful_support_errors(site): FILE: pyro/infer/csis.py class CSIS (line 16) | class CSIS(Importance): method __init__ (line 40) | def __init__( method set_validation_batch (line 57) | def set_validation_batch(self, *args, **kwargs): method step (line 68) | def step(self, *args, **kwargs): method loss_and_grads (line 91) | def loss_and_grads(self, grads, batch, *args, **kwargs): method _differentiable_loss_particle (line 142) | def _differentiable_loss_particle(self, guide_trace): method validation_loss (line 145) | def validation_loss(self, *args, **kwargs): method _get_matched_trace (line 161) | def _get_matched_trace(self, model_trace, *args, **kwargs): method _sample_from_joint (line 190) | def _sample_from_joint(self, *args, **kwargs): FILE: pyro/infer/discrete.py function _make_ring (line 24) | def _make_ring(temperature, cache, dim_to_size): class SamplePosteriorMessenger (line 31) | class SamplePosteriorMessenger(ReplayMessenger): method _pyro_sample (line 34) | def _pyro_sample(self, msg): function _sample_posterior (line 41) | def _sample_posterior( function _sample_posterior_from_trace (line 58) | def _sample_posterior_from_trace( function infer_discrete (line 181) | def infer_discrete( class TraceEnumSample_ELBO (line 234) | class TraceEnumSample_ELBO(TraceEnum_ELBO): method _get_trace (line 256) | def _get_trace(self, model, guide, args, kwargs): method sample_saved (line 269) | def sample_saved(self): FILE: pyro/infer/elbo.py class ELBOModule (line 19) | class ELBOModule(torch.nn.Module): method __init__ (line 20) | def __init__(self, model: torch.nn.Module, guide: torch.nn.Module, elb... method forward (line 26) | def forward(self, *args, **kwargs): class ELBO (line 30) | class ELBO(object, metaclass=ABCMeta): method __init__ (line 110) | def __init__( method __call__ (line 139) | def __call__(self, model: torch.nn.Module, guide: torch.nn.Module) -> ... method _guess_max_plate_nesting (line 146) | def _guess_max_plate_nesting(self, model, guide, args, kwargs): method _vectorized_num_particles (line 188) | def _vectorized_num_particles(self, fn): method _get_vectorized_trace (line 207) | def _get_vectorized_trace(self, model, guide, args, kwargs): method _get_trace (line 221) | def _get_trace(self, model, guide, args, kwargs): method _get_traces (line 228) | def _get_traces(self, model, guide, args, kwargs): FILE: pyro/infer/energy_distance.py function _squared_error (line 19) | def _squared_error(x, y, scale, mask): class EnergyDistance (line 29) | class EnergyDistance: method __init__ (line 79) | def __init__( method _pow (line 96) | def _pow(self, x): method _get_traces (line 101) | def _get_traces(self, model, guide, args, kwargs): method __call__ (line 157) | def __call__(self, model, guide, *args, **kwargs): method loss (line 225) | def loss(self, *args, **kwargs): FILE: pyro/infer/enum.py function iter_discrete_escape (line 16) | def iter_discrete_escape(trace, msg): function iter_discrete_extend (line 25) | def iter_discrete_extend(trace, site, **ignored): function get_importance_trace (line 45) | def get_importance_trace( function iter_discrete_traces (line 88) | def iter_discrete_traces(graph_type, fn, *args, **kwargs): function _config_fn (line 114) | def _config_fn(default, expand, num_samples, tmc, site): function _config_enumerate (line 134) | def _config_enumerate(default, expand, num_samples, tmc): function config_enumerate (line 138) | def config_enumerate( FILE: pyro/infer/importance.py class LogWeightsMixin (line 19) | class LogWeightsMixin: method get_log_normalizer (line 26) | def get_log_normalizer(self): method get_normalized_weights (line 45) | def get_normalized_weights(self, log_scale=False): method get_ESS (line 62) | def get_ESS(self): class Importance (line 77) | class Importance(TracePosterior, LogWeightsMixin): method __init__ (line 88) | def __init__(self, model, guide=None, num_samples=None): method _traces (line 105) | def _traces(self, *args, **kwargs): function vectorized_importance_weights (line 118) | def vectorized_importance_weights(model, guide, *args, **kwargs): function psis_diagnostic (line 174) | def psis_diagnostic(model, guide, *args, **kwargs): FILE: pyro/infer/inspect.py function is_sample_site (line 26) | def is_sample_site(msg, *, include_deterministic=False): function site_is_deterministic (line 47) | def site_is_deterministic(msg: dict) -> bool: class TrackProvenance (line 51) | class TrackProvenance(Messenger): method __init__ (line 52) | def __init__(self, *, include_deterministic=False): method _pyro_post_sample (line 55) | def _pyro_post_sample(self, msg): method _pyro_post_param (line 66) | def _pyro_post_param(self, msg): function get_dependencies (line 74) | def get_dependencies( function get_model_relations (line 252) | def get_model_relations( function _get_dist_name (line 381) | def _get_dist_name(fn): function generate_graph_specification (line 389) | def generate_graph_specification( function _deep_merge (line 472) | def _deep_merge(things: list): function render_graph (line 497) | def render_graph( function render_model (line 598) | def render_model( FILE: pyro/infer/mcmc/adaptation.py class WarmupAdapter (line 23) | class WarmupAdapter: method __init__ (line 31) | def __init__( method _build_adaptation_schedule (line 65) | def _build_adaptation_schedule(self): method reset_step_size_adaptation (line 105) | def reset_step_size_adaptation(self, z): method _update_step_size (line 115) | def _update_step_size(self, accept_prob): method _end_adaptation (line 122) | def _end_adaptation(self): method configure (line 127) | def configure( method step (line 166) | def step(self, t, z, accept_prob, z_grad=None): method adaptation_schedule (line 205) | def adaptation_schedule(self): method mass_matrix_adapter (line 209) | def mass_matrix_adapter(self): method mass_matrix_adapter (line 213) | def mass_matrix_adapter(self, value): function _matvecmul (line 218) | def _matvecmul(x, y): function _cholesky (line 222) | def _cholesky(x): function _transpose (line 226) | def _transpose(x): function _triu_inverse (line 230) | def _triu_inverse(x): class BlockMassMatrix (line 238) | class BlockMassMatrix: method __init__ (line 250) | def __init__(self, init_scale=1.0): method mass_matrix_size (line 261) | def mass_matrix_size(self): method inverse_mass_matrix (line 268) | def inverse_mass_matrix(self): method inverse_mass_matrix (line 272) | def inverse_mass_matrix(self, value): method configure (line 282) | def configure(self, mass_matrix_shape, adapt_mass_matrix=True, options... method update (line 306) | def update(self, z, z_grad): method end_adaptation (line 317) | def end_adaptation(self): method kinetic_grad (line 328) | def kinetic_grad(self, r): method scale (line 349) | def scale(self, r_unscaled, r_prototype): method unscale (line 375) | def unscale(self, r): class ArrowheadMassMatrix (line 395) | class ArrowheadMassMatrix: method __init__ (line 406) | def __init__(self, init_scale=1.0): method mass_matrix_size (line 416) | def mass_matrix_size(self): method inverse_mass_matrix (line 423) | def inverse_mass_matrix(self): method mass_matrix (line 434) | def mass_matrix(self): method mass_matrix (line 438) | def mass_matrix(self, value): method configure (line 449) | def configure(self, mass_matrix_shape, adapt_mass_matrix=True, options... method update (line 484) | def update(self, z, z_grad): method end_adaptation (line 495) | def end_adaptation(self): method kinetic_grad (line 505) | def kinetic_grad(self, r): method scale (line 536) | def scale(self, r_unscaled, r_prototype): method unscale (line 562) | def unscale(self, r): FILE: pyro/infer/mcmc/api.py function logger_thread (line 48) | def logger_thread( class _Worker (line 88) | class _Worker: method __init__ (line 89) | def __init__( method run (line 115) | def run(self, *args, **kwargs): function _gen_samples (line 145) | def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *arg... function _add_logging_hook (line 173) | def _add_logging_hook(logger, progress_bar=None, hook=None): class _UnarySampler (line 185) | class _UnarySampler: method __init__ (line 190) | def __init__( method terminate (line 210) | def terminate(self, *args, **kwargs): method run (line 213) | def run(self, *args, **kwargs): class _MultiSampler (line 239) | class _MultiSampler: method __init__ (line 246) | def __init__( method init_workers (line 286) | def init_workers(self, *args, **kwargs): method terminate (line 312) | def terminate(self, terminate_workers=False): method run (line 324) | def run(self, *args, **kwargs): class AbstractMCMC (line 354) | class AbstractMCMC(ABC): method __init__ (line 359) | def __init__(self, kernel, num_chains, transforms): method run (line 365) | def run(self, *args, **kwargs): method diagnostics (line 369) | def diagnostics(self): method _set_transforms (line 372) | def _set_transforms(self, *args, **kwargs): method _validate_kernel (line 385) | def _validate_kernel(self, initial_params): method _validate_initial_params (line 396) | def _validate_initial_params(self, initial_params): class MCMC (line 405) | class MCMC(AbstractMCMC): method __init__ (line 453) | def __init__( method run (line 532) | def run(self, *args, **kwargs): method get_samples (line 608) | def get_samples(self, num_samples=None, group_by_chain=False): method diagnostics (line 617) | def diagnostics(self): method summary (line 630) | def summary(self, prob=0.9): class StreamingMCMC (line 653) | class StreamingMCMC(AbstractMCMC): method __init__ (line 662) | def __init__( method run (line 710) | def run(self, *args, **kwargs): method get_statistics (line 763) | def get_statistics(self, group_by_chain=True): method diagnostics (line 782) | def diagnostics(self): FILE: pyro/infer/mcmc/hmc.py class HMC (line 21) | class HMC(MCMCKernel): method __init__ (line 96) | def __init__( method _kinetic_energy (line 152) | def _kinetic_energy(self, r_unscaled): method _reset (line 158) | def _reset(self): method _find_reasonable_step_size (line 170) | def _find_reasonable_step_size(self, z): method _sample_r (line 231) | def _sample_r(self, name): method mass_matrix_adapter (line 251) | def mass_matrix_adapter(self): method mass_matrix_adapter (line 255) | def mass_matrix_adapter(self, value): method inverse_mass_matrix (line 259) | def inverse_mass_matrix(self): method step_size (line 263) | def step_size(self): method num_steps (line 267) | def num_steps(self): method initial_params (line 271) | def initial_params(self): method initial_params (line 275) | def initial_params(self, params): method _initialize_model_properties (line 278) | def _initialize_model_properties(self, model_args, model_kwargs): method _initialize_adapter (line 296) | def _initialize_adapter(self): method setup (line 342) | def setup(self, warmup_steps, *args, **kwargs): method cleanup (line 355) | def cleanup(self): method _cache (line 358) | def _cache(self, z, potential_energy, z_grads=None): method clear_cache (line 363) | def clear_cache(self): method _fetch_from_cache (line 368) | def _fetch_from_cache(self): method sample (line 371) | def sample(self, params): method logging (line 440) | def logging(self): method diagnostics (line 448) | def diagnostics(self): FILE: pyro/infer/mcmc/logger.py class ProgressBar (line 45) | class ProgressBar: method __init__ (line 59) | def __init__( method __enter__ (line 95) | def __enter__(self): method __exit__ (line 98) | def __exit__(self, *exc): method set_description (line 102) | def set_description(self, *args, **kwargs): method set_postfix (line 107) | def set_postfix(self, *args, **kwargs): method update (line 112) | def update(self, *args, **kwargs): method close (line 117) | def close(self): class QueueHandler (line 125) | class QueueHandler(logging.Handler): method __init__ (line 136) | def __init__(self, queue): method enqueue (line 143) | def enqueue(self, record): method prepare (line 153) | def prepare(self, record): method emit (line 171) | def emit(self, record): class TqdmHandler (line 183) | class TqdmHandler(logging.StreamHandler): method emit (line 189) | def emit(self, record): class MCMCLoggingHandler (line 200) | class MCMCLoggingHandler(logging.Handler): method __init__ (line 212) | def __init__(self, log_handler, progress_bar=None): method emit (line 217) | def emit(self, record): class MetadataFilter (line 233) | class MetadataFilter(logging.Filter): method __init__ (line 239) | def __init__(self, logger_id): method filter (line 243) | def filter(self, record): function initialize_logger (line 250) | def initialize_logger(logger, logger_id, progress_bar=None, log_queue=No... FILE: pyro/infer/mcmc/mcmc_kernel.py class MCMCKernel (line 7) | class MCMCKernel(object, metaclass=ABCMeta): method setup (line 8) | def setup(self, warmup_steps, *args, **kwargs): method cleanup (line 19) | def cleanup(self): method logging (line 25) | def logging(self): method diagnostics (line 35) | def diagnostics(self): method end_warmup (line 42) | def end_warmup(self): method initial_params (line 49) | def initial_params(self): method initial_params (line 58) | def initial_params(self, params): method sample (line 66) | def sample(self, params): method __call__ (line 76) | def __call__(self, params): FILE: pyro/infer/mcmc/nuts.py function _logaddexp (line 15) | def _logaddexp(x, y): class NUTS (line 55) | class NUTS(HMC): method __init__ (line 137) | def __init__( method _is_turning (line 184) | def _is_turning(self, r_left_unscaled, r_right_unscaled, r_sum): method _build_basetree (line 197) | def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_... method _build_tree (line 250) | def _build_tree( method sample (line 367) | def sample(self, params): FILE: pyro/infer/mcmc/rwkernel.py class RandomWalkKernel (line 15) | class RandomWalkKernel(MCMCKernel): method __init__ (line 46) | def __init__( method setup (line 71) | def setup(self, warmup_steps, *args, **kwargs): method sample (line 85) | def sample(self, params): method initial_params (line 125) | def initial_params(self): method initial_params (line 129) | def initial_params(self, params): method logging (line 132) | def logging(self): method diagnostics (line 140) | def diagnostics(self): FILE: pyro/infer/mcmc/util.py class TraceTreeEvaluator (line 29) | class TraceTreeEvaluator: method __init__ (line 43) | def __init__(self, model_trace, has_enumerable_sites=False, max_plate_... method _parse_model_structure (line 54) | def _parse_model_structure(self, model_trace): method _populate_cache (line 75) | def _populate_cache(self, ordinal, parent_ordinal, parent_enum_dims): method _compute_log_prob_terms (line 94) | def _compute_log_prob_terms(self, model_trace): method _reduce (line 118) | def _reduce(self, ordinal, agg_log_prob=torch.tensor(0.0)): method _aggregate_log_probs (line 139) | def _aggregate_log_probs(self, ordinal): method log_prob (line 148) | def log_prob(self, model_trace): class TraceEinsumEvaluator (line 162) | class TraceEinsumEvaluator: method __init__ (line 177) | def __init__(self, model_trace, has_enumerable_sites=False, max_plate_... method _populate_cache (line 185) | def _populate_cache(self, model_trace): method _get_log_factors (line 212) | def _get_log_factors(self, model_trace): method log_prob (line 230) | def log_prob(self, model_trace): function _guess_max_plate_nesting (line 244) | def _guess_max_plate_nesting(model, args, kwargs): class _PEMaker (line 264) | class _PEMaker: method __init__ (line 265) | def __init__( method _potential_fn (line 275) | def _potential_fn(self, params): method _potential_fn_jit (line 288) | def _potential_fn_jit(self, skip_jit_warnings, jit_options, params): method get_potential_fn (line 316) | def get_potential_fn( function _find_valid_initial_params (line 325) | def _find_valid_initial_params( function initialize_model (line 370) | def initialize_model( function _safe (line 485) | def _safe(fn): function diagnostics (line 507) | def diagnostics(samples, group_by_chain=True): function summary (line 531) | def summary(samples, prob=0.9, group_by_chain=True): function print_summary (line 573) | def print_summary(samples, prob=0.9, group_by_chain=True): function _predictive_sequential (line 619) | def _predictive_sequential( function predictive (line 650) | def predictive(model, posterior_samples, *args, **kwargs): function select_samples (line 777) | def select_samples(samples, num_samples=None, group_by_chain=False): function diagnostics_from_stats (line 807) | def diagnostics_from_stats(statistics, num_samples, num_chains): FILE: pyro/infer/predictive.py function _guess_max_plate_nesting (line 20) | def _guess_max_plate_nesting(model, args, kwargs): class _predictiveResults (line 41) | class _predictiveResults: function _predictive_sequential (line 50) | def _predictive_sequential( function _predictive (line 79) | def _predictive( class Predictive (line 162) | class Predictive(torch.nn.Module): method __init__ (line 188) | def __init__( method call (line 239) | def call(self, *args, **kwargs): method forward (line 254) | def forward(self, *args, **kwargs): method get_samples (line 291) | def get_samples(self, *args, **kwargs): method get_vectorized_trace (line 298) | def get_vectorized_trace(self, *args, **kwargs): class WeighedPredictiveResults (line 327) | class WeighedPredictiveResults(LogWeightsMixin, CloneMixin): class WeighedPredictive (line 338) | class WeighedPredictive(Predictive): method call (line 365) | def call(self, *args, **kwargs): method forward (line 382) | def forward(self, *args, **kwargs): class MHResampler (line 459) | class MHResampler(torch.nn.Module): method __init__ (line 549) | def __init__( method forward (line 564) | def forward(self, *args, **kwargs): method get_min_sample_transition_count (line 600) | def get_min_sample_transition_count(self): method get_total_transition_count (line 606) | def get_total_transition_count(self): method get_source_samples (line 612) | def get_source_samples(self): method get_stored_samples (line 618) | def get_stored_samples(self): method get_samples (line 624) | def get_samples(self, samples): FILE: pyro/infer/renyi_elbo.py class RenyiELBO (line 16) | class RenyiELBO(ELBO): method __init__ (line 52) | def __init__( method _get_trace (line 81) | def _get_trace(self, model, guide, args, kwargs): method loss (line 94) | def loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 138) | def loss_and_grads(self, model, guide, *args, **kwargs): FILE: pyro/infer/reparam/conjugate.py class ConjugateReparam (line 11) | class ConjugateReparam(Reparam): method __init__ (line 51) | def __init__(self, guide): method apply (line 54) | def apply(self, msg): FILE: pyro/infer/reparam/discrete_cosine.py class DiscreteCosineReparam (line 9) | class DiscreteCosineReparam(UnitJacobianReparam): method __init__ (line 38) | def __init__(self, dim=-1, smooth=0.0, *, experimental_allow_batch=Fal... FILE: pyro/infer/reparam/haar.py class HaarReparam (line 9) | class HaarReparam(UnitJacobianReparam): method __init__ (line 32) | def __init__(self, dim=-1, flip=False, *, experimental_allow_batch=Fal... FILE: pyro/infer/reparam/hmm.py class LinearHMMReparam (line 9) | class LinearHMMReparam(Reparam): method __init__ (line 56) | def __init__(self, init=None, trans=None, obs=None): method apply (line 64) | def apply(self, msg): FILE: pyro/infer/reparam/loc_scale.py class LocScaleReparam (line 14) | class LocScaleReparam(Reparam): method __init__ (line 35) | def __init__(self, centered=None, shape_params=None): method apply (line 51) | def apply(self, msg): FILE: pyro/infer/reparam/neutra.py class NeuTraReparam (line 18) | class NeuTraReparam(Reparam): method __init__ (line 48) | def __init__(self, guide): method _reparam_config (line 60) | def _reparam_config(self, site): method reparam (line 64) | def reparam(self, fn=None): method apply (line 67) | def apply(self, msg): method transform_sample (line 124) | def transform_sample(self, latent): FILE: pyro/infer/reparam/projected_normal.py class ProjectedNormalReparam (line 13) | class ProjectedNormalReparam(Reparam): method apply (line 21) | def apply(self, msg): FILE: pyro/infer/reparam/reparam.py class ReparamMessage (line 12) | class ReparamMessage(TypedDict): class ReparamResult (line 19) | class ReparamResult(TypedDict): class Reparam (line 25) | class Reparam(ABC): method apply (line 33) | def apply(self, msg: ReparamMessage) -> ReparamResult: method __call__ (line 58) | def __call__(self, name, fn, obs): method _unwrap (line 66) | def _unwrap(self, fn): method _wrap (line 75) | def _wrap(self, fn, event_dim): FILE: pyro/infer/reparam/softmax.py class GumbelSoftmaxReparam (line 12) | class GumbelSoftmaxReparam(Reparam): method apply (line 23) | def apply(self, msg): FILE: pyro/infer/reparam/split.py function same_support (line 15) | def same_support(fn: TorchDistributionMixin, *args): function real_support (line 26) | def real_support(fn: TorchDistributionMixin, *args): function default_support (line 37) | def default_support(fn: TorchDistributionMixin, slice, dim): class SplitReparam (line 65) | class SplitReparam(Reparam): method __init__ (line 91) | def __init__(self, sections, dim, support_fn=default_support): method apply (line 99) | def apply(self, msg): FILE: pyro/infer/reparam/stable.py class LatentStableReparam (line 16) | class LatentStableReparam(Reparam): method apply (line 40) | def apply(self, msg): class SymmetricStableReparam (line 79) | class SymmetricStableReparam(Reparam): method apply (line 101) | def apply(self, msg): class StableReparam (line 143) | class StableReparam(Reparam): method apply (line 162) | def apply(self, msg): function _unsafe_shift (line 242) | def _unsafe_shift(a, skew, t_scale): function _safe_shift (line 247) | def _safe_shift(a, skew, t_scale, skew_abs): FILE: pyro/infer/reparam/strategies.py class Strategy (line 28) | class Strategy(ABC): method __init__ (line 44) | def __init__(self): method configure (line 50) | def configure(self, msg: dict) -> Optional[Reparam]: method __call__ (line 63) | def __call__(self, msg_or_fn: Union[dict, Callable]): class MinimalReparam (line 83) | class MinimalReparam(Strategy): method configure (line 103) | def configure(self, msg: dict) -> Optional[Reparam]: function _minimal_reparam (line 107) | def _minimal_reparam(fn, is_observed): class AutoReparam (line 131) | class AutoReparam(Strategy): method __init__ (line 162) | def __init__(self, *, centered: Optional[float] = None): method configure (line 167) | def configure(self, msg: dict) -> Optional[Reparam]: function _loc_scale_reparam (line 190) | def _loc_scale_reparam(name, fn, centered): function _is_unconstrained (line 212) | def _is_unconstrained(constraint): FILE: pyro/infer/reparam/structured.py class StructuredReparam (line 14) | class StructuredReparam(Reparam): method __init__ (line 49) | def __init__(self, guide: AutoStructured): method _reparam_config (line 57) | def _reparam_config(self, site): method reparam (line 61) | def reparam(self, fn=None): method apply (line 64) | def apply(self, msg): method transform_samples (line 90) | def transform_samples(self, aux_samples, save_params=None): FILE: pyro/infer/reparam/studentt.py class StudentTReparam (line 10) | class StudentTReparam(Reparam): method apply (line 25) | def apply(self, msg): FILE: pyro/infer/reparam/transform.py class TransformReparam (line 12) | class TransformReparam(Reparam): method apply (line 24) | def apply(self, msg): FILE: pyro/infer/reparam/unit_jacobian.py class UnitJacobianReparam (line 16) | class UnitJacobianReparam(Reparam): method __init__ (line 29) | def __init__( method apply (line 36) | def apply(self, msg): FILE: pyro/infer/resampler.py class Resampler (line 14) | class Resampler: method __init__ (line 37) | def __init__( method sample (line 68) | def sample( method _categorical_sample (line 97) | def _categorical_sample( function _log_prob_sum (line 112) | def _log_prob_sum(trace: Trace, batch_size: int) -> torch.Tensor: function _guess_max_plate_nesting (line 124) | def _guess_max_plate_nesting(model: callable) -> int: FILE: pyro/infer/rws.py class ReweightedWakeSleep (line 17) | class ReweightedWakeSleep(ELBO): method __init__ (line 77) | def __init__( method _get_trace (line 106) | def _get_trace(self, model, guide, args, kwargs): method _loss (line 117) | def _loss(self, model, guide, args, kwargs): method loss (line 216) | def loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 229) | def loss_and_grads(self, model, guide, *args, **kwargs): method _vectorized_num_sleep_particles (line 243) | def _vectorized_num_sleep_particles(self, fn): method _get_matched_trace (line 261) | def _get_matched_trace(model_trace, guide, args, kwargs): FILE: pyro/infer/smcfilter.py class SMCFailed (line 16) | class SMCFailed(ValueError): class SMCFilter (line 25) | class SMCFilter: method __init__ (line 56) | def __init__( method init (line 70) | def init(self, *args, **kwargs): method step (line 89) | def step(self, *args, **kwargs): method get_empirical (line 106) | def get_empirical(self): method _update_weights (line 118) | def _update_weights(self, model_trace, guide_trace): method _maybe_importance_resample (line 151) | def _maybe_importance_resample(self): method _importance_resample (line 162) | def _importance_resample(self, probs): function _systematic_sample (line 167) | def _systematic_sample(probs): class SMCState (line 179) | class SMCState(dict): method __init__ (line 191) | def __init__(self, num_particles): method _lock (line 199) | def _lock(self): method __setitem__ (line 206) | def __setitem__(self, key, value): method _resample (line 224) | def _resample(self, index): FILE: pyro/infer/svgd.py function vectorize (line 19) | def vectorize(fn, num_particles, max_plate_nesting): class _SVGDGuide (line 29) | class _SVGDGuide(AutoContinuous): method __init__ (line 35) | def __init__(self, model): method get_posterior (line 38) | def get_posterior(self, *args, **kwargs): class SteinKernel (line 43) | class SteinKernel(object, metaclass=ABCMeta): method log_kernel_and_grad (line 49) | def log_kernel_and_grad(self, particles): class RBFSteinKernel (line 63) | class RBFSteinKernel(SteinKernel): method __init__ (line 78) | def __init__(self, bandwidth_factor=None): method _bandwidth (line 84) | def _bandwidth(self, norm_sq): method log_kernel_and_grad (line 98) | def log_kernel_and_grad(self, particles): method bandwidth_factor (line 109) | def bandwidth_factor(self): method bandwidth_factor (line 113) | def bandwidth_factor(self, bandwidth_factor): class IMQSteinKernel (line 123) | class IMQSteinKernel(SteinKernel): method __init__ (line 144) | def __init__(self, alpha=0.5, beta=-0.5, bandwidth_factor=None): method _bandwidth (line 156) | def _bandwidth(self, norm_sq): method log_kernel_and_grad (line 170) | def log_kernel_and_grad(self, particles): method bandwidth_factor (line 183) | def bandwidth_factor(self): method bandwidth_factor (line 187) | def bandwidth_factor(self, bandwidth_factor): class SVGD (line 196) | class SVGD: method __init__ (line 234) | def __init__( method get_named_particles (line 259) | def get_named_particles(self): method step (line 272) | def step(self, *args, **kwargs): FILE: pyro/infer/svi.py class SVI (line 16) | class SVI(TracePosterior): method __init__ (line 38) | def __init__( method run (line 92) | def run(self, *args, **kwargs): method _traces (line 111) | def _traces(self, *args, **kwargs): method evaluate_loss (line 119) | def evaluate_loss(self, *args, **kwargs): method step (line 134) | def step(self, *args, **kwargs): FILE: pyro/infer/trace_elbo.py function _compute_log_r (line 20) | def _compute_log_r(model_trace, guide_trace): class Trace_ELBO (line 32) | class Trace_ELBO(ELBO): method _get_trace (line 52) | def _get_trace(self, model, guide, args, kwargs): method loss (line 64) | def loss(self, model, guide, *args, **kwargs): method _differentiable_loss_particle (line 82) | def _differentiable_loss_particle(self, model_trace, guide_trace): method differentiable_loss (line 114) | def differentiable_loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 130) | def loss_and_grads(self, model, guide, *args, **kwargs): class JitTrace_ELBO (line 162) | class JitTrace_ELBO(Trace_ELBO): method loss_and_surrogate_loss (line 177) | def loss_and_surrogate_loss(self, model, guide, *args, **kwargs): method differentiable_loss (line 241) | def differentiable_loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 249) | def loss_and_grads(self, model, guide, *args, **kwargs): FILE: pyro/infer/trace_mean_field_elbo.py function _check_mean_field_requirement (line 21) | def _check_mean_field_requirement(model_trace, guide_trace): class TraceMeanField_ELBO (line 49) | class TraceMeanField_ELBO(Trace_ELBO): method _get_trace (line 81) | def _get_trace(self, model, guide, args, kwargs): method loss (line 87) | def loss(self, model, guide, *args, **kwargs): method _differentiable_loss_particle (line 104) | def _differentiable_loss_particle(self, model_trace, guide_trace): class JitTraceMeanField_ELBO (line 159) | class JitTraceMeanField_ELBO(TraceMeanField_ELBO): method differentiable_loss (line 174) | def differentiable_loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 202) | def loss_and_grads(self, model, guide, *args, **kwargs): FILE: pyro/infer/trace_mmd.py function _compute_mmd (line 17) | def _compute_mmd(X, Z, kernel): class Trace_MMD (line 22) | class Trace_MMD(ELBO): method __init__ (line 64) | def __init__( method kernel (line 91) | def kernel(self): method kernel (line 95) | def kernel(self, kernel): method mmd_scale (line 115) | def mmd_scale(self): method mmd_scale (line 119) | def mmd_scale(self, mmd_scale): method _get_trace (line 127) | def _get_trace(self, model, guide, args, kwargs): method _differentiable_loss_parts (line 139) | def _differentiable_loss_parts(self, model, guide, args, kwargs): method differentiable_loss (line 219) | def differentiable_loss(self, model, guide, *args, **kwargs): method loss (line 238) | def loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 254) | def loss_and_grads(self, model, guide, *args, **kwargs): FILE: pyro/infer/trace_tail_adaptive_elbo.py class TraceTailAdaptive_ELBO (line 12) | class TraceTailAdaptive_ELBO(Trace_ELBO): method loss (line 35) | def loss(self, model, guide, *args, **kwargs): method _differentiable_loss_particle (line 45) | def _differentiable_loss_particle(self, model_trace, guide_trace): FILE: pyro/infer/traceenum_elbo.py function _get_common_scale (line 32) | def _get_common_scale(scales): function _check_model_guide_enumeration_constraint (line 50) | def _check_model_guide_enumeration_constraint(model_enum_sites, guide_tr... function _check_tmc_elbo_constraint (line 68) | def _check_tmc_elbo_constraint(model_trace, guide_trace): function _find_ordinal (line 105) | def _find_ordinal(trace, site): function _compute_model_factors (line 112) | def _compute_model_factors(model_trace, guide_trace): function _compute_dice_elbo (line 178) | def _compute_dice_elbo(model_trace, guide_trace): function _make_dist (line 217) | def _make_dist(dist_, logits): function _compute_marginals (line 224) | def _compute_marginals(model_trace, guide_trace): class BackwardSampleMessenger (line 256) | class BackwardSampleMessenger(pyro.poutine.messenger.Messenger): method __init__ (line 262) | def __init__(self, enum_trace, guide_trace): method __enter__ (line 268) | def __enter__(self): method __exit__ (line 272) | def __exit__(self, exc_type, exc_value, traceback): method _pyro_sample (line 277) | def _pyro_sample(self, msg): method _pyro_post_sample (line 300) | def _pyro_post_sample(self, msg): class TraceEnum_ELBO (line 316) | class TraceEnum_ELBO(ELBO): method _get_trace (line 334) | def _get_trace(self, model, guide, args, kwargs): method _get_traces (line 366) | def _get_traces(self, model, guide, args, kwargs): method loss (line 396) | def loss(self, model, guide, *args, **kwargs): method differentiable_loss (line 415) | def differentiable_loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 442) | def loss_and_grads(self, model, guide, *args, **kwargs): method compute_marginals (line 473) | def compute_marginals(self, model, guide, *args, **kwargs): method sample_posterior (line 495) | def sample_posterior(self, model, guide, *args, **kwargs): class JitTraceEnum_ELBO (line 523) | class JitTraceEnum_ELBO(TraceEnum_ELBO): method differentiable_loss (line 538) | def differentiable_loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 563) | def loss_and_grads(self, model, guide, *args, **kwargs): FILE: pyro/infer/tracegraph_elbo.py function _get_baseline_options (line 28) | def _get_baseline_options(site): function _construct_baseline (line 48) | def _construct_baseline(node, guide_site, downstream_cost): function _compute_downstream_costs (line 103) | def _compute_downstream_costs(model_trace, guide_trace, non_reparam_node... function _compute_elbo (line 178) | def _compute_elbo(model_trace, guide_trace): class TrackNonReparam (line 239) | class TrackNonReparam(Messenger): method _pyro_post_sample (line 279) | def _pyro_post_sample(self, msg): class TraceGraph_ELBO (line 290) | class TraceGraph_ELBO(ELBO): method _get_trace (line 313) | def _get_trace(self, model, guide, args, kwargs): method loss (line 326) | def loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 344) | def loss_and_grads(self, model, guide, *args, **kwargs): method _loss_and_surrogate_loss (line 362) | def _loss_and_surrogate_loss(self, model, guide, args, kwargs): method _loss_and_surrogate_loss_particle (line 376) | def _loss_and_surrogate_loss_particle(self, model_trace, guide_trace): class JitTraceGraph_ELBO (line 382) | class JitTraceGraph_ELBO(TraceGraph_ELBO): method loss_and_grads (line 397) | def loss_and_grads(self, model, guide, *args, **kwargs): FILE: pyro/infer/tracetmc_elbo.py function _compute_dice_factors (line 24) | def _compute_dice_factors(model_trace, guide_trace): function _compute_tmc_factors (line 48) | def _compute_tmc_factors(model_trace, guide_trace): function _compute_tmc_estimate (line 75) | def _compute_tmc_estimate(model_trace, guide_trace): class TraceTMC_ELBO (line 105) | class TraceTMC_ELBO(ELBO): method _get_trace (line 133) | def _get_trace(self, model, guide, args, kwargs): method _get_traces (line 165) | def _get_traces(self, model, guide, args, kwargs): method differentiable_loss (line 193) | def differentiable_loss(self, model, guide, *args, **kwargs): method loss (line 217) | def loss(self, model, guide, *args, **kwargs): method loss_and_grads (line 224) | def loss_and_grads(self, model, guide, *args, **kwargs): FILE: pyro/infer/util.py function enable_validation (line 29) | def enable_validation(is_validate): function is_validation_enabled (line 34) | def is_validation_enabled(): function validation_enabled (line 39) | def validation_enabled(is_validate=True): function torch_item (line 48) | def torch_item(x): function torch_backward (line 55) | def torch_backward(x, retain_graph=None): function torch_exp (line 64) | def torch_exp(x): function torch_sum (line 75) | def torch_sum(tensor, dims): function zero_grads (line 85) | def zero_grads(tensors): function get_plate_stacks (line 94) | def get_plate_stacks(trace): function get_dependent_plate_dims (line 108) | def get_dependent_plate_dims(sites): class MultiFrameTensor (line 122) | class MultiFrameTensor(dict): method __init__ (line 138) | def __init__(self, *items): method add (line 142) | def add(self, *items): method sum_to (line 156) | def sum_to(self, target_frames): method __repr__ (line 167) | def __repr__(self): function compute_site_dice_factor (line 174) | def compute_site_dice_factor(site): class Dice (line 199) | class Dice: method __init__ (line 228) | def __init__(self, guide_trace, ordering): method _get_log_factors (line 248) | def _get_log_factors(self, target_ordinal): method compute_expectation (line 264) | def compute_expectation(self, costs): function _fulldot (line 329) | def _fulldot(x, y): function check_fully_reparametrized (line 336) | def check_fully_reparametrized(guide_site): function plate_log_prob_sum (line 349) | def plate_log_prob_sum(trace: Trace, plate_symbol: str) -> torch.Tensor: class CloneMixin (line 364) | class CloneMixin: method clone (line 370) | def clone(self): FILE: pyro/nn/auto_reg_nn.py function sample_mask_indices (line 12) | def sample_mask_indices( function create_mask (line 36) | def create_mask( class MaskedLinear (line 103) | class MaskedLinear(nn.Linear): method __init__ (line 117) | def __init__( method forward (line 123) | def forward(self, _input: torch.Tensor) -> torch.Tensor: class ConditionalAutoRegressiveNN (line 128) | class ConditionalAutoRegressiveNN(nn.Module): method __init__ (line 174) | def __init__( method get_permutation (line 255) | def get_permutation(self) -> torch.LongTensor: method forward (line 261) | def forward( method _forward (line 272) | def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], t... class AutoRegressiveNN (line 301) | class AutoRegressiveNN(ConditionalAutoRegressiveNN): method __init__ (line 340) | def __init__( method forward (line 359) | def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], to... FILE: pyro/nn/dense_nn.py class ConditionalDenseNN (line 9) | class ConditionalDenseNN(torch.nn.Module): method __init__ (line 38) | def __init__( method forward (line 70) | def forward( method _forward (line 79) | def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], t... class DenseNN (line 101) | class DenseNN(ConditionalDenseNN): method __init__ (line 128) | def __init__( method forward (line 139) | def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], to... FILE: pyro/nn/module.py function _copy_to_script_wrapper (line 28) | def _copy_to_script_wrapper(fn): function _validate_module_local_params (line 71) | def _validate_module_local_params(value: bool) -> None: function _is_module_local_param_enabled (line 75) | def _is_module_local_param_enabled() -> bool: class PyroParam (line 79) | class PyroParam(NamedTuple): method __get__ (line 130) | def __get__( method __call__ (line 147) | def __call__( class PyroSample (line 155) | class PyroSample: method __post_init__ (line 197) | def __post_init__(self) -> None: method __get__ (line 214) | def __get__( function _make_name (line 235) | def _make_name(prefix: str, name: str) -> str: function _unconstrain (line 239) | def _unconstrain( class _Context (line 250) | class _Context: method __init__ (line 255) | def __init__(self) -> None: method __enter__ (line 262) | def __enter__(self) -> None: method __exit__ (line 269) | def __exit__( method get (line 282) | def get(self, name: str) -> Optional[torch.Tensor]: method set (line 287) | def set(self, name: str, value: torch.Tensor) -> None: function _get_pyro_params (line 292) | def _get_pyro_params( class _PyroModuleMeta (line 307) | class _PyroModuleMeta(type): class _New (line 311) | class _New: method __init__ (line 312) | def __init__(self, Module): method __getitem__ (line 315) | def __getitem__(cls, Module: Type[torch.nn.Module]) -> Type["PyroModul... class PyroModule (line 339) | class PyroModule(torch.nn.Module, metaclass=_PyroModuleMeta): method __init__ (line 473) | def __init__(self, name: str = "") -> None: method add_module (line 482) | def add_module(self, name: str, module: Optional[torch.nn.Module]) -> ... method named_pyro_params (line 492) | def named_pyro_params( method _pyro_set_supermodule (line 509) | def _pyro_set_supermodule(self, name: str, context: _Context) -> None: method _pyro_get_fullname (line 521) | def _pyro_get_fullname(self, name: str) -> str: method __call__ (line 525) | def __call__(self, *args: Any, **kwargs: Any) -> Any: method _check_module_local_param_usage (line 536) | def _check_module_local_param_usage(self) -> None: method __getattr__ (line 548) | def __getattr__(self, name: str) -> Any: method __setattr__ (line 673) | def __setattr__( method __delattr__ (line 784) | def __delattr__(self, name: str) -> None: method __getstate__ (line 819) | def __getstate__(self) -> Dict[str, Any]: function pyro_method (line 826) | def pyro_method( function clear (line 845) | def clear(mod: PyroModule) -> None: function to_pyro_module_ (line 860) | def to_pyro_module_(m: torch.nn.Module, recurse: bool = True) -> None: class _FlatWeightsDescriptor (line 925) | class _FlatWeightsDescriptor: method __get__ (line 926) | def __get__( method __set__ (line 935) | def __set__(self, obj: object, value: Any) -> None: class PyroModuleList (line 953) | class PyroModuleList(torch.nn.ModuleList, PyroModule): method __init__ (line 954) | def __init__(self, modules): method __getitem__ (line 958) | def __getitem__( FILE: pyro/ops/arrowhead.py function sqrt (line 12) | def sqrt(x): function triu_inverse (line 60) | def triu_inverse(x): function triu_matvecmul (line 84) | def triu_matvecmul(x, y, transpose=False): function triu_gram (line 108) | def triu_gram(x): FILE: pyro/ops/contract.py function _check_plates_are_sensible (line 16) | def _check_plates_are_sensible(output_dims, nonoutput_ordinal): function _check_tree_structure (line 26) | def _check_tree_structure(parent, leaf): function _partition_terms (line 38) | def _partition_terms(ring, terms, dims): function _contract_component (line 79) | def _contract_component(ring, tensor_tree, sum_dims, target_dims): function contract_tensor_tree (line 163) | def contract_tensor_tree(tensor_tree, sum_dims, cache=None, ring=None): function contract_to_tensor (line 205) | def contract_to_tensor( function einsum (line 276) | def einsum(equation, *operands, **kwargs): function ubersum (line 429) | def ubersum(equation, *operands, **kwargs): function _select (line 446) | def _select(tensor, dims, indices): class _DimUnroller (line 452) | class _DimUnroller: method __init__ (line 460) | def __init__(self, dim_to_ordinal): method __call__ (line 467) | def __call__(self, dim, indices): function naive_ubersum (line 486) | def naive_ubersum(equation, *operands, **kwargs): FILE: pyro/ops/dual_averaging.py class DualAveraging (line 5) | class DualAveraging: method __init__ (line 43) | def __init__(self, prox_center=0, t0=10, kappa=0.75, gamma=0.05): method reset (line 50) | def reset(self): method step (line 55) | def step(self, g): method get_state (line 74) | def get_state(self): FILE: pyro/ops/einsum/__init__.py function contract_expression (line 11) | def contract_expression(equation, *shapes, **kwargs): function contract (line 33) | def contract(equation, *operands, **kwargs): FILE: pyro/ops/einsum/adjoint.py class Backward (line 15) | class Backward(object, metaclass=ABCMeta): method __call__ (line 18) | def __call__(self): method process (line 29) | def process(self, message): class _LeafBackward (line 33) | class _LeafBackward(Backward): method __init__ (line 36) | def __init__(self, target): method process (line 39) | def process(self, message): function require_backward (line 46) | def require_backward(tensor): class _TransposeBackward (line 53) | class _TransposeBackward(Backward): method __init__ (line 54) | def __init__(self, a, axes): method process (line 58) | def process(self, message): function transpose (line 69) | def transpose(a, axes): function einsum_backward_sample (line 77) | def einsum_backward_sample(operands, sample1, sample2): function unflatten (line 132) | def unflatten(flat_sample, output_dims, contract_dims, contract_shape): FILE: pyro/ops/einsum/torch_log.py function transpose (line 10) | def transpose(a, axes): function einsum (line 14) | def einsum(equation, *operands): FILE: pyro/ops/einsum/torch_map.py class _EinsumBackward (line 17) | class _EinsumBackward(Backward): method __init__ (line 18) | def __init__(self, operands, argmax): method process (line 22) | def process(self, message): function einsum (line 28) | def einsum(equation, *operands): FILE: pyro/ops/einsum/torch_marginal.py class _EinsumBackward (line 9) | class _EinsumBackward(Backward): method __init__ (line 10) | def __init__(self, equation, operands): method process (line 14) | def process(self, message): function einsum (line 49) | def einsum(equation, *operands): FILE: pyro/ops/einsum/torch_sample.py class _EinsumBackward (line 20) | class _EinsumBackward(Backward): method __init__ (line 21) | def __init__(self, output, operands): method process (line 25) | def process(self, message): function einsum (line 61) | def einsum(equation, *operands): FILE: pyro/ops/einsum/util.py class Tensordot (line 7) | class Tensordot: method __init__ (line 12) | def __init__(self, einsum): method __call__ (line 19) | def __call__(self, x, y, axes=2): FILE: pyro/ops/gamma_gaussian.py class Gamma (line 16) | class Gamma: method __init__ (line 23) | def __init__(self, log_normalizer, concentration, rate): method log_density (line 28) | def log_density(self, s): method logsumexp (line 36) | def logsumexp(self): class GammaGaussian (line 47) | class GammaGaussian: method __init__ (line 85) | def __init__(self, log_normalizer, info_vec, precision, alpha, beta): method dim (line 96) | def dim(self): method batch_shape (line 100) | def batch_shape(self): method expand (line 109) | def expand(self, batch_shape): method reshape (line 118) | def reshape(self, batch_shape): method __getitem__ (line 127) | def __getitem__(self, index): method cat (line 140) | def cat(parts, dim=0): method event_pad (line 152) | def event_pad(self, left=0, right=0): method event_permute (line 167) | def event_permute(self, perm): method __add__ (line 179) | def __add__(self, other): method log_density (line 193) | def log_density(self, value, s): method condition (line 213) | def condition(self, value): method marginalize (line 251) | def marginalize(self, left=0, right=0): method compound (line 298) | def compound(self): method event_logsumexp (line 317) | def event_logsumexp(self): function gamma_and_mvn_to_gamma_gaussian (line 343) | def gamma_and_mvn_to_gamma_gaussian(gamma, mvn): function scale_mvn (line 375) | def scale_mvn(mvn, s): function matrix_and_mvn_to_gamma_gaussian (line 390) | def matrix_and_mvn_to_gamma_gaussian(matrix, mvn): function gamma_gaussian_tensordot (line 434) | def gamma_gaussian_tensordot(x, y, dims=0): FILE: pyro/ops/gaussian.py class Gaussian (line 15) | class Gaussian: method __init__ (line 32) | def __init__( method dim (line 46) | def dim(self): method batch_shape (line 50) | def batch_shape(self): method expand (line 57) | def expand(self, batch_shape) -> "Gaussian": method reshape (line 64) | def reshape(self, batch_shape) -> "Gaussian": method __getitem__ (line 71) | def __getitem__(self, index) -> "Gaussian": method cat (line 82) | def cat(parts, dim=0) -> "Gaussian": method event_pad (line 94) | def event_pad(self, left=0, right=0) -> "Gaussian": method event_permute (line 104) | def event_permute(self, perm) -> "Gaussian": method __add__ (line 114) | def __add__(self, other: Union["Gaussian", int, float, torch.Tensor]) ... method __sub__ (line 129) | def __sub__(self, other: Union["Gaussian", int, float, torch.Tensor]) ... method log_density (line 134) | def log_density(self, value: torch.Tensor) -> torch.Tensor: method rsample (line 151) | def rsample( method condition (line 168) | def condition(self, value: torch.Tensor) -> "Gaussian": method left_condition (line 205) | def left_condition(self, value: torch.Tensor) -> "Gaussian": method marginalize (line 233) | def marginalize(self, left=0, right=0) -> "Gaussian": method event_logsumexp (line 275) | def event_logsumexp(self) -> torch.Tensor: class AffineNormal (line 294) | class AffineNormal: method __init__ (line 314) | def __init__(self, matrix, loc, scale): method batch_shape (line 324) | def batch_shape(self): method condition (line 327) | def condition(self, value): method left_condition (line 342) | def left_condition(self, value): method rsample (line 356) | def rsample( method to_gaussian (line 372) | def to_gaussian(self): method expand (line 381) | def expand(self, batch_shape): method reshape (line 387) | def reshape(self, batch_shape): method __getitem__ (line 393) | def __getitem__(self, index): method event_permute (line 400) | def event_permute(self, perm): method __add__ (line 403) | def __add__(self, other): method marginalize (line 406) | def marginalize(self, left=0, right=0): function mvn_to_gaussian (line 417) | def mvn_to_gaussian(mvn): function matrix_and_gaussian_to_gaussian (line 449) | def matrix_and_gaussian_to_gaussian( function matrix_and_mvn_to_gaussian (line 477) | def matrix_and_mvn_to_gaussian(matrix, mvn): function gaussian_tensordot (line 510) | def gaussian_tensordot(x: Gaussian, y: Gaussian, dims: int = 0) -> Gauss... function sequential_gaussian_tensordot (line 573) | def sequential_gaussian_tensordot(gaussian: Gaussian) -> Gaussian: function sequential_gaussian_filter_sample (line 600) | def sequential_gaussian_filter_sample( FILE: pyro/ops/hessian.py function hessian (line 7) | def hessian(y, xs): FILE: pyro/ops/indexing.py function _is_batched (line 7) | def _is_batched(arg): function _flatten (line 11) | def _flatten(args, out): function index (line 22) | def index(tensor, args): class Index (line 62) | class Index: method __init__ (line 75) | def __init__(self, tensor): method __getitem__ (line 78) | def __getitem__(self, args): function vindex (line 82) | def vindex(tensor, args): class Vindex (line 200) | class Vindex: method __init__ (line 213) | def __init__(self, tensor): method __getitem__ (line 216) | def __getitem__(self, args): FILE: pyro/ops/integrator.py function velocity_verlet (line 14) | def velocity_verlet( function _single_step_verlet (line 45) | def _single_step_verlet(z, r, potential_fn, kinetic_grad, step_size, z_g... function potential_grad (line 68) | def potential_grad(potential_fn, z): function register_exception_handler (line 97) | def register_exception_handler( function _handle_torch_singular (line 119) | def _handle_torch_singular(exception: Exception) -> bool: FILE: pyro/ops/jit.py function _hash (line 15) | def _hash(value, allow_id): function _hashable_args_kwargs (line 37) | def _hashable_args_kwargs(args, kwargs): class CompiledFunction (line 48) | class CompiledFunction: method __init__ (line 59) | def __init__(self, fn, ignore_warnings=False, jit_options=None): method __call__ (line 68) | def __call__(self, *args, **kwargs): function trace (line 132) | def trace(fn=None, ignore_warnings=False, jit_options=None): FILE: pyro/ops/linalg.py function ignore_torch_deprecation_warnings (line 12) | def ignore_torch_deprecation_warnings(): function rinverse (line 19) | def rinverse(M, sym=False): function determinant_3d (line 43) | def determinant_3d(H): function eig_3d (line 55) | def eig_3d(H): function inv3d (line 83) | def inv3d(H, sym=False): FILE: pyro/ops/newton.py function newton_step (line 11) | def newton_step(loss, x, trust_radius=None): function newton_step_1d (line 77) | def newton_step_1d(loss, x, trust_radius=None): function newton_step_2d (line 121) | def newton_step_2d(loss, x, trust_radius=None): function newton_step_3d (line 185) | def newton_step_3d(loss, x, trust_radius=None): FILE: pyro/ops/packed.py function pack (line 12) | def pack(value, dim_to_symbol): function unpack (line 51) | def unpack(value, symbol_to_dim): function broadcast_all (line 73) | def broadcast_all(*values, **kwargs): function gather (line 101) | def gather(value, index, dim): function mul (line 119) | def mul(lhs, rhs): function scale_and_mask (line 137) | def scale_and_mask(tensor, scale=1.0, mask=None): function neg (line 165) | def neg(value): function exp (line 175) | def exp(value): function rename_equation (line 187) | def rename_equation(equation, *operands): FILE: pyro/ops/provenance.py class ProvenanceTensor (line 13) | class ProvenanceTensor(torch.Tensor): method __new__ (line 51) | def __new__(cls, data: torch.Tensor, provenance=frozenset(), **kwargs): method __repr__ (line 63) | def __repr__(self): method __torch_function__ (line 67) | def __torch_function__(cls, func, types, args=(), kwargs=None): function track_provenance (line 74) | def track_provenance(x, provenance: frozenset): function _track_provenance_set (line 90) | def _track_provenance_set(x, provenance: frozenset): function _track_provenance_pytree (line 97) | def _track_provenance_pytree(x, provenance: frozenset): function _track_provenance_provenancetensor (line 102) | def _track_provenance_provenancetensor(x: ProvenanceTensor, provenance: ... function extract_provenance (line 108) | def extract_provenance(x) -> Tuple[object, frozenset]: function _extract_provenance_tensor (line 122) | def _extract_provenance_tensor(x): function _extract_provenance_set (line 128) | def _extract_provenance_set(x): function _extract_provenance_pytree (line 142) | def _extract_provenance_pytree(x): function get_provenance (line 152) | def get_provenance(x) -> frozenset: function detach_provenance (line 165) | def detach_provenance(x: _Tensor) -> _Tensor: FILE: pyro/ops/rings.py class Ring (line 14) | class Ring(object, metaclass=ABCMeta): method __init__ (line 31) | def __init__(self, cache=None): method _hash_by_id (line 34) | def _hash_by_id(self, tensor): method sumproduct (line 44) | def sumproduct(self, terms, dims): method product (line 55) | def product(self, term, ordinal): method broadcast (line 65) | def broadcast(self, term, ordinal): method inv (line 88) | def inv(self, term): method global_local (line 96) | def global_local(self, term, dims, ordinal): class LinearRing (line 126) | class LinearRing(Ring): method __init__ (line 137) | def __init__(self, cache=None, dim_to_size=None): method sumproduct (line 141) | def sumproduct(self, terms, dims): method product (line 149) | def product(self, term, ordinal): method inv (line 164) | def inv(self, term): class LogRing (line 178) | class LogRing(Ring): method __init__ (line 191) | def __init__(self, cache=None, dim_to_size=None): method sumproduct (line 195) | def sumproduct(self, terms, dims): method product (line 203) | def product(self, term, ordinal): method inv (line 218) | def inv(self, term): class _SampleProductBackward (line 232) | class _SampleProductBackward(Backward): method __init__ (line 241) | def __init__(self, ring, term, ordinal): method process (line 246) | def process(self, message): class MapRing (line 260) | class MapRing(LogRing): method product (line 267) | def product(self, term, ordinal): class SampleRing (line 274) | class SampleRing(LogRing): method product (line 281) | def product(self, term, ordinal): class _MarginalProductBackward (line 288) | class _MarginalProductBackward(Backward): method __init__ (line 293) | def __init__(self, ring, term, ordinal, result): method process (line 299) | def process(self, message): class MarginalRing (line 316) | class MarginalRing(LogRing): method product (line 323) | def product(self, term, ordinal): FILE: pyro/ops/special.py class _SafeLog (line 15) | class _SafeLog(torch.autograd.Function): method forward (line 17) | def forward(ctx, x): method backward (line 22) | def backward(ctx, grad): function safe_log (line 27) | def safe_log(x): function log_beta (line 35) | def log_beta(x, y, tol=0.0): function log_binomial (line 93) | def log_binomial(n, k, tol=0.0): function log_I1 (line 113) | def log_I1(orders: int, value: torch.Tensor, terms=250): function get_quad_rule (line 160) | def get_quad_rule(num_quad, prototype_tensor): function sparse_multinomial_likelihood (line 186) | def sparse_multinomial_likelihood(total_count, nonzero_logits, nonzero_v... function _log_factorial_sum (line 211) | def _log_factorial_sum(x: torch.Tensor) -> torch.Tensor: FILE: pyro/ops/ssm_gp.py class MaternKernel (line 16) | class MaternKernel(PyroModule): method __init__ (line 36) | def __init__( method transition_matrix (line 72) | def transition_matrix(self, dt): method stationary_covariance (line 118) | def stationary_covariance(self): method process_covariance (line 145) | def process_covariance(self, A): method transition_matrix_and_covariance (line 158) | def transition_matrix_and_covariance(self, dt): FILE: pyro/ops/stats.py function _compute_chain_variance_stats (line 14) | def _compute_chain_variance_stats(input): function gelman_rubin (line 32) | def gelman_rubin(input, chain_dim=0, sample_dim=1): function split_gelman_rubin (line 58) | def split_gelman_rubin(input, chain_dim=0, sample_dim=1): function autocorrelation (line 87) | def autocorrelation(input, dim=0): function autocovariance (line 131) | def autocovariance(input, dim=0): function _cummin (line 142) | def _cummin(input): function effective_sample_size (line 162) | def effective_sample_size(input, chain_dim=0, sample_dim=1): function resample (line 222) | def resample(input, num_samples, dim=0, replacement=False): function quantile (line 236) | def quantile(input, probs, dim=0): function weighed_quantile (line 265) | def weighed_quantile( function pi (line 328) | def pi(input, prob, dim=0): function hpdi (line 341) | def hpdi(input, prob, dim=0): function _weighted_mean (line 368) | def _weighted_mean(input, log_weights, dim=0, keepdim=False): function _weighted_variance (line 376) | def _weighted_variance(input, log_weights, dim=0, keepdim=False, unbiase... function waic (line 385) | def waic(input, log_weights=None, pointwise=False, dim=0): function fit_generalized_pareto (line 419) | def fit_generalized_pareto(X): function crps_empirical (line 468) | def crps_empirical(pred, truth): function energy_score_empirical (line 513) | def energy_score_empirical( FILE: pyro/ops/streaming.py class StreamingStats (line 14) | class StreamingStats(ABC): method update (line 23) | def update(self, sample) -> None: method merge (line 38) | def merge(self, other) -> "StreamingStats": method get (line 51) | def get(self) -> Any: class CountStats (line 58) | class CountStats(StreamingStats): method __init__ (line 70) | def __init__(self): method update (line 74) | def update(self, sample) -> None: method merge (line 77) | def merge(self, other: "CountStats") -> "CountStats": method get (line 83) | def get(self) -> Dict[str, int]: class StatsOfDict (line 91) | class StatsOfDict(StreamingStats): method __init__ (line 117) | def __init__( method update (line 126) | def update(self, sample: Dict[Hashable, Any]) -> None: method merge (line 130) | def merge(self, other: "StatsOfDict") -> "StatsOfDict": method get (line 140) | def get(self) -> Dict[Hashable, Any]: class StackStats (line 150) | class StackStats(StreamingStats): method __init__ (line 155) | def __init__(self): method update (line 158) | def update(self, sample: torch.Tensor) -> None: method merge (line 162) | def merge(self, other: "StackStats") -> "StackStats": method get (line 168) | def get(self) -> Dict[str, Union[int, torch.Tensor]]: class CountMeanStats (line 179) | class CountMeanStats(StreamingStats): method __init__ (line 184) | def __init__(self): method update (line 189) | def update(self, sample: torch.Tensor) -> None: method merge (line 194) | def merge(self, other: "CountMeanStats") -> "CountMeanStats": method get (line 203) | def get(self) -> Dict[str, Union[int, torch.Tensor]]: class CountMeanVarianceStats (line 214) | class CountMeanVarianceStats(StreamingStats): method __init__ (line 220) | def __init__(self): method update (line 225) | def update(self, sample: torch.Tensor) -> None: method merge (line 232) | def merge(self, other: "CountMeanVarianceStats") -> "CountMeanVariance... method get (line 254) | def get(self) -> Dict[str, Union[int, torch.Tensor]]: FILE: pyro/ops/tensor_utils.py function _validate_jitter (line 16) | def _validate_jitter(value): function as_complex (line 21) | def as_complex(x): function block_diag_embed (line 35) | def block_diag_embed(mat): function block_diagonal (line 49) | def block_diagonal(mat, block_size): function periodic_repeat (line 68) | def periodic_repeat(tensor, size, dim): function periodic_cumsum (line 101) | def periodic_cumsum(tensor, period, dim): function periodic_features (line 140) | def periodic_features(duration, max_period=None, min_period=None, **opti... function next_fast_len (line 185) | def next_fast_len(size): function convolve (line 213) | def convolve(signal, kernel, mode="full"): function repeated_matmul (line 253) | def repeated_matmul(M, n): function dct (line 282) | def dct(x, dim=-1): function idct (line 323) | def idct(x, dim=-1): function haar_transform (line 366) | def haar_transform(x): function inverse_haar_transform (line 386) | def inverse_haar_transform(x): function safe_cholesky (line 405) | def safe_cholesky(x): function cholesky_solve (line 421) | def cholesky_solve(x, y): function matmul (line 427) | def matmul(x, y): function matvecmul (line 433) | def matvecmul(x, y): function triangular_solve (line 439) | def triangular_solve(x, y, upper=False, transpose=False): function precision_to_scale_tril (line 448) | def precision_to_scale_tril(P): function safe_normalize (line 457) | def safe_normalize(x, *, p=2): function broadcast_tensors_without_dim (line 475) | def broadcast_tensors_without_dim(tensors, dim): FILE: pyro/ops/welford.py class WelfordCovariance (line 7) | class WelfordCovariance: method __init__ (line 18) | def __init__(self, diagonal=True): method reset (line 22) | def reset(self): method update (line 27) | def update(self, sample): method get_covariance (line 38) | def get_covariance(self, regularize=True): class WelfordArrowheadCovariance (line 54) | class WelfordArrowheadCovariance: method __init__ (line 59) | def __init__(self, head_size=0): method reset (line 63) | def reset(self): method update (line 69) | def update(self, sample): method get_covariance (line 85) | def get_covariance(self, regularize=True): FILE: pyro/optim/adagrad_rmsprop.py class AdagradRMSProp (line 10) | class AdagradRMSProp(Optimizer): method __init__ (line 36) | def __init__( method share_memory (line 48) | def share_memory(self) -> None: method step (line 54) | def step(self, closure: Optional[Callable] = None) -> Optional[Any]: FILE: pyro/optim/clipped_adam.py class ClippedAdam (line 11) | class ClippedAdam(Optimizer): method __init__ (line 39) | def __init__( method step (line 61) | def step(self, closure: Optional[Callable] = None) -> Optional[Any]: FILE: pyro/optim/dct_adam.py function _transform_forward (line 14) | def _transform_forward(x: torch.Tensor, dim: int, duration: int) -> torc... function _transform_inverse (line 28) | def _transform_inverse(x: torch.Tensor, dim: int, duration: int): function _get_mask (line 37) | def _get_mask(x, indices): class DCTAdam (line 55) | class DCTAdam(Optimizer): method __init__ (line 77) | def __init__( method step (line 97) | def step(self, closure: Optional[Callable] = None) -> Optional[float]: method _step_param (line 122) | def _step_param(self, group: Dict, p) -> None: method _step_param_subsample (line 163) | def _step_param_subsample(self, group: Dict, p, subsample) -> None: FILE: pyro/optim/horovod.py class HorovodOptimizer (line 13) | class HorovodOptimizer(PyroOptim): method __init__ (line 33) | def __init__(self, pyro_optim: PyroOptim, **horovod_kwargs): method __call__ (line 52) | def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -... FILE: pyro/optim/lr_scheduler.py class PyroLRScheduler (line 11) | class PyroLRScheduler(PyroOptim): method __init__ (line 33) | def __init__( method __call__ (line 48) | def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -... method _get_optim (line 51) | def _get_optim( method step (line 57) | def step(self, *args, **kwargs) -> None: FILE: pyro/optim/multi.py class MultiOptimizer (line 12) | class MultiOptimizer: method step (line 35) | def step(self, loss: torch.Tensor, params: Dict) -> None: method get_step (line 53) | def get_step(self, loss: torch.Tensor, params: Dict) -> Dict: class PyroMultiOptimizer (line 71) | class PyroMultiOptimizer(MultiOptimizer): method __init__ (line 77) | def __init__(self, optim: PyroOptim) -> None: method step (line 84) | def step(self, loss: torch.Tensor, params: Dict) -> None: class TorchMultiOptimizer (line 92) | class TorchMultiOptimizer(PyroMultiOptimizer): method __init__ (line 98) | def __init__(self, optim_constructor: torch.optim.Optimizer, optim_arg... class MixedMultiOptimizer (line 103) | class MixedMultiOptimizer(MultiOptimizer): method __init__ (line 116) | def __init__(self, parts: List) -> None: method step (line 131) | def step(self, loss: torch.Tensor, params: Dict): method get_step (line 135) | def get_step(self, loss: torch.Tensor, params: Dict) -> Dict: class Newton (line 144) | class Newton(MultiOptimizer): method __init__ (line 159) | def __init__(self, trust_radii: Dict = {}): method get_step (line 162) | def get_step(self, loss: torch.Tensor, params: Dict): FILE: pyro/optim/optim.py function is_scheduler (line 33) | def is_scheduler(optimizer) -> bool: function _get_state_dict (line 45) | def _get_state_dict(optimizer) -> dict: function _load_state_dict (line 60) | def _load_state_dict(optimizer, state: dict) -> None: class PyroOptim (line 72) | class PyroOptim: method __init__ (line 83) | def __init__( method __call__ (line 117) | def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -... method get_state (line 157) | def get_state(self) -> Dict: method set_state (line 168) | def set_state(self, state_dict: Dict) -> None: method save (line 175) | def save(self, filename: str) -> None: method load (line 185) | def load(self, filename: str, map_location=None) -> None: method _get_optim (line 200) | def _get_optim(self, param: Union[Iterable[Tensor], Iterable[Dict[Any,... method _get_optim_args (line 204) | def _get_optim_args(self, param: Union[Iterable[Tensor], Iterable[Dict... method _get_grad_clip (line 227) | def _get_grad_clip(self, param: str): method _get_grad_clip_args (line 238) | def _get_grad_clip_args(self, param: str) -> Dict: method _clip_grad (line 259) | def _clip_grad( function AdagradRMSProp (line 270) | def AdagradRMSProp(optim_args: Dict) -> PyroOptim: function ClippedAdam (line 277) | def ClippedAdam(optim_args: Dict) -> PyroOptim: function DCTAdam (line 284) | def DCTAdam(optim_args: Dict) -> PyroOptim: FILE: pyro/params/param_store.py class StateDict (line 25) | class StateDict(TypedDict): class ParamStoreDict (line 30) | class ParamStoreDict: method __init__ (line 59) | def __init__(self) -> None: method clear (line 73) | def clear(self) -> None: method items (line 81) | def items(self) -> Iterator[Tuple[str, torch.Tensor]]: method keys (line 89) | def keys(self) -> KeysView[str]: method values (line 95) | def values(self) -> Iterator[torch.Tensor]: method __bool__ (line 102) | def __bool__(self) -> bool: method __len__ (line 105) | def __len__(self) -> int: method __contains__ (line 108) | def __contains__(self, name: str) -> bool: method __iter__ (line 111) | def __iter__(self) -> Iterator[str]: method __delitem__ (line 117) | def __delitem__(self, name) -> None: method __getitem__ (line 125) | def __getitem__(self, name: str) -> torch.Tensor: method __setitem__ (line 138) | def __setitem__(self, name: str, new_constrained_value: torch.Tensor) ... method setdefault (line 158) | def setdefault( method named_parameters (line 201) | def named_parameters(self) -> ItemsView[str, torch.Tensor]: method get_all_param_names (line 209) | def get_all_param_names(self) -> KeysView[str]: method replace_param (line 216) | def replace_param( method get_param (line 226) | def get_param( method match (line 253) | def match(self, name: str) -> Dict[str, torch.Tensor]: method param_name (line 264) | def param_name(self, p: torch.Tensor) -> Optional[str]: method get_state (line 276) | def get_state(self) -> StateDict: method set_state (line 287) | def set_state(self, state: StateDict) -> None: method save (line 306) | def save(self, filename: str) -> None: method load (line 316) | def load(self, filename: str, map_location: MAP_LOCATION = None) -> None: method scope (line 338) | def scope(self, state: Optional[StateDict] = None) -> Iterator[StateDi... function param_with_module_name (line 380) | def param_with_module_name(pyro_name: str, param_name: str) -> str: function module_from_param_with_module_name (line 384) | def module_from_param_with_module_name(param_name: str) -> str: function user_param_name (line 388) | def user_param_name(param_name: str) -> str: function normalize_param_name (line 394) | def normalize_param_name(name: str) -> str: FILE: pyro/poutine/block_messenger.py function _block_fn (line 13) | def _block_fn( function _make_default_hide_fn (line 41) | def _make_default_hide_fn( function _negate_fn (line 86) | def _negate_fn( class BlockMessenger (line 96) | class BlockMessenger(Messenger): method __init__ (line 145) | def __init__( method _process_message (line 168) | def _process_message(self, msg: "Message") -> None: FILE: pyro/poutine/broadcast_messenger.py class BroadcastMessenger (line 14) | class BroadcastMessenger(Messenger): method _pyro_sample (line 46) | def _pyro_sample(msg: "Message") -> None: FILE: pyro/poutine/collapse_messenger.py function _substitute (line 37) | def _substitute(x, subs): function _ (line 42) | def _(x, subs): function _ (line 47) | def _(x, subs): function _ (line 52) | def _(x, subs): function _extract_deltas (line 57) | def _extract_deltas(f): function _ (line 62) | def _(f): function _ (line 67) | def _(f): class CollapseMessenger (line 73) | class CollapseMessenger(TraceMessenger): method __init__ (line 88) | def __init__(self, *args: Any, **kwargs: Any) -> None: method _process_message (line 98) | def _process_message(self, msg: "Message") -> None: method _pyro_sample (line 105) | def _pyro_sample(self, msg: "Message") -> None: method _pyro_post_sample (line 121) | def _pyro_post_sample(self, msg: "Message") -> None: method _pyro_barrier (line 128) | def _pyro_barrier(self, msg: "Message") -> None: method __enter__ (line 149) | def __enter__(self) -> Self: method __exit__ (line 156) | def __exit__(self, *args) -> None: method _get_log_prob (line 166) | def _get_log_prob(self) -> Tuple[str, Funsor, Funsor, FrozenSet[str]]: FILE: pyro/poutine/condition_messenger.py class ConditionMessenger (line 15) | class ConditionMessenger(Messenger): method __init__ (line 41) | def __init__(self, data: Union[Dict[str, torch.Tensor], Trace]) -> None: method _pyro_sample (line 51) | def _pyro_sample(self, msg: "Message") -> None: FILE: pyro/poutine/do_messenger.py class DoMessenger (line 14) | class DoMessenger(Messenger): method __init__ (line 52) | def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]... method _pyro_sample (line 57) | def _pyro_sample(self, msg: Message) -> None: FILE: pyro/poutine/enum_messenger.py function _tmc_mixture_sample (line 17) | def _tmc_mixture_sample(msg: Message) -> torch.Tensor: function _tmc_diagonal_sample (line 67) | def _tmc_diagonal_sample(msg: Message) -> torch.Tensor: function enumerate_site (line 114) | def enumerate_site(msg: Message) -> torch.Tensor: class EnumMessenger (line 136) | class EnumMessenger(Messenger): method __init__ (line 147) | def __init__(self, first_available_dim: Optional[int] = None) -> None: method __enter__ (line 154) | def __enter__(self) -> Self: method _pyro_sample (line 169) | def _pyro_sample(self, msg: Message) -> None: method _pyro_post_sample (line 233) | def _pyro_post_sample(self, msg: Message) -> None: FILE: pyro/poutine/equalize_messenger.py class EqualizeMessenger (line 14) | class EqualizeMessenger(Messenger): method __init__ (line 67) | def __init__( method __enter__ (line 78) | def __enter__(self) -> Self: method _is_matching (line 82) | def _is_matching(self, msg: Message) -> bool: method _postprocess_message (line 89) | def _postprocess_message(self, msg: Message) -> None: method _process_message (line 95) | def _process_message(self, msg: Message) -> None: FILE: pyro/poutine/escape_messenger.py class EscapeMessenger (line 10) | class EscapeMessenger(Messenger): method __init__ (line 15) | def __init__(self, escape_fn: Callable[[Message], bool]) -> None: method _pyro_sample (line 25) | def _pyro_sample(self, msg: Message) -> None: FILE: pyro/poutine/guide.py class GuideMessenger (line 19) | class GuideMessenger(TraceMessenger, ABC): method __init__ (line 26) | def __init__(self, model: Callable) -> None: method model (line 32) | def model(self) -> Callable: method __getstate__ (line 35) | def __getstate__(self) -> Dict[str, object]: method __call__ (line 41) | def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # typ... method _pyro_sample (line 65) | def _pyro_sample(self, msg: "Message") -> None: method _pyro_post_sample (line 81) | def _pyro_post_sample(self, msg: "Message") -> None: method get_posterior (line 93) | def get_posterior( method upstream_value (line 125) | def upstream_value(self, name: str) -> Optional[torch.Tensor]: method get_traces (line 134) | def get_traces(self) -> Tuple[Trace, Trace]: FILE: pyro/poutine/handlers.py function _make_handler (line 110) | def _make_handler(msngr_cls, module=None): function block (line 141) | def block( function block (line 155) | def block( function block (line 169) | def block( # type: ignore[empty-body] function broadcast (line 183) | def broadcast( function broadcast (line 189) | def broadcast( function broadcast (line 195) | def broadcast( # type: ignore[empty-body] function collapse (line 201) | def collapse( function collapse (line 209) | def collapse( function collapse (line 217) | def collapse( # type: ignore[empty-body] function condition (line 225) | def condition( function condition (line 231) | def condition( function condition (line 238) | def condition( # type: ignore[empty-body] function do (line 245) | def do( function do (line 251) | def do( function do (line 258) | def do( # type: ignore[empty-body] function enum (line 265) | def enum( function enum (line 272) | def enum( function enum (line 279) | def enum( # type: ignore[empty-body] function escape (line 286) | def escape( function escape (line 292) | def escape( function escape (line 299) | def escape( # type: ignore[empty-body] function equalize (line 306) | def equalize( function equalize (line 314) | def equalize( function equalize (line 323) | def equalize( # type: ignore[empty-body] function infer_config (line 332) | def infer_config( function infer_config (line 338) | def infer_config( function infer_config (line 345) | def infer_config( # type: ignore[empty-body] function lift (line 352) | def lift( function lift (line 358) | def lift( function lift (line 365) | def lift( # type: ignore[empty-body] function mask (line 372) | def mask( function mask (line 378) | def mask( function mask (line 385) | def mask( # type: ignore[empty-body] function reparam (line 392) | def reparam( function reparam (line 398) | def reparam( function reparam (line 405) | def reparam( # type: ignore[empty-body] function replay (line 412) | def replay( function replay (line 420) | def replay( function replay (line 428) | def replay( # type: ignore[empty-body] function scale (line 436) | def scale( function scale (line 442) | def scale( function scale (line 449) | def scale( # type: ignore[empty-body] function seed (line 456) | def seed( function seed (line 462) | def seed( function seed (line 469) | def seed( # type: ignore[empty-body] function substitute (line 476) | def substitute( function substitute (line 482) | def substitute( function substitute (line 489) | def substitute( # type: ignore[empty-body] function trace (line 496) | def trace( function trace (line 504) | def trace( function trace (line 512) | def trace( # type: ignore[empty-body] function uncondition (line 520) | def uncondition( function uncondition (line 526) | def uncondition( function uncondition (line 532) | def uncondition( # type: ignore[empty-body] function queue (line 542) | def queue( function markov (line 610) | def markov( function markov (line 620) | def markov( function markov (line 630) | def markov( function markov (line 639) | def markov( FILE: pyro/poutine/indep_messenger.py class CondIndepStackFrame (line 14) | class CondIndepStackFrame(NamedTuple): method vectorized (line 22) | def vectorized(self) -> bool: method _key (line 25) | def _key(self) -> Tuple[str, Optional[int], int, int]: method __eq__ (line 32) | def __eq__(self, other: object) -> bool: method __ne__ (line 37) | def __ne__(self, other: object) -> bool: method __hash__ (line 40) | def __hash__(self) -> int: method __str__ (line 43) | def __str__(self) -> str: class IndepMessenger (line 47) | class IndepMessenger(Messenger): method __init__ (line 67) | def __init__( method next_context (line 89) | def next_context(self) -> None: method __enter__ (line 95) | def __enter__(self) -> Self: method __exit__ (line 104) | def __exit__(self, *args) -> None: method __iter__ (line 110) | def __iter__(self) -> Iterator[Union[int, float]]: method _reset (line 132) | def _reset(self) -> None: method indices (line 140) | def indices(self) -> torch.Tensor: method _process_message (line 145) | def _process_message(self, msg: Message) -> None: FILE: pyro/poutine/infer_config_messenger.py class InferConfigMessenger (line 12) | class InferConfigMessenger(Messenger): method __init__ (line 23) | def __init__(self, config_fn: Callable[["Message"], "InferDict"]) -> N... method _pyro_sample (line 33) | def _pyro_sample(self, msg: "Message") -> None: method _pyro_param (line 46) | def _pyro_param(self, msg: "Message") -> None: FILE: pyro/poutine/lift_messenger.py class LiftMessenger (line 18) | class LiftMessenger(Messenger): method __init__ (line 48) | def __init__( method __enter__ (line 62) | def __enter__(self) -> Self: method __exit__ (line 69) | def __exit__(self, *args, **kwargs) -> None: method _pyro_sample (line 82) | def _pyro_sample(self, msg: "Message") -> None: method _pyro_param (line 85) | def _pyro_param(self, msg: "Message") -> None: FILE: pyro/poutine/markov_messenger.py class MarkovMessenger (line 16) | class MarkovMessenger(ReentrantMessenger): method __init__ (line 36) | def __init__( method generator (line 61) | def generator(self, iterable: Iterable[int]) -> Self: method __iter__ (line 65) | def __iter__(self) -> Iterator[int]: method __enter__ (line 72) | def __enter__(self) -> Self: method __exit__ (line 78) | def __exit__(self, *args, **kwargs) -> None: method _pyro_sample (line 84) | def _pyro_sample(self, msg: "Message") -> None: FILE: pyro/poutine/mask_messenger.py class MaskMessenger (line 14) | class MaskMessenger(Messenger): method __init__ (line 25) | def __init__(self, mask: Union[bool, torch.BoolTensor]) -> None: method _process_message (line 39) | def _process_message(self, msg: "Message") -> None: FILE: pyro/poutine/messenger.py function _context_wrap (line 25) | def _context_wrap( class _bound_partial (line 35) | class _bound_partial(partial): method __init__ (line 46) | def __init__(self, func): method __get__ (line 49) | def __get__( function unwrap (line 59) | def unwrap(fn: Callable) -> Callable: class Messenger (line 73) | class Messenger: method __call__ (line 88) | def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]: method __enter__ (line 96) | def __enter__(self) -> Self: method __exit__ (line 128) | def __exit__( method _reset (line 176) | def _reset(self) -> None: method _process_message (line 179) | def _process_message(self, msg: Message) -> None: method _postprocess_message (line 191) | def _postprocess_message(self, msg: Message) -> None: method register (line 197) | def register( method unregister (line 230) | def unregister( function block_messengers (line 264) | def block_messengers( FILE: pyro/poutine/plate_messenger.py class PlateMessenger (line 17) | class PlateMessenger(SubsampleMessenger): method _process_message (line 23) | def _process_message(self, msg: "Message") -> None: method __enter__ (line 27) | def __enter__(self) -> Optional["torch.Tensor"]: # type: ignore[overr... function block_plate (line 35) | def block_plate( FILE: pyro/poutine/reentrant_messenger.py class ReentrantMessenger (line 16) | class ReentrantMessenger(Messenger): method __init__ (line 17) | def __init__(self) -> None: method __call__ (line 21) | def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]: method __enter__ (line 24) | def __enter__(self) -> Self: method __exit__ (line 30) | def __exit__( FILE: pyro/poutine/reparam_messenger.py function _get_init_messengers (line 32) | def _get_init_messengers() -> List[Messenger]: class ReparamMessenger (line 36) | class ReparamMessenger(Messenger): method __init__ (line 61) | def __init__( method __call__ (line 70) | def __call__(self, fn: Callable[_P, _T]) -> "ReparamHandler[_P, _T]": method _pyro_sample (line 73) | def _pyro_sample(self, msg: "Message") -> None: class ReparamHandler (line 148) | class ReparamHandler(Generic[_P, _T]): method __init__ (line 153) | def __init__(self, msngr, fn: Callable[_P, _T]) -> None: method __call__ (line 158) | def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: FILE: pyro/poutine/replay_messenger.py class ReplayMessenger (line 15) | class ReplayMessenger(Messenger): method __init__ (line 43) | def __init__( method _pyro_sample (line 60) | def _pyro_sample(self, msg: "Message") -> None: method _pyro_param (line 83) | def _pyro_param(self, msg: "Message") -> None: FILE: pyro/poutine/runtime.py class InferDict (line 45) | class InferDict(TypedDict, total=False): class Message (line 108) | class Message(TypedDict, Generic[_P, _T], total=False): class _DimAllocator (line 184) | class _DimAllocator: method __init__ (line 192) | def __init__(self) -> None: method allocate (line 196) | def allocate(self, name: str, dim: Optional[int]) -> int: method free (line 231) | def free(self, name: str, dim: int) -> None: class _EnumAllocator (line 246) | class _EnumAllocator: method set_first_available_dim (line 255) | def set_first_available_dim(self, first_available_dim: int) -> None: method allocate (line 267) | def allocate(self, scope_dims: Optional[Set[int]] = None) -> Tuple[int... class NonlocalExit (line 306) | class NonlocalExit(Exception): method __init__ (line 313) | def __init__(self, site: Message, *args, **kwargs) -> None: method reset_stack (line 321) | def reset_stack(self) -> None: function default_process_message (line 334) | def default_process_message(msg: Message) -> None: function apply_stack (line 351) | def apply_stack(initial_msg: Message) -> None: function am_i_wrapped (line 393) | def am_i_wrapped() -> bool: function effectful (line 402) | def effectful( function effectful (line 408) | def effectful( function effectful (line 413) | def effectful( function _inspect (line 470) | def _inspect() -> Message: function get_mask (line 500) | def get_mask() -> Union[bool, torch.Tensor, None]: function get_plates (line 520) | def get_plates() -> Tuple["CondIndepStackFrame", ...]: FILE: pyro/poutine/scale_messenger.py class ScaleMessenger (line 15) | class ScaleMessenger(Messenger): method __init__ (line 40) | def __init__(self, scale: Union[float, torch.Tensor]) -> None: method _process_message (line 52) | def _process_message(self, msg: "Message") -> None: FILE: pyro/poutine/seed_messenger.py class SeedMessenger (line 11) | class SeedMessenger(Messenger): method __init__ (line 23) | def __init__(self, rng_seed: int) -> None: method __enter__ (line 28) | def __enter__(self) -> None: # type: ignore[override] method __exit__ (line 32) | def __exit__( FILE: pyro/poutine/subsample_messenger.py class _Subsample (line 15) | class _Subsample(Distribution): method __init__ (line 22) | def __init__( method sample (line 51) | def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Ten... method log_prob (line 67) | def log_prob(self, x: torch.Tensor) -> torch.Tensor: class SubsampleMessenger (line 74) | class SubsampleMessenger(IndepMessenger): method __init__ (line 79) | def __init__( method _subsample (line 101) | def _subsample( method _reset (line 155) | def _reset(self) -> None: method _process_message (line 159) | def _process_message(self, msg: Message) -> None: method _postprocess_message (line 176) | def _postprocess_message(self, msg: Message) -> None: FILE: pyro/poutine/substitute_messenger.py class SubstituteMessenger (line 19) | class SubstituteMessenger(Messenger): method __init__ (line 38) | def __init__(self, data: Dict[str, "torch.Tensor"]) -> None: method __enter__ (line 47) | def __enter__(self) -> Self: method __exit__ (line 54) | def __exit__(self, *args, **kwargs) -> None: method _pyro_sample (line 67) | def _pyro_sample(self, msg: "Message") -> None: method _pyro_param (line 70) | def _pyro_param(self, msg: "Message") -> None: FILE: pyro/poutine/trace_messenger.py function identify_dense_edges (line 20) | def identify_dense_edges(trace: Trace) -> None: class TraceMessenger (line 49) | class TraceMessenger(Messenger): method __init__ (line 75) | def __init__( method __enter__ (line 95) | def __enter__(self) -> Self: method __exit__ (line 99) | def __exit__(self, *args, **kwargs) -> None: method __call__ (line 113) | def __call__(self, fn: Callable[_P, _T]) -> "TraceHandler[_P, _T]": method get_trace (line 119) | def get_trace(self) -> Trace: method _reset (line 129) | def _reset(self) -> None: method _pyro_post_sample (line 142) | def _pyro_post_sample(self, msg: "Message") -> None: method _pyro_post_param (line 153) | def _pyro_post_param(self, msg: "Message") -> None: class TraceHandler (line 158) | class TraceHandler(Generic[_P, _T]): method __init__ (line 170) | def __init__(self, msngr: TraceMessenger, fn: Callable[_P, _T]) -> None: method __call__ (line 174) | def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: method trace (line 205) | def trace(self) -> Trace: method get_trace (line 208) | def get_trace(self, *args, **kwargs) -> Trace: FILE: pyro/poutine/trace_struct.py function allow_all_sites (line 36) | def allow_all_sites(name: str, site: "Message") -> bool: class Trace (line 40) | class Trace: method __init__ (line 97) | def __init__(self, graph_type: Literal["flat", "dense"] = "flat") -> N... method __contains__ (line 106) | def __contains__(self, name: str) -> bool: method __iter__ (line 109) | def __iter__(self) -> Iterable[str]: method __len__ (line 112) | def __len__(self) -> int: method edges (line 116) | def edges(self) -> Iterable[Tuple[str, str]]: method add_node (line 121) | def add_node(self, site_name: str, **kwargs: Any) -> None: method add_edge (line 148) | def add_edge(self, site1: str, site2: str) -> None: method remove_node (line 155) | def remove_node(self, site_name: str) -> None: method predecessors (line 164) | def predecessors(self, site_name: str) -> Set[str]: method successors (line 167) | def successors(self, site_name: str) -> Set[str]: method copy (line 170) | def copy(self) -> "Trace": method _dfs (line 180) | def _dfs(self, site: str, visited: Set[str]) -> Iterable[str]: method topological_sort (line 189) | def topological_sort(self, reverse: bool = False) -> List[str]: method log_prob_sum (line 203) | def log_prob_sum( method compute_log_prob (line 248) | def compute_log_prob( method compute_score_parts (line 290) | def compute_score_parts(self) -> None: method detach_ (line 330) | def detach_(self) -> None: method observation_nodes (line 340) | def observation_nodes(self) -> List[str]: method param_nodes (line 351) | def param_nodes(self) -> List[str]: method stochastic_nodes (line 358) | def stochastic_nodes(self) -> List[str]: method reparameterized_nodes (line 369) | def reparameterized_nodes(self) -> List[str]: method nonreparam_stochastic_nodes (line 383) | def nonreparam_stochastic_nodes(self) -> List[str]: method iter_stochastic_nodes (line 390) | def iter_stochastic_nodes(self) -> Iterator[Tuple[str, "Message"]]: method symbolize_dims (line 398) | def symbolize_dims(self, plate_to_symbol: Optional[Dict[str, str]] = N... method pack_tensors (line 435) | def pack_tensors(self, plate_to_symbol: Optional[Dict[str, str]] = Non... method format_shapes (line 475) | def format_shapes( function _format_table (line 534) | def _format_table(rows: List[List[Optional[str]]]) -> str: FILE: pyro/poutine/uncondition_messenger.py class UnconditionMessenger (line 12) | class UnconditionMessenger(Messenger): method __init__ (line 18) | def __init__(self) -> None: method _pyro_sample (line 21) | def _pyro_sample(self, msg: "Message") -> None: FILE: pyro/poutine/util.py function enable_validation (line 17) | def enable_validation(is_validate: bool) -> None: function is_validation_enabled (line 22) | def is_validation_enabled() -> bool: function site_is_subsample (line 26) | def site_is_subsample(site: "Message") -> bool: function site_is_factor (line 33) | def site_is_factor(site: "Message") -> bool: function prune_subsample_sites (line 40) | def prune_subsample_sites(trace: "Trace") -> "Trace": function enum_extend (line 51) | def enum_extend( function mc_extend (line 83) | def mc_extend( function discrete_escape (line 111) | def discrete_escape(trace: "Trace", msg: "Message") -> bool: function all_escape (line 131) | def all_escape(trace: "Trace", msg: "Message") -> bool: FILE: pyro/primitives.py function get_param_store (line 35) | def get_param_store() -> ParamStoreDict: function clear_param_store (line 42) | def clear_param_store() -> None: function param (line 57) | def param( function _masked_observe (line 94) | def _masked_observe( function sample (line 125) | def sample( function factor (line 195) | def factor( function deterministic (line 221) | def deterministic( function subsample (line 250) | def subsample(data: torch.Tensor, event_dim: int) -> torch.Tensor: class plate (line 283) | class plate(PlateMessenger): class iarange (line 392) | class iarange(plate): method __init__ (line 393) | def __init__(self, *args, **kwargs): class irange (line 400) | class irange(SubsampleMessenger): method __init__ (line 401) | def __init__(self, *args, **kwargs): function plate_stack (line 409) | def plate_stack( function module (line 429) | def module( function random_module (line 506) | def random_module(name, nn_module, prior, *args, **kwargs): function barrier (line 547) | def barrier(data: torch.Tensor) -> torch.Tensor: function enable_validation (line 556) | def enable_validation(is_validate: bool = True) -> None: function validation_enabled (line 583) | def validation_enabled(is_validate: bool = True) -> Iterator[None]: FILE: pyro/settings.py function get (line 61) | def get(alias: Optional[str] = None) -> Any: function set (line 79) | def set(**kwargs) -> None: function context (line 97) | def context(**kwargs) -> Iterator[None]: function register (line 112) | def register( FILE: pyro/util.py function set_rng_seed (line 37) | def set_rng_seed(rng_seed: int) -> None: function get_rng_state (line 48) | def get_rng_state() -> Dict[str, Any]: function set_rng_state (line 56) | def set_rng_state(state: Dict[str, Any]) -> None: function torch_isnan (line 66) | def torch_isnan(x: numbers.Number) -> bool: ... function torch_isnan (line 68) | def torch_isnan(x: torch.Tensor) -> torch.Tensor: ... function torch_isnan (line 69) | def torch_isnan(x: Union[torch.Tensor, numbers.Number]) -> Union[bool, t... function torch_isinf (line 79) | def torch_isinf(x: numbers.Number) -> bool: ... function torch_isinf (line 81) | def torch_isinf(x: torch.Tensor) -> torch.Tensor: ... function torch_isinf (line 82) | def torch_isinf(x: Union[torch.Tensor, numbers.Number]) -> Union[bool, t... function warn_if_nan (line 92) | def warn_if_nan( function warn_if_nan (line 100) | def warn_if_nan( function warn_if_nan (line 107) | def warn_if_nan( function warn_if_inf (line 148) | def warn_if_inf( function warn_if_inf (line 158) | def warn_if_inf( function warn_if_inf (line 167) | def warn_if_inf( function save_visualization (line 230) | def save_visualization(trace: "Trace", graph_output: str) -> None: function check_traces_match (line 284) | def check_traces_match(trace1: "Trace", trace2: "Trace") -> None: function check_model_guide_match (line 314) | def check_model_guide_match( function check_site_shape (line 465) | def check_site_shape(site: "Message", max_plate_nesting: int) -> None: function _are_independent (line 548) | def _are_independent(counters1: Dict[str, int], counters2: Dict[str, int... function check_traceenum_requirements (line 556) | def check_traceenum_requirements(model_trace: "Trace", guide_trace: "Tra... function check_if_enumerated (line 620) | def check_if_enumerated(guide_trace: "Trace") -> None: function ignore_jit_warnings (line 639) | def ignore_jit_warnings(filter=None): function jit_iter (line 665) | def jit_iter(tensor: torch.Tensor) -> List[torch.Tensor]: class optional (line 677) | class optional: method __init__ (line 682) | def __init__(self, context_manager, condition): method __enter__ (line 686) | def __enter__(self): method __exit__ (line 690) | def __exit__(self, exc_type, exc_val, exc_tb): class ExperimentalWarning (line 695) | class ExperimentalWarning(UserWarning): function ignore_experimental_warning (line 700) | def ignore_experimental_warning(): class timed (line 706) | class timed: method __enter__ (line 707) | def __enter__(self, timer=timeit.default_timer): method __exit__ (line 711) | def __exit__(self, exc_type, exc_val, exc_tb): function torch_float (line 718) | def torch_float(x: Union[float, int]) -> float: ... function torch_float (line 720) | def torch_float(x: torch.Tensor) -> torch.Tensor: ... function torch_float (line 721) | def torch_float( FILE: tests/common.py function xfail_param (line 31) | def xfail_param(*args, **kwargs): function str_erase_pointers (line 36) | def str_erase_pointers(x): function skipif_param (line 46) | def skipif_param(*args, **kwargs): function suppress_warnings (line 50) | def suppress_warnings(fn): function get_cpu_type (line 86) | def get_cpu_type(t): function get_gpu_type (line 91) | def get_gpu_type(t): function default_dtype (line 97) | def default_dtype(dtype): function freeze_rng_state (line 112) | def freeze_rng_state(): function xfail_if_not_implemented (line 123) | def xfail_if_not_implemented(msg="Not implemented"): function iter_indices (line 130) | def iter_indices(tensor): function is_iterable (line 138) | def is_iterable(obj): function assert_tensors_equal (line 146) | def assert_tensors_equal(a, b, prec=0.0, msg=""): function _safe_coalesce (line 173) | def _safe_coalesce(t): function assert_close (line 198) | def assert_close(actual, expected, atol=1e-7, rtol=0, msg=""): function assert_equal (line 246) | def assert_equal(actual, expected, prec=1e-5, msg=""): function assert_not_equal (line 289) | def assert_not_equal(x, y, prec=1e-5, msg=""): FILE: tests/conftest.py function pytest_configure (line 17) | def pytest_configure(config): function pytest_runtest_setup (line 29) | def pytest_runtest_setup(item): function pytest_addoption (line 41) | def pytest_addoption(parser): function _get_highest_specificity_marker (line 58) | def _get_highest_specificity_marker(stage_marker): function _add_marker (line 81) | def _add_marker(marker, items): function pytest_collection_modifyitems (line 86) | def pytest_collection_modifyitems(config, items): FILE: tests/contrib/autoname/test_autoname.py function test_basic_scope (line 12) | def test_basic_scope(): function test_repeat_names (line 32) | def test_repeat_names(): function test_compose_scopes (line 58) | def test_compose_scopes(): function test_basic_loop (line 92) | def test_basic_loop(): function test_named_loop (line 116) | def test_named_loop(): function test_sequential_plate (line 141) | def test_sequential_plate(): function test_nested_plate (line 168) | def test_nested_plate(): function test_model_guide (line 199) | def test_model_guide(): function test_context_manager (line 218) | def test_context_manager(): function test_multi_nested (line 234) | def test_multi_nested(): function test_recur_multi (line 269) | def test_recur_multi(): function test_only_withs (line 301) | def test_only_withs(): function test_mutual_recur (line 314) | def test_mutual_recur(): function test_simple_recur (line 341) | def test_simple_recur(): function test_no_param (line 358) | def test_no_param(): FILE: tests/contrib/autoname/test_named.py function get_sample_names (line 12) | def get_sample_names(tr): function get_observe_names (line 22) | def get_observe_names(tr): function get_param_names (line 32) | def get_param_names(tr): function test_named_object (line 36) | def test_named_object(): function test_named_list (line 52) | def test_named_list(): function test_named_dict (line 68) | def test_named_dict(): function test_nested (line 84) | def test_nested(): function test_eval_str (line 101) | def test_eval_str(): FILE: tests/contrib/autoname/test_scoping.py function test_multi_nested (line 16) | def test_multi_nested(): function test_recur_multi (line 49) | def test_recur_multi(): function test_only_withs (line 79) | def test_only_withs(): function test_mutual_recur (line 92) | def test_mutual_recur(): function test_simple_recur (line 116) | def test_simple_recur(): function test_basic_scope (line 134) | def test_basic_scope(): function test_nested_traces (line 152) | def test_nested_traces(): function test_no_param (line 172) | def test_no_param(): FILE: tests/contrib/bnn/test_hidden_layer.py function test_hidden_layer_rsample (line 15) | def test_hidden_layer_rsample( function test_hidden_layer_log_prob (line 52) | def test_hidden_layer_log_prob(non_linearity, include_hidden_bias, B=2, ... FILE: tests/contrib/cevae/test_cevae.py function generate_data (line 18) | def generate_data(num_data, feature_dim): function test_smoke (line 29) | def test_smoke(num_data, feature_dim, outcome_dist): function test_serialization (line 42) | def test_serialization(jit, feature_dim, outcome_dist): FILE: tests/contrib/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/contrib/easyguide/test_easyguide.py function model (line 21) | def model(batch, subsample, full_size): function check_guide (line 36) | def check_guide(guide): function test_delta_smoke (line 58) | def test_delta_smoke(init_fn): class PickleGuide (line 71) | class PickleGuide(EasyGuide): method __init__ (line 72) | def __init__(self, model): method guide (line 76) | def guide(self, batch, subsample, full_size): function test_serialize (line 82) | def test_serialize(): function test_subsample_smoke (line 101) | def test_subsample_smoke(init_fn): function test_amortized_smoke (line 133) | def test_amortized_smoke(init_fn): function test_overlapping_plates_ok (line 170) | def test_overlapping_plates_ok(): function test_overlapping_plates_error (line 215) | def test_overlapping_plates_error(): FILE: tests/contrib/epidemiology/test_distributions.py function assert_dist_close (line 19) | def assert_dist_close(d1, d2): function test_binomial_vs_poisson (line 44) | def test_binomial_vs_poisson(R0, I): function test_beta_binomial_vs_negative_binomial (line 75) | def test_beta_binomial_vs_negative_binomial(R0, I, k): function test_beta_binomial_vs_binomial (line 102) | def test_beta_binomial_vs_binomial(R0, I): function test_negative_binomial_vs_poisson (line 131) | def test_negative_binomial_vs_poisson(R0, I): function test_overdispersed_bound (line 145) | def test_overdispersed_bound(probs, overdispersion): function test_overdispersed_asymptote (line 159) | def test_overdispersed_asymptote(probs, overdispersion): function test_beta_binomial (line 180) | def test_beta_binomial(concentration1, concentration0, total_count): function test_overdispersed_beta_binomial (line 199) | def test_overdispersed_beta_binomial(probs, total_count, overdispersion): function test_relaxed_binomial (line 218) | def test_relaxed_binomial(): function test_relaxed_overdispersed_binomial (line 233) | def test_relaxed_overdispersed_binomial(overdispersion): function test_relaxed_beta_binomial (line 247) | def test_relaxed_beta_binomial(): function test_relaxed_overdispersed_beta_binomial (line 263) | def test_relaxed_overdispersed_beta_binomial(overdispersion): FILE: tests/contrib/epidemiology/test_models.py function test_simple_sir_smoke (line 62) | def test_simple_sir_smoke(duration, forecast, options, algo): function test_simple_seir_smoke (line 106) | def test_simple_seir_smoke(duration, forecast, options, algo): function test_simple_seird_smoke (line 150) | def test_simple_seird_smoke(duration, forecast, options, algo): function test_overdispersed_sir_smoke (line 198) | def test_overdispersed_sir_smoke(duration, forecast, options): function test_overdispersed_seir_smoke (line 233) | def test_overdispersed_seir_smoke(duration, forecast, options): function test_superspreading_sir_smoke (line 273) | def test_superspreading_sir_smoke(duration, forecast, options): function test_superspreading_seir_smoke (line 309) | def test_superspreading_seir_smoke(duration, forecast, options): function test_coalescent_likelihood_smoke (line 349) | def test_coalescent_likelihood_smoke(duration, forecast, options, algo): function test_heterogeneous_sir_smoke (line 404) | def test_heterogeneous_sir_smoke(duration, forecast, options, algo): function test_sparse_smoke (line 442) | def test_sparse_smoke(duration, forecast, options): function test_unknown_start_smoke (line 489) | def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): function test_regional_smoke (line 541) | def test_regional_smoke(duration, forecast, options, algo): class RegionalSIRModelWithFinalize (line 575) | class RegionalSIRModelWithFinalize(RegionalSIRModel): method finalize (line 576) | def finalize(self, params, prev, curr): function test_regional_finalize_smoke (line 600) | def test_regional_finalize_smoke(duration, forecast, options, algo): function test_hetero_regional_smoke (line 651) | def test_hetero_regional_smoke(duration, forecast, options, algo): FILE: tests/contrib/epidemiology/test_quant.py function test_quantization_scheme (line 11) | def test_quantization_scheme(num_quant_bins, num_samples=1000 * 1000): FILE: tests/contrib/epidemiology/test_util.py function test_clamp (line 14) | def test_clamp(shape, min, max): function test_cat2_scalar (line 35) | def test_cat2_scalar(shape): FILE: tests/contrib/forecast/test_evaluate.py class Model (line 16) | class Model(ForecastingModel): method model (line 17) | def model(self, zero_data, covariates): function test_simple (line 43) | def test_simple( function test_poisson (line 81) | def test_poisson( function test_custom_warm_start (line 129) | def test_custom_warm_start(): FILE: tests/contrib/forecast/test_forecaster.py class Model0 (line 16) | class Model0(ForecastingModel): method model (line 17) | def model(self, zero_data, covariates): class Model1 (line 30) | class Model1(ForecastingModel): method model (line 31) | def model(self, zero_data, covariates): class Model2 (line 44) | class Model2(ForecastingModel): method model (line 45) | def model(self, zero_data, covariates): class Model3 (line 59) | class Model3(ForecastingModel): method model (line 60) | def model(self, zero_data, covariates): class Model4 (line 81) | class Model4(ForecastingModel): method model (line 82) | def model(self, zero_data, covariates): function test_smoke (line 121) | def test_smoke( function test_trace_smoke (line 177) | def test_trace_smoke(Model, batch_shape, t_obs, obs_dim, cov_dim): function test_svi_custom_smoke (line 220) | def test_svi_custom_smoke(subsample_aware): class SubsampleModel3 (line 244) | class SubsampleModel3(ForecastingModel): method model (line 245) | def model(self, zero_data, covariates): class SubsampleModel4 (line 269) | class SubsampleModel4(ForecastingModel): method model (line 270) | def model(self, zero_data, covariates): function test_subsample_smoke (line 301) | def test_subsample_smoke(Model, t_obs, t_forecast, obs_dim, cov_dim): FILE: tests/contrib/forecast/test_util.py function random_dist (line 47) | def random_dist(Dist, shape, transform=None): function test_prefix_condition (line 106) | def test_prefix_condition(Dist, batch_shape, t, f, dim): function test_reshape_batch (line 123) | def test_reshape_batch(Dist, batch_shape, duration, dim): function test_reshape_transform_batch (line 139) | def test_reshape_transform_batch(transform, batch_shape, duration, dim): FILE: tests/contrib/funsor/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/contrib/funsor/test_enum_funsor.py function _check_loss_and_grads (line 34) | def _check_loss_and_grads(expected_loss, actual_loss): function test_elbo_plate_plate (line 73) | def test_elbo_plate_plate(outer_dim, inner_dim): function test_elbo_enumerate_1 (line 120) | def test_elbo_enumerate_1(scale): function test_elbo_enumerate_2 (line 166) | def test_elbo_enumerate_2(scale): function test_elbo_enumerate_3 (line 218) | def test_elbo_enumerate_3(scale): function test_elbo_enumerate_plate_1 (line 272) | def test_elbo_enumerate_plate_1(num_samples, num_masked, scale): function test_elbo_enumerate_plate_2 (line 341) | def test_elbo_enumerate_plate_2(num_samples, num_masked, scale): function test_elbo_enumerate_plate_3 (line 417) | def test_elbo_enumerate_plate_3(num_samples, num_masked, scale): function test_elbo_enumerate_plate_4 (line 506) | def test_elbo_enumerate_plate_4(outer_obs, inner_obs, scale): function test_elbo_enumerate_plate_5 (line 573) | def test_elbo_enumerate_plate_5(): function test_elbo_enumerate_plate_6 (line 645) | def test_elbo_enumerate_plate_6(enumerate1): function test_elbo_enumerate_plate_7 (line 707) | def test_elbo_enumerate_plate_7(scale): function test_elbo_enumerate_plates_1 (line 814) | def test_elbo_enumerate_plates_1(scale): function test_elbo_enumerate_plates_2 (line 878) | def test_elbo_enumerate_plates_2(scale): function test_elbo_enumerate_plates_3 (line 934) | def test_elbo_enumerate_plates_3(scale): function test_elbo_enumerate_plates_4 (line 986) | def test_elbo_enumerate_plates_4(scale): function test_elbo_enumerate_plates_5 (line 1045) | def test_elbo_enumerate_plates_5(scale): function test_elbo_enumerate_plates_6 (line 1108) | def test_elbo_enumerate_plates_6(scale): function test_elbo_enumerate_plates_7 (line 1246) | def test_elbo_enumerate_plates_7(scale): function test_elbo_enumerate_plates_8 (line 1404) | def test_elbo_enumerate_plates_8( function test_elbo_enumerate_plate_9 (line 1529) | def test_elbo_enumerate_plate_9(): function test_elbo_enumerate_plate_10 (line 1603) | def test_elbo_enumerate_plate_10(): function test_elbo_enumerate_plate_11 (line 1679) | def test_elbo_enumerate_plate_11(): function test_elbo_enumerate_plate_12 (line 1755) | def test_elbo_enumerate_plate_12(): function test_elbo_enumerate_plate_13 (line 1849) | def test_elbo_enumerate_plate_13(): FILE: tests/contrib/funsor/test_infer_discrete.py function test_hmm_smoke (line 34) | def test_hmm_smoke(length, temperature): function test_distribution_1 (line 65) | def test_distribution_1(temperature): function test_distribution_2 (line 117) | def test_distribution_2(temperature): function test_distribution_3_simple (line 184) | def test_distribution_3_simple(temperature): function test_distribution_3 (line 242) | def test_distribution_3(temperature): function model_zzxx (line 305) | def model_zzxx(): function model2 (line 326) | def model2(): function test_svi_model_side_enumeration (line 344) | def test_svi_model_side_enumeration(model, temperature): function test_mcmc_model_side_enumeration (line 379) | def test_mcmc_model_side_enumeration(model, temperature): function test_distribution_masked (line 412) | def test_distribution_masked(temperature): FILE: tests/contrib/funsor/test_named_handlers.py function test_iteration (line 27) | def test_iteration(): function test_nesting (line 52) | def test_nesting(): function test_staggered (line 97) | def test_staggered(): function test_fresh_inputs_to_funsor (line 113) | def test_fresh_inputs_to_funsor(): function test_iteration_fresh (line 126) | def test_iteration_fresh(): function test_staggered_fresh (line 147) | def test_staggered_fresh(): FILE: tests/contrib/funsor/test_pyroapi_funsor.py function backend (line 21) | def backend(request): FILE: tests/contrib/funsor/test_tmc.py function test_tmc_categoricals (line 33) | def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_str... function test_tmc_normals_chain_gradient (line 117) | def test_tmc_normals_chain_gradient( FILE: tests/contrib/funsor/test_valid_models_enum.py function assert_ok (line 38) | def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs): function _check_traces (line 94) | def _check_traces(tr_pyro, tr_funsor): function test_enum_recycling_chain_iter (line 217) | def test_enum_recycling_chain_iter(history): function test_enum_recycling_chain_iter_interleave_parallel_sequential (line 231) | def test_enum_recycling_chain_iter_interleave_parallel_sequential(history): function test_enum_recycling_chain_while (line 250) | def test_enum_recycling_chain_while(history): function test_enum_recycling_chain_recur (line 267) | def test_enum_recycling_chain_recur(history): function test_enum_recycling_dbn (line 287) | def test_enum_recycling_dbn(markov, use_vindex): function test_enum_recycling_nested (line 316) | def test_enum_recycling_nested(): function test_enum_recycling_grid (line 360) | def test_enum_recycling_grid(grid_size, use_vindex): function test_enum_recycling_reentrant_history (line 389) | def test_enum_recycling_reentrant_history(max_plate_nesting, depth, hist... function test_enum_recycling_mutual_recursion (line 422) | def test_enum_recycling_mutual_recursion(max_plate_nesting, depth): function test_enum_recycling_interleave (line 470) | def test_enum_recycling_interleave(max_plate_nesting): function test_markov_history (line 486) | def test_markov_history(max_plate_nesting, history): FILE: tests/contrib/funsor/test_valid_models_plate.py function test_enum_discrete_non_enumerated_plate_ok (line 30) | def test_enum_discrete_non_enumerated_plate_ok(enumerate_): function test_plate_dim_allocation_ok (line 54) | def test_plate_dim_allocation_ok(plate_dims): function test_enum_recycling_plate (line 75) | def test_enum_recycling_plate(subsampling, reuse_plate, tmc_strategy): function test_enum_discrete_plates_dependency_ok (line 130) | def test_enum_discrete_plates_dependency_ok(enumerate_, reuse_plate): function test_enum_discrete_plate_shape_broadcasting_ok (line 152) | def test_enum_discrete_plate_shape_broadcasting_ok(subsampling, enumerat... function test_plate_subsample_primitive_ok (line 188) | def test_plate_subsample_primitive_ok(subsample_size, num_samples): FILE: tests/contrib/funsor/test_valid_models_sequential_plate.py function test_enum_discrete_iplate_plate_dependency_ok (line 30) | def test_enum_discrete_iplate_plate_dependency_ok(subsampling, enumerate_): function test_enum_iplate_iplate_ok (line 48) | def test_enum_iplate_iplate_ok(): function test_enum_plate_iplate_ok (line 77) | def test_enum_plate_iplate_ok(): function test_enum_iplate_plate_ok (line 103) | def test_enum_iplate_plate_ok(): FILE: tests/contrib/funsor/test_vectorized_markov.py function model_0 (line 30) | def model_0(data, history, vectorized): function model_1 (line 68) | def model_1(data, history, vectorized): function model_2 (line 100) | def model_2(data, history, vectorized): function model_3 (line 148) | def model_3(data, history, vectorized): function model_4 (line 202) | def model_4(data, history, vectorized): function model_5 (line 261) | def model_5(data, history, vectorized): function model_6 (line 306) | def model_6(data, history, vectorized): function model_7 (line 350) | def model_7(data, history, vectorized): function _guide_from_model (line 398) | def _guide_from_model(model): function test_enumeration (line 425) | def test_enumeration(model, data, var, history, use_replay): function model_8 (line 510) | def model_8(weeks_data, days_data, history, vectorized): function test_enumeration_multi (line 584) | def test_enumeration_multi( function guide_empty (line 707) | def guide_empty(data, history, vectorized): function test_model_enumerated_elbo (line 727) | def test_model_enumerated_elbo(model, guide, data, history): function guide_empty_multi (line 749) | def guide_empty_multi(weeks_data, days_data, history, vectorized): function test_model_enumerated_elbo_multi (line 761) | def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data... function model_10 (line 787) | def model_10(data, history, vectorized): function test_guide_enumerated_elbo (line 825) | def test_guide_enumerated_elbo(model, guide, data, history): FILE: tests/contrib/gp/test_conditional.py function test_conditional (line 52) | def test_conditional(Xnew, X, kernel, f_loc, f_scale_tril, loc, cov): function test_conditional_whiten (line 71) | def test_conditional_whiten(Xnew, X, kernel, f_loc, f_scale_tril, loc, c... FILE: tests/contrib/gp/test_kernels.py function test_kernel_forward (line 93) | def test_kernel_forward(kernel, X, Z, K_sum): function test_combination (line 107) | def test_combination(): function test_active_dims_overlap_ok (line 121) | def test_active_dims_overlap_ok(): function test_active_dims_disjoint_ok (line 127) | def test_active_dims_disjoint_ok(): function test_transforming (line 133) | def test_transforming(): FILE: tests/contrib/gp/test_likelihoods.py function test_inference (line 55) | def test_inference(model_class, X, y, kernel, likelihood): function test_inference_with_empty_latent_shape (line 71) | def test_inference_with_empty_latent_shape(model_class, X, y, kernel, li... function test_forward (line 87) | def test_forward(model_class, X, y, kernel, likelihood): function test_forward_with_empty_latent_shape (line 108) | def test_forward_with_empty_latent_shape(model_class, X, y, kernel, like... FILE: tests/contrib/gp/test_models.py function _kernel (line 36) | def _kernel(): function _likelihood (line 40) | def _likelihood(): function _TEST_CASES (line 44) | def _TEST_CASES(): function test_model (line 65) | def test_model(model_class, X, y, kernel, likelihood): function test_forward (line 83) | def test_forward(model_class, X, y, kernel, likelihood): function test_forward_with_empty_latent_shape (line 134) | def test_forward_with_empty_latent_shape(model_class, X, y, kernel, like... function test_inference (line 161) | def test_inference(model_class, X, y, kernel, likelihood): function test_inference_sgpr (line 185) | def test_inference_sgpr(): function test_inference_vsgp (line 206) | def test_inference_vsgp(): function test_inference_whiten_vsgp (line 228) | def test_inference_whiten_vsgp(): function test_inference_with_empty_latent_shape (line 251) | def test_inference_with_empty_latent_shape(model_class, X, y, kernel, li... function test_inference_with_whiten (line 268) | def test_inference_with_whiten(model_class, X, y, kernel, likelihood): function test_hmc (line 283) | def test_hmc(model_class, X, y, kernel, likelihood): function test_inference_deepGP (line 302) | def test_inference_deepGP(): function test_gplvm (line 335) | def test_gplvm(model_class, X, y, kernel, likelihood): function _pre_test_mean_function (line 348) | def _pre_test_mean_function(): function _mape (line 372) | def _mape(y_true, y_pred): function _post_test_mean_function (line 376) | def _post_test_mean_function(gpmodule, Xnew, y_true): function test_mean_function_GPR (line 384) | def test_mean_function_GPR(): function test_mean_function_SGPR (line 391) | def test_mean_function_SGPR(): function test_mean_function_SGPR_DTC (line 399) | def test_mean_function_SGPR_DTC(): function test_mean_function_SGPR_FITC (line 407) | def test_mean_function_SGPR_FITC(): function test_mean_function_VGP (line 417) | def test_mean_function_VGP(): function test_mean_function_VGP_whiten (line 425) | def test_mean_function_VGP_whiten(): function test_mean_function_VSGP (line 436) | def test_mean_function_VSGP(): function test_mean_function_VSGP_whiten (line 446) | def test_mean_function_VSGP_whiten(): FILE: tests/contrib/gp/test_parameterized.py function test_parameterized (line 15) | def test_parameterized(): function test_nested_parameterized (line 79) | def test_nested_parameterized(): function test_inference (line 116) | def test_inference(): FILE: tests/contrib/mue/test_dataloaders.py function test_biosequencedataset (line 13) | def test_biosequencedataset(source_type, alphabet, include_stop): function test_write (line 74) | def test_write(): FILE: tests/contrib/mue/test_missingdatahmm.py function test_hmm_log_prob (line 11) | def test_hmm_log_prob(): function test_shapes (line 84) | def test_shapes(batch_initial, batch_transition, batch_observation, batc... function test_DiscreteHMM_comparison (line 127) | def test_DiscreteHMM_comparison( function test_samples (line 193) | def test_samples(batch_data): function indiv_filter (line 239) | def indiv_filter(a0, a, e, x): function indiv_smooth (line 257) | def indiv_smooth(a0, a, e, x): function indiv_map_states (line 273) | def indiv_map_states(a0, a, e, x): function test_state_infer (line 298) | def test_state_infer(): function test_sample_given_states (line 442) | def test_sample_given_states(): function test_sample_states (line 495) | def test_sample_states(): FILE: tests/contrib/mue/test_models.py function test_ProfileHMM_smoke (line 16) | def test_ProfileHMM_smoke(jit): function test_FactorMuE_smoke (line 53) | def test_FactorMuE_smoke( FILE: tests/contrib/mue/test_statearrangers.py function simpleprod (line 10) | def simpleprod(lst): function test_profile_alternate_imp (line 21) | def test_profile_alternate_imp(M, batch_size, substitute): function test_profile_shapes (line 198) | def test_profile_shapes( function test_profile_trivial_cases (line 237) | def test_profile_trivial_cases(M): FILE: tests/contrib/oed/test_ewma.py function test_ewma (line 14) | def test_ewma(alpha, NS=10000, D=1): function test_ewma_log (line 30) | def test_ewma_log(): function test_ewma_log_with_s (line 40) | def test_ewma_log_with_s(): FILE: tests/contrib/oed/test_finite_spaces_eig.py function finite_space_model (line 26) | def finite_space_model(): function one_point_design (line 40) | def one_point_design(): function true_eig (line 45) | def true_eig(): function posterior_guide (line 49) | def posterior_guide(y_dict, design, observation_labels, target_labels): function marginal_guide (line 55) | def marginal_guide(design, observation_labels, target_labels): function likelihood_guide (line 60) | def likelihood_guide(theta_dict, design, observation_labels, target_labe... function make_lfire_classifier (line 66) | def make_lfire_classifier(n_theta_samples): function dv_critic (line 79) | def dv_critic(design, trace, observation_labels, target_labels): function test_posterior_finite_space_model (line 97) | def test_posterior_finite_space_model(finite_space_model, one_point_desi... function test_marginal_finite_space_model (line 126) | def test_marginal_finite_space_model(finite_space_model, one_point_desig... function test_marginal_likelihood_finite_space_model (line 155) | def test_marginal_likelihood_finite_space_model( function test_vnmc_finite_space_model (line 192) | def test_vnmc_finite_space_model(finite_space_model, one_point_design, t... function test_nmc_eig_finite_space_model (line 221) | def test_nmc_eig_finite_space_model(finite_space_model, one_point_design... function test_lfire_finite_space_model (line 230) | def test_lfire_finite_space_model(finite_space_model, one_point_design, ... function test_dv_finite_space_model (line 248) | def test_dv_finite_space_model(finite_space_model, one_point_design, tru... FILE: tests/contrib/oed/test_glmm.py function lm_2p_10_10_1 (line 22) | def lm_2p_10_10_1(design): function lm_2p_10_10_1_w12 (line 31) | def lm_2p_10_10_1_w12(design): function nz_lm_2p_10_10_1 (line 44) | def nz_lm_2p_10_10_1(design): function normal_inv_gamma_2_2_10_10 (line 54) | def normal_inv_gamma_2_2_10_10(design): function lr_10_10 (line 68) | def lr_10_10(design): function sigmoid_example (line 78) | def sigmoid_example(design): function test_log_prob_matches (line 165) | def test_log_prob_matches(model1, model2, design): FILE: tests/contrib/oed/test_linear_models_eig.py function linear_model (line 29) | def linear_model(): function one_point_design (line 38) | def one_point_design(): function posterior_guide (line 44) | def posterior_guide(y_dict, design, observation_labels, target_labels): function marginal_guide (line 56) | def marginal_guide(design, observation_labels, target_labels): function likelihood_guide (line 66) | def likelihood_guide(theta_dict, design, observation_labels, target_labe... function make_lfire_classifier (line 85) | def make_lfire_classifier(n_theta_samples): function dv_critic (line 106) | def dv_critic(design, trace, observation_labels, target_labels): function test_posterior_linear_model (line 120) | def test_posterior_linear_model(linear_model, one_point_design): function test_marginal_linear_model (line 150) | def test_marginal_linear_model(linear_model, one_point_design): function test_marginal_likelihood_linear_model (line 180) | def test_marginal_likelihood_linear_model(linear_model, one_point_design): function test_vnmc_linear_model (line 212) | def test_vnmc_linear_model(linear_model, one_point_design): function test_nmc_eig_linear_model (line 242) | def test_nmc_eig_linear_model(linear_model, one_point_design): function test_laplace_linear_model (line 250) | def test_laplace_linear_model(linear_model, one_point_design): function test_lfire_linear_model (line 269) | def test_lfire_linear_model(linear_model, one_point_design): function test_dv_linear_model (line 288) | def test_dv_linear_model(linear_model, one_point_design): FILE: tests/contrib/oed/test_xexpx.py function test_xexpx (line 18) | def test_xexpx(argument, output): FILE: tests/contrib/randomvariable/test_random_variable.py function test_add (line 13) | def test_add(): function test_subtract (line 22) | def test_subtract(): function test_multiply_divide (line 31) | def test_multiply_divide(): function test_abs (line 39) | def test_abs(): function test_neg (line 47) | def test_neg(): function test_pow (line 54) | def test_pow(): function test_tensor_ops (line 61) | def test_tensor_ops(): function test_chaining (line 72) | def test_chaining(): FILE: tests/contrib/test_hessian.py function test_hessian_mvn (line 11) | def test_hessian_mvn(): function test_hessian_multi_variables (line 21) | def test_hessian_multi_variables(): FILE: tests/contrib/test_minipyro.py function build_svi (line 19) | def build_svi(model, guide, elbo): function assert_ok (line 25) | def assert_ok(model, guide, elbo, steps=2, *args, **kwargs): function assert_error (line 34) | def assert_error(model, guide, elbo, match=None): function assert_warning (line 46) | def assert_warning(model, guide, elbo): function constrained_model (line 59) | def constrained_model(data): function guide_constrained_model (line 69) | def guide_constrained_model(data): function test_generate_data (line 75) | def test_generate_data(backend): function test_generate_data_plate (line 88) | def test_generate_data_plate(backend): function test_nonempty_model_empty_guide_ok (line 108) | def test_nonempty_model_empty_guide_ok(backend, jit): function test_plate_ok (line 126) | def test_plate_ok(backend, jit): function test_nested_plate_plate_ok (line 149) | def test_nested_plate_plate_ok(backend, jit): function test_local_param_ok (line 179) | def test_local_param_ok(backend, jit): function test_constraints (line 207) | def test_constraints(backend, jit): function test_elbo_jit (line 218) | def test_elbo_jit(backend): function test_elbo_equivalence (line 237) | def test_elbo_equivalence(backend, jit): function elbo_test_case (line 247) | def elbo_test_case(backend, jit, expected_elbo, data, steps=None): FILE: tests/contrib/test_util.py function test_get_indices_sizes (line 22) | def test_get_indices_sizes(): function test_tensor_to_dict (line 33) | def test_tensor_to_dict(): function test_rmv (line 52) | def test_rmv(A, b): function test_rvv (line 61) | def test_rvv(a, b): function test_lexpand (line 69) | def test_lexpand(): function test_rexpand (line 76) | def test_rexpand(): function test_rtril (line 85) | def test_rtril(): function test_rdiag (line 93) | def test_rdiag(): FILE: tests/contrib/test_zuko.py function test_ZukoToPyro (line 16) | def test_ZukoToPyro(multivariate: bool, rsample_and_log_prob: bool): FILE: tests/contrib/timeseries/test_gp.py function test_timeseries_models (line 42) | def test_timeseries_models(model, nu_statedim, obs_dim, T): function test_dependent_matern_gp (line 148) | def test_dependent_matern_gp(obs_dim): FILE: tests/contrib/timeseries/test_lgssm.py function test_generic_lgssm_forecast (line 15) | def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T): FILE: tests/contrib/tracking/test_assignment.py function assert_finite (line 23) | def assert_finite(tensor, name): function logit (line 27) | def logit(p): function dense_to_sparse (line 31) | def dense_to_sparse(assign_logits): function sparse_to_dense (line 41) | def sparse_to_dense(num_objects, num_detections, edges, assign_logits): function test_dense_smoke (line 47) | def test_dense_smoke(): function test_sparse_smoke (line 79) | def test_sparse_smoke(): function test_sparse_grid_smoke (line 111) | def test_sparse_grid_smoke(): function test_persistent_smoke (line 151) | def test_persistent_smoke(bp_iters): function test_flat_exact_1_1 (line 194) | def test_flat_exact_1_1(e, a): function test_flat_exact_2_1 (line 206) | def test_flat_exact_2_1(e, a11, a21): function test_flat_exact_1_2 (line 219) | def test_flat_exact_1_2(e1, e2, a11, a12): function test_flat_exact_2_2 (line 233) | def test_flat_exact_2_2(e1, e2, a11, a12, a22): function test_flat_bp_vs_exact (line 245) | def test_flat_bp_vs_exact(num_objects, num_detections): function test_flat_vs_persistent (line 258) | def test_flat_vs_persistent(num_objects, num_frames, bp_iters): function test_persistent_bp_vs_exact (line 272) | def test_persistent_bp_vs_exact(num_objects, num_frames, num_detections): function test_persistent_exact_5_4_3 (line 288) | def test_persistent_exact_5_4_3(e1, e2, e3, bp_iters, bp_momentum): function test_persistent_independent_subproblems (line 317) | def test_persistent_independent_subproblems( FILE: tests/contrib/tracking/test_distributions.py function test_EKFDistribution_smoke (line 14) | def test_EKFDistribution_smoke(Model, dim, time): FILE: tests/contrib/tracking/test_dynamic_models.py function assert_cov_validity (line 15) | def assert_cov_validity(cov, eigenvalue_lbnd=0.0, condition_number_ubnd=... function test_NcpContinuous (line 47) | def test_NcpContinuous(): function test_NcvContinuous (line 85) | def test_NcvContinuous(): function test_NcpDiscrete (line 120) | def test_NcpDiscrete(): function test_NcvDiscrete (line 158) | def test_NcvDiscrete(): FILE: tests/contrib/tracking/test_ekf.py function test_EKFState_with_NcpContinuous (line 12) | def test_EKFState_with_NcpContinuous(): function test_EKFState_with_NcvContinuous (line 44) | def test_EKFState_with_NcvContinuous(): FILE: tests/contrib/tracking/test_em.py function make_args (line 22) | def make_args(): function model (line 39) | def model(detections, args): function compute_exists_logits (line 72) | def compute_exists_logits(objects, args): function compute_assign_logits (line 81) | def compute_assign_logits(objects, detections, noise_scale, args): function guide (line 91) | def guide(detections, args): function generate_data (line 115) | def generate_data(args): function test_em (line 132) | def test_em(assignment_grad): function test_em_nested_in_svi (line 157) | def test_em_nested_in_svi(assignment_grad): function test_svi_multi (line 197) | def test_svi_multi(): FILE: tests/contrib/tracking/test_hashing.py function test_lsh_init (line 16) | def test_lsh_init(scale): function test_lsh_add (line 22) | def test_lsh_add(scale): function test_lsh_hash_nearby (line 30) | def test_lsh_hash_nearby(scale): function test_lsh_overwrite (line 62) | def test_lsh_overwrite(): function test_lsh_remove (line 74) | def test_lsh_remove(): function test_aps_init (line 86) | def test_aps_init(scale): function test_aps_hash (line 92) | def test_aps_hash(scale): function test_aps_try_add (line 111) | def test_aps_try_add(scale): function test_merge_points_small (line 125) | def test_merge_points_small(): function test_merge_points_large (line 147) | def test_merge_points_large(dim, radius): FILE: tests/contrib/tracking/test_measurements.py function test_PositionMeasurement (line 9) | def test_PositionMeasurement(): FILE: tests/distributions/conftest.py class FoldedNormal (line 22) | class FoldedNormal(dist.FoldedDistribution): method __init__ (line 25) | def __init__(self, loc, scale): method loc (line 29) | def loc(self): method scale (line 33) | def scale(self): class SparsePoisson (line 37) | class SparsePoisson(dist.Poisson): method __init__ (line 38) | def __init__(self, rate, *, validate_args=None): class SineSkewedUniform (line 42) | class SineSkewedUniform(dist.SineSkewed): method __init__ (line 43) | def __init__(self, lower, upper, skewness, *args, **kwargs): class SineSkewedVonMises (line 48) | class SineSkewedVonMises(dist.SineSkewed): method __init__ (line 49) | def __init__(self, von_loc, von_conc, skewness): function all_distributions (line 1053) | def all_distributions(request): function continuous_distributions (line 1062) | def continuous_distributions(request): function discrete_distributions (line 1071) | def discrete_distributions(request): function pytest_collection_modifyitems (line 1075) | def pytest_collection_modifyitems(items): FILE: tests/distributions/dist_fixture.py class Fixture (line 16) | class Fixture: method __init__ (line 17) | def __init__( method __repr__ (line 43) | def __repr__(self): method get_batch_data_indices (line 46) | def get_batch_data_indices(self): method get_test_data_indices (line 51) | def get_test_data_indices(self): method _extract_fixture_data (line 56) | def _extract_fixture_data(self, examples): method get_num_test_data (line 63) | def get_num_test_data(self): method get_samples (line 66) | def get_samples(self, num_samples, **dist_params): method get_test_data (line 71) | def get_test_data(self, idx, wrap_tensor=True): method get_dist_params (line 76) | def get_dist_params(self, idx, wrap_tensor=True): method _convert_logits_to_ps (line 81) | def _convert_logits_to_ps(self, dist_params): method get_scipy_logpdf (line 92) | def get_scipy_logpdf(self, idx): method get_scipy_batch_logpdf (line 108) | def get_scipy_batch_logpdf(self, idx): method get_num_samples (line 132) | def get_num_samples(self, idx): method get_test_distribution_name (line 160) | def get_test_distribution_name(self): function tensor_wrap (line 164) | def tensor_wrap(*args, **kwargs): FILE: tests/distributions/test_binomial.py function test_binomial_approx_sample (line 17) | def test_binomial_approx_sample(total_count, prob): function test_beta_binomial_approx_sample (line 31) | def test_beta_binomial_approx_sample(concentration1, concentration0, tot... function test_binomial_approx_log_prob (line 57) | def test_binomial_approx_log_prob(tol): FILE: tests/distributions/test_categorical.py class TestCategorical (line 15) | class TestCategorical(TestCase): method setUp (line 20) | def setUp(self): method test_log_prob_sum (line 39) | def test_log_prob_sum(self): method test_mean_and_var (line 50) | def test_mean_and_var(self): method test_support_non_vectorized (line 61) | def test_support_non_vectorized(self): method test_support (line 65) | def test_support(self): function wrap_nested (line 70) | def wrap_nested(x, dim): function dim (line 77) | def dim(request): function probs (line 82) | def probs(request): function modify_params_using_dims (line 86) | def modify_params_using_dims(probs, dim): function test_support_dims (line 90) | def test_support_dims(dim, probs): function test_sample_dims (line 96) | def test_sample_dims(dim, probs): function test_batch_log_dims (line 103) | def test_batch_log_dims(dim, probs): function test_view_reshape_bug (line 111) | def test_view_reshape_bug(): FILE: tests/distributions/test_coalescent.py function test_sample_is_valid (line 23) | def test_sample_is_valid(num_leaves): function test_simple_smoke (line 43) | def test_simple_smoke(num_leaves, num_steps, batch_shape, sample_shape): function test_with_rate_smoke (line 58) | def test_with_rate_smoke( function test_log_prob_unit_rate (line 78) | def test_log_prob_unit_rate(num_leaves, num_steps, batch_shape, sample_s... function test_log_prob_scale (line 93) | def test_log_prob_scale(num_leaves, num_steps, batch_shape, sample_shape): function test_log_prob_constant_rate_1 (line 114) | def test_log_prob_constant_rate_1(num_leaves, num_steps, batch_shape, sa... function test_log_prob_constant_rate_2 (line 136) | def test_log_prob_constant_rate_2(num_leaves, num_steps, batch_shape, sa... function test_likelihood_vectorized (line 155) | def test_likelihood_vectorized(num_leaves, num_steps, batch_shape, clamp... function test_likelihood_sequential (line 179) | def test_likelihood_sequential(num_leaves, num_steps, batch_shape, clamp... function tree (line 241) | def tree(): function test_bio_phylo_to_times (line 249) | def test_bio_phylo_to_times(tree): function test_bio_phylo_to_times_custom (line 262) | def test_bio_phylo_to_times_custom(tree): FILE: tests/distributions/test_conjugate.py function test_mean (line 31) | def test_mean(dist): function test_variance (line 55) | def test_variance(dist): function test_log_prob_support (line 71) | def test_log_prob_support(dist, values): function test_beta_binomial_log_prob (line 80) | def test_beta_binomial_log_prob(total_count, shape): function test_dirichlet_multinomial_log_prob (line 97) | def test_dirichlet_multinomial_log_prob(total_count, batch_shape, is_spa... function test_gamma_poisson_log_prob (line 116) | def test_gamma_poisson_log_prob(shape): FILE: tests/distributions/test_conjugate_update.py function test_beta_binomial (line 13) | def test_beta_binomial(sample_shape, batch_shape): function test_dirichlet_multinomial (line 29) | def test_dirichlet_multinomial(sample_shape, batch_shape): function test_gamma_poisson (line 45) | def test_gamma_poisson(sample_shape, batch_shape): FILE: tests/distributions/test_constraints.py function test_sphere_check (line 12) | def test_sphere_check(dim): function test_constraints (line 31) | def test_constraints(constraint, batch_shape, event_shape): FILE: tests/distributions/test_cuda.py function test_sample (line 16) | def test_sample(dist): function test_rsample (line 38) | def test_rsample(dist): function test_log_prob (line 79) | def test_log_prob(dist): FILE: tests/distributions/test_delta.py class TestDelta (line 14) | class TestDelta(TestCase): method setUp (line 15) | def setUp(self): method test_log_prob_sum (line 29) | def test_log_prob_sum(self): method test_batch_log_prob (line 33) | def test_batch_log_prob(self): method test_batch_log_prob_shape (line 43) | def test_batch_log_prob_shape(self): method test_mean_and_var (line 47) | def test_mean_and_var(self): function test_shapes (line 62) | def test_shapes(batch_dim, event_dim, has_log_density): function test_expand (line 75) | def test_expand(batch_shape): FILE: tests/distributions/test_distributions.py function _log_prob_shape (line 21) | def _log_prob_shape(dist, x_size=torch.Size()): function test_support_shape (line 32) | def test_support_shape(dist): function test_infer_shapes (line 45) | def test_infer_shapes(dist): function test_batch_log_prob (line 60) | def test_batch_log_prob(dist): function test_batch_log_prob_shape (line 74) | def test_batch_log_prob_shape(dist): function test_batch_entropy_shape (line 86) | def test_batch_entropy_shape(dist): function test_score_errors_event_dim_mismatch (line 97) | def test_score_errors_event_dim_mismatch(dist): function test_score_errors_non_broadcastable_data_shape (line 115) | def test_score_errors_non_broadcastable_data_shape(dist): function test_support_is_not_discrete (line 131) | def test_support_is_not_discrete(continuous_dist): function test_gof (line 138) | def test_gof(continuous_dist): function test_mean (line 166) | def test_mean(continuous_dist): function test_variance (line 188) | def test_variance(continuous_dist): function test_cdf_icdf (line 209) | def test_cdf_icdf(continuous_dist): function test_support_is_discrete (line 225) | def test_support_is_discrete(discrete_dist): function test_enumerate_support (line 232) | def test_enumerate_support(discrete_dist): function test_enumerate_support_shape (line 246) | def test_enumerate_support_shape(dist): function test_distribution_validate_args (line 277) | def test_distribution_validate_args(dist_class, args, validate_args): function check_sample_shapes (line 286) | def check_sample_shapes(small, large): function test_expand_by (line 309) | def test_expand_by(dist, sample_shape, shape_type): function test_expand_new_dim (line 320) | def test_expand_new_dim(dist, sample_shape, shape_type, default): function test_expand_existing_dim (line 336) | def test_expand_existing_dim(dist, shape_type, default): function test_subsequent_expands_ok (line 362) | def test_subsequent_expands_ok(dist, sample_shapes, default): function test_expand_error (line 388) | def test_expand_error(dist, initial_shape, proposed_shape, default): function test_expand_reshaped_distribution (line 411) | def test_expand_reshaped_distribution(extra_event_dims, expand_shape, de... function test_expand_enumerate_support (line 438) | def test_expand_enumerate_support(): FILE: tests/distributions/test_empirical.py function test_unweighted_mean_and_var (line 13) | def test_unweighted_mean_and_var(size, dtype): function test_unweighted_samples (line 37) | def test_unweighted_samples(batch_shape, event_shape, sample_shape, dtype): function test_sample_examples (line 71) | def test_sample_examples(sample, weights, expected_mean, expected_var): function test_log_prob (line 92) | def test_log_prob(batch_shape, event_shape, dtype): function test_weighted_sample_coherence (line 117) | def test_weighted_sample_coherence(event_shape, dtype): function test_weighted_mean_var (line 143) | def test_weighted_mean_var(event_shape, dtype, batch_shape): function test_mean_var_non_nan (line 164) | def test_mean_var_non_nan(): FILE: tests/distributions/test_extended.py function check_grad (line 15) | def check_grad(value, *params): function test_extended_binomial (line 21) | def test_extended_binomial(tol): function test_extended_beta_binomial (line 59) | def test_extended_beta_binomial(tol): FILE: tests/distributions/test_gaussian_mixtures.py function test_mean_gradient (line 29) | def test_mean_gradient(K, D, flat_logits, cost_function, mix_dist, batch... function test_mix_of_diag_normals_shared_cov_log_prob (line 172) | def test_mix_of_diag_normals_shared_cov_log_prob(batch_size): function test_gsm_log_prob (line 198) | def test_gsm_log_prob(): function test_mix_of_diag_normals_log_prob (line 215) | def test_mix_of_diag_normals_log_prob(batch_size): FILE: tests/distributions/test_grouped_normal_normal.py function test_grouped_normal_normal (line 12) | def test_grouped_normal_normal(num_groups=3, num_samples=10**5): FILE: tests/distributions/test_haar.py function test_haar_ortho (line 12) | def test_haar_ortho(size): FILE: tests/distributions/test_hmm.py function check_expand (line 42) | def check_expand(old_dist, old_data): function check_sample_shape (line 56) | def check_sample_shape(d): function test_sequential_logmatmulexp (line 68) | def test_sequential_logmatmulexp(batch_shape, state_dim, num_steps): function test_sequential_gamma_gaussian_tensordot (line 97) | def test_sequential_gamma_gaussian_tensordot(batch_shape, state_dim, num... function test_discrete_hmm_shape (line 138) | def test_discrete_hmm_shape( function test_discrete_hmm_homogeneous_trick (line 189) | def test_discrete_hmm_homogeneous_trick( function empty_guide (line 209) | def empty_guide(*args, **kwargs): function test_discrete_hmm_categorical (line 214) | def test_discrete_hmm_categorical(num_steps): function test_discrete_hmm_diag_normal (line 248) | def test_discrete_hmm_diag_normal(num_steps): function test_discrete_hmm_distribution (line 285) | def test_discrete_hmm_distribution(): function test_gaussian_hmm_shape (line 324) | def test_gaussian_hmm_shape( function test_gaussian_hmm_high_obs_dim (line 399) | def test_gaussian_hmm_high_obs_dim(): function test_gaussian_hmm_distribution (line 424) | def test_gaussian_hmm_distribution( function test_gaussian_mrf_shape (line 558) | def test_gaussian_mrf_shape(init_shape, trans_shape, obs_shape, hidden_d... function test_gaussian_mrf_log_prob (line 583) | def test_gaussian_mrf_log_prob( function test_gaussian_mrf_log_prob_block_diag (line 644) | def test_gaussian_mrf_log_prob_block_diag( function test_gamma_gaussian_hmm_shape (line 691) | def test_gamma_gaussian_hmm_shape( function test_gamma_gaussian_hmm_log_prob (line 745) | def test_gamma_gaussian_hmm_log_prob( function random_stable (line 807) | def random_stable(stability, skew_scale_loc_shape): function test_stable_hmm_shape (line 837) | def test_stable_hmm_shape( function random_studentt (line 875) | def random_studentt(shape): function test_studentt_hmm_shape (line 904) | def test_studentt_hmm_shape( function test_independent_hmm_shape (line 964) | def test_independent_hmm_shape( FILE: tests/distributions/test_ig.py function test_sample (line 15) | def test_sample(concentration, rate, n_samples=int(1e6)): function test_log_prob (line 27) | def test_log_prob(concentration, rate, value): FILE: tests/distributions/test_improper_uniform.py function test_improper_uniform (line 23) | def test_improper_uniform(constraint, batch_shape, event_shape): FILE: tests/distributions/test_independent.py function test_independent (line 25) | def test_independent(base_dist, sample_shape, batch_shape, reinterpreted... function test_to_event (line 68) | def test_to_event(base_dist): function test_expand (line 111) | def test_expand(sample_shape, batch_shape, event_shape): FILE: tests/distributions/test_kl.py function test_kl_delta_normal_shape (line 14) | def test_kl_delta_normal_shape(batch_shape): function test_kl_delta_mvn_shape (line 25) | def test_kl_delta_mvn_shape(batch_shape, size): function test_kl_independent_normal (line 38) | def test_kl_independent_normal(batch_shape, event_shape): function test_kl_independent_delta_mvn_shape (line 51) | def test_kl_independent_delta_mvn_shape(batch_shape, size): function test_kl_independent_normal_mvn (line 64) | def test_kl_independent_normal_mvn(batch_shape, size): function test_kl_transformed_transformed (line 85) | def test_kl_transformed_transformed(shape, event_dim, transform): FILE: tests/distributions/test_lkj.py function test_constraint (line 22) | def test_constraint(value_shape): function _autograd_log_det (line 30) | def _autograd_log_det(ys, x): function test_unconstrained_to_corr_cholesky_transform (line 41) | def test_unconstrained_to_corr_cholesky_transform(y_shape): function test_corr_cholesky_transform (line 74) | def test_corr_cholesky_transform(x_shape, mapping): function test_log_prob_conc1 (line 95) | def test_log_prob_conc1(dim): function test_log_prob_d2 (line 123) | def test_log_prob_d2(concentration): function test_sample_batch (line 138) | def test_sample_batch(): FILE: tests/distributions/test_log_normal_negative_binomial.py function test_lnnb_shapes (line 13) | def test_lnnb_shapes(num_quad_points, shape): function test_lnnb_mean_variance (line 31) | def test_lnnb_mean_variance( FILE: tests/distributions/test_lowrank_mvn.py function test_scale_tril (line 10) | def test_scale_tril(): function test_log_prob (line 22) | def test_log_prob(): function test_variance (line 35) | def test_variance(): FILE: tests/distributions/test_mask.py function checker_mask (line 14) | def checker_mask(shape): function test_mask (line 25) | def test_mask(batch_dim, event_dim, mask_dim): function test_mask_type (line 67) | def test_mask_type(mask): function test_broadcast (line 89) | def test_broadcast(event_shape, dist_shape, mask_shape): function test_kl_divergence (line 102) | def test_kl_divergence(): function test_kl_divergence_type (line 119) | def test_kl_divergence_type(p_mask, q_mask): class NormalBomb (line 137) | class NormalBomb(Normal): method log_prob (line 138) | def log_prob(self, value): method score_parts (line 141) | def score_parts(self, value): function test_mask_noop (line 146) | def test_mask_noop(shape): FILE: tests/distributions/test_mixture.py function test_masked_mixture_univariate (line 25) | def test_masked_mixture_univariate(component0, component1, sample_shape,... function test_masked_mixture_multivariate (line 52) | def test_masked_mixture_multivariate(sample_shape, batch_shape): function test_broadcast (line 88) | def test_broadcast(mask_shape, component0_shape, component1_shape, value... function test_expand (line 105) | def test_expand(sample_shape, batch_shape, event_shape): FILE: tests/distributions/test_mvn.py function random_mvn (line 11) | def random_mvn(loc_shape, cov_shape, dim): function test_shape (line 46) | def test_shape(loc_shape, cov_shape, dim): FILE: tests/distributions/test_mvt.py function random_mvt (line 14) | def random_mvt(df_shape, loc_shape, cov_shape, dim): function test_shape (line 59) | def test_shape(df_shape, loc_shape, cov_shape, dim): function test_log_prob (line 83) | def test_log_prob(batch_shape, dim): function test_rsample (line 108) | def test_rsample(dim, df, num_samples=200 * 1000): function test_log_prob_normalization (line 129) | def test_log_prob_normalization(dim, df=6.1, grid_size=2000, domain_widt... function test_mean_var (line 161) | def test_mean_var(batch_shape): FILE: tests/distributions/test_nanmasked.py function test_normal (line 18) | def test_normal(batch_shape): function test_multivariate_normal (line 46) | def test_multivariate_normal(batch_shape, p): function test_multivariate_normal_model (line 77) | def test_multivariate_normal_model(): FILE: tests/distributions/test_omt_mvn.py function analytic_grad (line 16) | def analytic_grad(L11=1.0, L22=1.0, L21=1.0, omega1=1.0, omega2=1.0): function test_mean_gradient (line 32) | def test_mean_gradient( function test_mean_single_gradient (line 73) | def test_mean_single_gradient( function test_log_prob (line 114) | def test_log_prob(mvn_dist): FILE: tests/distributions/test_one_hot_categorical.py class TestOneHotCategorical (line 14) | class TestOneHotCategorical(TestCase): method setUp (line 19) | def setUp(self): method test_support_non_vectorized (line 52) | def test_support_non_vectorized(self): method test_support (line 56) | def test_support(self): function wrap_nested (line 61) | def wrap_nested(x, dim): function assert_correct_dimensions (line 67) | def assert_correct_dimensions(sample, probs): function dim (line 74) | def dim(request): function probs (line 79) | def probs(request): function modify_params_using_dims (line 83) | def modify_params_using_dims(probs, dim): function test_support_dims (line 87) | def test_support_dims(dim, probs): function test_sample_dims (line 101) | def test_sample_dims(dim, probs): function test_batch_log_dims (line 107) | def test_batch_log_dims(dim, probs): FILE: tests/distributions/test_one_one_matching.py function _hash (line 16) | def _hash(value): function test_enumerate (line 22) | def test_enumerate(num_nodes, dtype): function test_sample_shape_smoke (line 35) | def test_sample_shape_smoke(num_nodes, sample_shape, dtype, bp_iters): function test_log_prob_full (line 47) | def test_log_prob_full(num_nodes, dtype, bp_iters): function test_log_prob_hard (line 60) | def test_log_prob_hard(dtype, bp_iters): function assert_grads_ok (line 72) | def assert_grads_ok(logits, bp_iters=None): function assert_grads_agree (line 80) | def assert_grads_agree(logits): function test_grad_full (line 91) | def test_grad_full(num_nodes): function test_grad_hard (line 101) | def test_grad_hard(num_nodes): function test_mode (line 114) | def test_mode(num_nodes, dtype): function test_mode_smoke (line 127) | def test_mode_smoke(num_nodes, dtype): function test_sample (line 137) | def test_sample(num_nodes, dtype, bp_iters): FILE: tests/distributions/test_one_two_matching.py function _hash (line 16) | def _hash(value): function random_phylo_logits (line 20) | def random_phylo_logits(num_leaves, dtype): function test_enumerate (line 43) | def test_enumerate(num_destins, dtype): function test_sample_shape_smoke (line 57) | def test_sample_shape_smoke(num_destins, sample_shape, dtype, bp_iters): function test_log_prob_full (line 70) | def test_log_prob_full(num_destins, dtype, bp_iters): function test_log_prob_hard (line 84) | def test_log_prob_hard(dtype, bp_iters): function test_log_prob_phylo (line 99) | def test_log_prob_phylo(num_leaves, dtype, bp_iters): function test_log_prob_phylo_smoke (line 112) | def test_log_prob_phylo_smoke(num_leaves, dtype): function assert_grads_ok (line 122) | def assert_grads_ok(logits, bp_iters=None): function assert_grads_agree (line 130) | def assert_grads_agree(logits): function test_grad_full (line 141) | def test_grad_full(num_destins): function test_grad_hard (line 152) | def test_grad_hard(num_destins): function test_grad_phylo (line 166) | def test_grad_phylo(num_leaves): function test_mode_full (line 177) | def test_mode_full(num_destins, dtype): function test_mode_phylo (line 190) | def test_mode_phylo(num_leaves, dtype): function test_mode_full_smoke (line 202) | def test_mode_full_smoke(num_destins, dtype): function test_mode_phylo_smoke (line 212) | def test_mode_phylo_smoke(num_leaves, dtype): function test_sample_full (line 222) | def test_sample_full(num_destins, dtype, bp_iters): function test_sample_phylo (line 248) | def test_sample_phylo(num_leaves, dtype, bp_iters): FILE: tests/distributions/test_ordered_logistic.py function test_sample (line 16) | def test_sample(n_cutpoints, pred_shape): function test_constraints (line 26) | def test_constraints(): function test_broadcast (line 37) | def test_broadcast(): function test_expand (line 59) | def test_expand(): function test_autograd (line 71) | def test_autograd(): function test_transform_bijection (line 92) | def test_transform_bijection(batch_shape, event_shape): function cjald (line 102) | def cjald(func, X): function test_transform_log_abs_det (line 118) | def test_transform_log_abs_det(batch_shape, event_shape): FILE: tests/distributions/test_pickle.py function test_pickle (line 74) | def test_pickle(Dist): FILE: tests/distributions/test_polya_gamma.py function test_polya_gamma (line 12) | def test_polya_gamma(batch_shape, num_points=20000): FILE: tests/distributions/test_projected_normal.py function test_log_prob (line 14) | def test_log_prob(dtype, dim, strength): FILE: tests/distributions/test_rejector.py function test_rejection_standard_gamma_sample_shape (line 23) | def test_rejection_standard_gamma_sample_shape(sample_shape, batch_shape): function test_rejection_exponential_sample_shape (line 32) | def test_rejection_exponential_sample_shape(sample_shape, batch_shape): function compute_elbo_grad (line 40) | def compute_elbo_grad(model, guide, variables): function test_rejector (line 51) | def test_rejector(rate, factor): function test_exponential_elbo (line 67) | def test_exponential_elbo(rate, factor): function test_standard_gamma_elbo (line 86) | def test_standard_gamma_elbo(alpha): function test_gamma_elbo (line 104) | def test_gamma_elbo(alpha, beta): function test_shape_augmented_gamma_elbo (line 133) | def test_shape_augmented_gamma_elbo(alpha, beta): function test_shape_augmented_beta (line 162) | def test_shape_augmented_beta(alpha, beta): FILE: tests/distributions/test_relaxed_straight_through.py function test_onehot_shapes (line 33) | def test_onehot_shapes(probs): function test_onehot_entropy_grad (line 44) | def test_onehot_entropy_grad(temp): function test_onehot_svi_usage (line 67) | def test_onehot_svi_usage(): function test_bernoulli_shapes (line 96) | def test_bernoulli_shapes(probs): function test_bernoulli_entropy_grad (line 107) | def test_bernoulli_entropy_grad(temp): FILE: tests/distributions/test_reshape.py function test_sample_shape_order (line 11) | def test_sample_shape_order(): function test_idempotent (line 25) | def test_idempotent(batch_dim, event_dim): function test_reshape (line 44) | def test_reshape(sample_dim, extra_event_dims): function test_reshape_reshape (line 86) | def test_reshape_reshape(sample_dim, extra_event_dims): function test_extra_event_dim_overflow (line 129) | def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim): function test_independent_entropy (line 152) | def test_independent_entropy(): FILE: tests/distributions/test_shapes.py function test_categorical_shape (line 9) | def test_categorical_shape(): function test_one_hot_categorical_shape (line 18) | def test_one_hot_categorical_shape(): function test_normal_shape (line 27) | def test_normal_shape(): function test_dirichlet_shape (line 37) | def test_dirichlet_shape(): function test_zip_shape (line 46) | def test_zip_shape(): function test_bernoulli_log_prob_shape (line 56) | def test_bernoulli_log_prob_shape(): function test_categorical_log_prob_shape (line 63) | def test_categorical_log_prob_shape(): function test_one_hot_categorical_log_prob_shape (line 70) | def test_one_hot_categorical_log_prob_shape(): function test_normal_log_prob_shape (line 78) | def test_normal_log_prob_shape(): function test_diag_normal_log_prob_shape (line 86) | def test_diag_normal_log_prob_shape(): FILE: tests/distributions/test_sine_bivariate_von_mises.py function _unnorm_log_prob (line 17) | def _unnorm_log_prob(value, loc1, loc2, conc1, conc2, corr): function test_log_binomial (line 28) | def test_log_binomial(n): function test_bvm_unnorm_log_prob (line 35) | def test_bvm_unnorm_log_prob(batch_dim): function test_bvm_multidim (line 54) | def test_bvm_multidim(): function test_mle_bvm (line 76) | def test_mle_bvm(): function test_sine_bivariate_von_mises_norm (line 136) | def test_sine_bivariate_von_mises_norm(conc): FILE: tests/distributions/test_sine_skewed.py function _skewness (line 18) | def _skewness(event_shape): function test_ss_multidim_log_prob (line 52) | def test_ss_multidim_log_prob(expand_shape, dist): function test_ss_mle (line 69) | def test_ss_mle(dim, dist): FILE: tests/distributions/test_spanning_tree.py function test_make_complete_graph (line 36) | def test_make_complete_graph(num_vertices, expected_grid, backend): function test_sample_tree_mcmc_smoke (line 48) | def test_sample_tree_mcmc_smoke(num_edges, backend): function test_sample_tree_approx_smoke (line 62) | def test_sample_tree_approx_smoke(num_edges, backend): function test_find_best_tree_smoke (line 75) | def test_find_best_tree_smoke(num_edges, backend): function test_enumerate_support (line 86) | def test_enumerate_support(num_edges): function test_partition_function (line 101) | def test_partition_function(num_edges): function test_log_prob (line 119) | def test_log_prob(num_edges): function test_edge_mean_function (line 135) | def test_edge_mean_function(num_edges): function test_mode (line 159) | def test_mode(num_edges, backend): function test_sample_tree_gof (line 181) | def test_sample_tree_gof(method, backend, num_edges, pattern): FILE: tests/distributions/test_stable.py function test_shape (line 18) | def test_shape(sample_shape, batch_shape): function test_sample (line 35) | def test_sample(alpha, beta): function test_sample_2 (line 76) | def test_sample_2(alpha, beta): function test_normal (line 97) | def test_normal(loc, scale): function test_additive (line 111) | def test_additive(stability, skew0, skew1, scale0, scale1): function test_mean (line 131) | def test_mean(stability, skew, scale, coords): function test_variance (line 143) | def test_variance(stability, scale): FILE: tests/distributions/test_stable_log_prob.py function test_stable_gof (line 27) | def test_stable_gof(stability, skew): function test_stable_with_log_prob_param_fit (line 83) | def test_stable_with_log_prob_param_fit(alpha, beta, c, mu, alpha_0, bet... FILE: tests/distributions/test_tensor_type.py function test_data (line 13) | def test_data(): function alpha (line 18) | def alpha(): function beta (line 26) | def beta(): function float_test_data (line 34) | def float_test_data(test_data): function float_alpha (line 39) | def float_alpha(alpha): function float_beta (line 44) | def float_beta(beta): function test_double_type (line 48) | def test_double_type(test_data, alpha, beta): function test_float_type (line 60) | def test_float_type(float_test_data, float_alpha, float_beta, test_data,... function test_conflicting_types (line 75) | def test_conflicting_types(test_data, float_alpha, beta): FILE: tests/distributions/test_torch_patch.py function test_dirichlet_grad_cuda (line 12) | def test_dirichlet_grad_cuda(): function test_linspace (line 18) | def test_linspace(): function test_lower_cholesky_transform (line 25) | def test_lower_cholesky_transform(batch_shape, dim): FILE: tests/distributions/test_transforms.py class Flatten (line 19) | class Flatten(dist.TransformModule): method __init__ (line 27) | def __init__(self, transform, input_shape): method _call (line 38) | def _call(self, x): method _inverse (line 44) | def _inverse(self, y): method log_abs_det_jacobian (line 50) | def log_abs_det_jacobian(self, x, y): method parameters (line 55) | def parameters(self): class TransformTests (line 59) | class TransformTests(TestCase): method setUp (line 60) | def setUp(self): method _test_jacobian (line 67) | def _test_jacobian(self, input_dim, transform): method _test_inverse (line 115) | def _test_inverse(self, shape, transform, base_dist_type="normal"): method _test_shape (line 140) | def _test_shape(self, base_shape, transform, base_dist_type="normal"): method _test_autodiff (line 161) | def _test_autodiff( method _test (line 189) | def _test( method _test_conditional (line 229) | def _test_conditional( method test_affine_autoregressive (line 244) | def test_affine_autoregressive(self): method test_affine_coupling (line 248) | def test_affine_coupling(self): method test_batchnorm (line 252) | def test_batchnorm(self): method test_block_autoregressive_jacobians (line 268) | def test_block_autoregressive_jacobians(self): method test_conditional_affine_autoregressive (line 279) | def test_conditional_affine_autoregressive(self): method test_conditional_affine_coupling (line 282) | def test_conditional_affine_coupling(self): method test_conditional_generalized_channel_permute (line 288) | def test_conditional_generalized_channel_permute(self, context_dim=3): method test_conditional_householder (line 305) | def test_conditional_householder(self): method test_conditional_matrix_exponential (line 309) | def test_conditional_matrix_exponential(self): method test_conditional_neural_autoregressive (line 312) | def test_conditional_neural_autoregressive(self): method test_conditional_planar (line 315) | def test_conditional_planar(self): method test_conditional_radial (line 318) | def test_conditional_radial(self): method test_conditional_spline (line 321) | def test_conditional_spline(self): method test_conditional_spline_autoregressive (line 325) | def test_conditional_spline_autoregressive(self): method test_discrete_cosine (line 328) | def test_discrete_cosine(self): method test_haar_transform (line 336) | def test_haar_transform(self): method test_elu (line 341) | def test_elu(self): method test_generalized_channel_permute (line 345) | def test_generalized_channel_permute(self): method test_householder (line 359) | def test_householder(self): method test_leaky_relu (line 362) | def test_leaky_relu(self): method test_lower_cholesky_affine (line 366) | def test_lower_cholesky_affine(self): method test_matrix_exponential (line 378) | def test_matrix_exponential(self): method test_neural_autoregressive (line 381) | def test_neural_autoregressive(self): method test_ordered_transform (line 387) | def test_ordered_transform(self): method test_permute (line 391) | def test_permute(self): method test_planar (line 395) | def test_planar(self): method test_polynomial (line 398) | def test_polynomial(self): method test_radial (line 401) | def test_radial(self): method test_simplex_to_ordered (line 404) | def test_simplex_to_ordered(self): method test_spline (line 417) | def test_spline(self): method test_spline_coupling (line 421) | def test_spline_coupling(self): method test_spline_autoregressive (line 424) | def test_spline_autoregressive(self): method test_sylvester (line 427) | def test_sylvester(self): method test_normalize_transform (line 430) | def test_normalize_transform(self): method test_softplus (line 433) | def test_softplus(self): method test_positive_power (line 436) | def test_positive_power(self): function test_cholesky_transform (line 451) | def test_cholesky_transform(batch_shape, dim, transform): function test_lower_cholesky_transform (line 501) | def test_lower_cholesky_transform(transform, batch_shape, dim): function test_inverse_conditional_transform_module (line 516) | def test_inverse_conditional_transform_module(batch_shape, input_dim, co... function test_conditional_compose_transform_module (line 538) | def test_conditional_compose_transform_module( FILE: tests/distributions/test_unit.py function test_shapes (line 12) | def test_shapes(batch_shape): function test_expand (line 23) | def test_expand(sample_shape, batch_shape): FILE: tests/distributions/test_util.py function test_broadcast_shape (line 48) | def test_broadcast_shape(shapes): function test_broadcast_shape_error (line 59) | def test_broadcast_shape_error(shapes): function test_broadcast_shape_strict (line 84) | def test_broadcast_shape_strict(shapes): function test_broadcast_shape_strict_error (line 105) | def test_broadcast_shape_strict_error(shapes): function test_sum_rightmost (line 110) | def test_sum_rightmost(): function test_sum_leftmost (line 120) | def test_sum_leftmost(): function test_weakmethod (line 130) | def test_weakmethod(): function test_detach_normal (line 150) | def test_detach_normal(shape): function test_detach_beta (line 166) | def test_detach_beta(shape): function test_detach_transformed (line 183) | def test_detach_transformed(shape): function test_detach_jit (line 211) | def test_detach_jit(shape): function test_deep_to_normal (line 233) | def test_deep_to_normal(shape, dtype): function test_deep_to_beta (line 265) | def test_deep_to_beta(shape, dtype): function test_deep_to_transformed (line 284) | def test_deep_to_transformed(shape, dtype): function test_deep_to_structure (line 313) | def test_deep_to_structure(dtype): function test_deep_to_jit (line 341) | def test_deep_to_jit(shape): function test_deep_to_module (line 372) | def test_deep_to_module(dtype): function test_deep_to_pyro_module (line 392) | def test_deep_to_pyro_module(dtype): FILE: tests/distributions/test_von_mises.py function _eval_poly (line 16) | def _eval_poly(y, coef): function _log_modified_bessel_fn (line 69) | def _log_modified_bessel_fn(x, order=0): function _fit_params_from_samples (line 94) | def _fit_params_from_samples(samples, n_iter): function test_sample (line 143) | def test_sample(loc, concentration, n_samples=int(1e6), n_iter=50): function test_log_prob_normalized (line 154) | def test_log_prob_normalized(concentration): function test_von_mises_gof (line 163) | def test_von_mises_gof(loc, concentration): function test_von_mises_3d (line 172) | def test_von_mises_3d(scale): function test_von_mises_3d_gof (line 188) | def test_von_mises_3d_gof(scale): FILE: tests/distributions/test_zero_inflated.py function test_zid_shape (line 24) | def test_zid_shape(gate_shape, base_shape): function test_zip_0_gate (line 37) | def test_zip_0_gate(rate): function test_zip_1_gate (line 51) | def test_zip_1_gate(rate): function test_zip_mean_variance (line 66) | def test_zip_mean_variance(gate, rate): function test_zinb_0_gate (line 80) | def test_zinb_0_gate(total_count, probs): function test_zinb_1_gate (line 103) | def test_zinb_1_gate(total_count, probs): function test_zinb_mean_variance (line 127) | def test_zinb_mean_variance(gate, total_count, logits): FILE: tests/distributions/testing/test_gof.py function test_multinomial_goodness_of_fit (line 10) | def test_multinomial_goodness_of_fit(): FILE: tests/distributions/testing/test_special.py function test_chi2sf (line 13) | def test_chi2sf(): FILE: tests/doctest_fixtures.py function add_imports (line 23) | def add_imports(doctest_namespace): FILE: tests/infer/autoguide/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/infer/autoguide/test_gaussian.py function test_break_plates (line 29) | def test_break_plates(): function test_backend_dispatch (line 77) | def test_backend_dispatch(backend): function check_structure (line 94) | def check_structure(model, expected_str, expected_dependencies=None): function check_backends_agree (line 122) | def check_backends_agree(model): function test_structure_0 (line 186) | def test_structure_0(backend): function test_structure_1 (line 206) | def test_structure_1(backend): function test_structure_2 (line 227) | def test_structure_2(backend): function test_structure_3 (line 253) | def test_structure_3(backend): function test_structure_4 (line 283) | def test_structure_4(backend): function test_structure_5 (line 312) | def test_structure_5(backend): function test_structure_6 (line 334) | def test_structure_6(backend): function test_structure_7 (line 369) | def test_structure_7(backend): function test_structure_8 (line 400) | def test_structure_8(backend): function test_broken_plates_smoke (line 422) | def test_broken_plates_smoke(backend): function test_intractable_smoke (line 439) | def test_intractable_smoke(backend): function pyrocov_model (line 462) | def pyrocov_model(dataset): function pyrocov_model_relaxed (line 507) | def pyrocov_model_relaxed(dataset): function pyrocov_model_plated (line 554) | def pyrocov_model_plated(dataset): function pyrocov_model_poisson (line 597) | def pyrocov_model_poisson(dataset): class PoissonGuide (line 643) | class PoissonGuide(AutoGuideList): method __init__ (line 644) | def __init__(self, model, backend): method hide_fn_1 (line 654) | def hide_fn_1(msg): method hide_fn_2 (line 658) | def hide_fn_2(msg): function test_pyrocov_smoke (line 672) | def test_pyrocov_smoke(model, Guide, backend): function test_pyrocov_reparam (line 692) | def test_pyrocov_reparam(model, Guide, backend): function test_pyrocov_structure (line 720) | def test_pyrocov_structure(): function test_profile (line 809) | def test_profile(backend, jit, n=1, num_steps=1, log_every=1): FILE: tests/infer/autoguide/test_inference.py class AutoGaussianChain (line 33) | class AutoGaussianChain(GaussianChain): method compute_target (line 35) | def compute_target(self, N): method test_multivariatate_normal_auto (line 50) | def test_multivariatate_normal_auto(self): method do_test_auto (line 53) | def do_test_auto(self, N, reparameterized, n_steps): function test_auto_diagonal_gaussians (line 118) | def test_auto_diagonal_gaussians(auto_class, Elbo): function test_auto_transform (line 182) | def test_auto_transform(auto_class): function test_auto_dirichlet (line 221) | def test_auto_dirichlet(auto_class, Elbo): FILE: tests/infer/autoguide/test_mean_field_entropy.py function mean_field_guide (line 14) | def mean_field_guide(batch_tensor, design): function h (line 22) | def h(p): function test_guide_entropy (line 37) | def test_guide_entropy(guide, args, expected_entropy): FILE: tests/infer/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/infer/mcmc/test_adaptation.py function test_adaptation_schedule (line 27) | def test_adaptation_schedule(adapt_step_size, adapt_mass, warmup_steps, ... function test_arrowhead_mass_matrix (line 37) | def test_arrowhead_mass_matrix(diagonal): FILE: tests/infer/mcmc/test_hmc.py function mark_jit (line 21) | def mark_jit(*args, **kwargs): function jit_idfn (line 30) | def jit_idfn(param): class GaussianChain (line 34) | class GaussianChain: method __init__ (line 35) | def __init__(self, dim, chain_len, num_obs): method model (line 42) | def model(self, data): method data (line 52) | def data(self): method id_fn (line 55) | def id_fn(self): function rmse (line 61) | def rmse(t1, t2): function test_hmc_conjugate_gaussian (line 135) | def test_hmc_conjugate_gaussian( function test_logistic_regression (line 181) | def test_logistic_regression( function test_dirichlet_categorical (line 216) | def test_dirichlet_categorical(jit): function test_beta_bernoulli (line 235) | def test_beta_bernoulli(jit): function test_gamma_normal (line 259) | def test_gamma_normal(): function test_bernoulli_latent_model (line 277) | def test_bernoulli_latent_model(jit): function test_unnormalized_normal (line 309) | def test_unnormalized_normal(kernel, jit): function test_singular_matrix_catch (line 342) | def test_singular_matrix_catch(jit, op): FILE: tests/infer/mcmc/test_mcmc_api.py class PriorKernel (line 22) | class PriorKernel(MCMCKernel): method __init__ (line 28) | def __init__(self, model): method setup (line 35) | def setup(self, warmup_steps, data): method diagnostics (line 46) | def diagnostics(self): method initial_params (line 50) | def initial_params(self): method initial_params (line 54) | def initial_params(self, params): method cleanup (line 57) | def cleanup(self): method sample_params (line 60) | def sample_params(self): method sample (line 64) | def sample(self, params): function normal_normal_model (line 72) | def normal_normal_model(data): function run_default_mcmc (line 79) | def run_default_mcmc( function run_streaming_mcmc (line 106) | def run_streaming_mcmc( function test_mcmc_interface (line 158) | def test_mcmc_interface(run_mcmc_cls, num_draws, group_by_chain, num_cha... function test_num_chains (line 207) | def test_num_chains(num_chains, cpu_count, default_init_params, monkeypa... function _empty_model (line 236) | def _empty_model(): function _hook (line 240) | def _hook(iters, kernel, samples, stage, i): function test_null_model_with_hook (line 256) | def test_null_model_with_hook(run_mcmc_cls, kernel, model, jit, num_chai... function test_mcmc_diagnostics (line 289) | def test_mcmc_diagnostics(run_mcmc_cls, num_chains): function test_sequential_consistent (line 328) | def test_sequential_consistent(run_mcmc_cls, monkeypatch): function test_model_with_potential_fn (line 370) | def test_model_with_potential_fn(run_mcmc_cls): function test_save_params (line 393) | def test_save_params(save_params, Kernel, options): FILE: tests/infer/mcmc/test_mcmc_util.py function beta_bernoulli (line 28) | def beta_bernoulli(): function test_predictive (line 44) | def test_predictive(num_samples, parallel): function model_with_param (line 73) | def model_with_param(): function test_model_with_param (line 81) | def test_model_with_param(jit_compile, num_chains): function test_model_with_subsample (line 88) | def test_model_with_subsample(subsample_size): function test_init_to_value (line 104) | def test_init_to_value(): function test_init_strategy_smoke (line 133) | def test_init_strategy_smoke(init_strategy): FILE: tests/infer/mcmc/test_nuts.py function mark_jit (line 91) | def mark_jit(*args, **kwargs): function jit_idfn (line 100) | def jit_idfn(param): function test_nuts_conjugate_gaussian (line 111) | def test_nuts_conjugate_gaussian( function test_logistic_regression (line 150) | def test_logistic_regression(jit, use_multinomial_sampling): function test_beta_bernoulli (line 184) | def test_beta_bernoulli(step_size, adapt_step_size, adapt_mass_matrix, f... function test_gamma_normal (line 209) | def test_gamma_normal(jit, use_multinomial_sampling): function test_dirichlet_categorical (line 232) | def test_dirichlet_categorical(jit): function test_gamma_beta (line 250) | def test_gamma_beta(jit): function test_gaussian_mixture_model (line 274) | def test_gaussian_mixture_model(jit): function test_bernoulli_latent_model (line 307) | def test_bernoulli_latent_model(jit): function test_gaussian_hmm (line 331) | def test_gaussian_hmm(num_steps): function test_beta_binomial (line 394) | def test_beta_binomial(hyperpriors): function test_gamma_poisson (line 434) | def test_gamma_poisson(hyperpriors): function test_structured_mass (line 465) | def test_structured_mass(): function test_arrowhead_mass (line 506) | def test_arrowhead_mass(): function test_dirichlet_categorical_grad_adapt (line 550) | def test_dirichlet_categorical_grad_adapt(): FILE: tests/infer/mcmc/test_rwkernel.py function test_beta_bernoulli (line 13) | def test_beta_bernoulli(): FILE: tests/infer/mcmc/test_valid_models.py function assert_ok (line 28) | def assert_ok(mcmc_kernel): function assert_error (line 35) | def assert_error(mcmc_kernel): function print_debug_info (line 43) | def print_debug_info(model_trace): function test_model_error_stray_batch_dims (line 57) | def test_model_error_stray_batch_dims(kernel, kwargs): function test_model_error_enum_dim_clash (line 84) | def test_model_error_enum_dim_clash(kernel, kwargs): function test_log_prob_eval_iterates_in_correct_order (line 101) | def test_log_prob_eval_iterates_in_correct_order(): function test_all_discrete_sites_log_prob (line 147) | def test_all_discrete_sites_log_prob(Eval): function test_enumeration_in_tree (line 179) | def test_enumeration_in_tree(Eval): function test_enumeration_in_dag (line 221) | def test_enumeration_in_dag(Eval): function test_enum_log_prob_continuous_observed (line 258) | def test_enum_log_prob_continuous_observed(data, expected_log_prob, Eval): function test_enum_log_prob_continuous_sampled (line 289) | def test_enum_log_prob_continuous_sampled(data, expected_log_prob, Eval): function test_enum_log_prob_discrete_observed (line 320) | def test_enum_log_prob_discrete_observed(data, expected_log_prob, Eval): function test_enum_log_prob_multiple_plate (line 348) | def test_enum_log_prob_multiple_plate(data, expected_log_prob, Eval): function test_enum_log_prob_nested_plate (line 379) | def test_enum_log_prob_nested_plate(data, expected_log_prob, Eval): function _beta_bernoulli (line 402) | def _beta_bernoulli(data): function test_potential_fn_pickling (line 412) | def test_potential_fn_pickling(jit): function test_reparam_stable (line 434) | def test_reparam_stable(kernel, kwargs): function test_potential_fn_initial_params (line 448) | def test_potential_fn_initial_params(Kernel): function test_obs_mask_ok (line 484) | def test_obs_mask_ok(Kernel, options, mask): FILE: tests/infer/reparam/test_conjugate.py function test_beta_binomial_static_sample (line 21) | def test_beta_binomial_static_sample(): function test_beta_binomial_dependent_sample (line 45) | def test_beta_binomial_dependent_sample(): function test_beta_binomial_elbo (line 75) | def test_beta_binomial_elbo(): function test_gaussian_hmm_elbo (line 114) | def test_gaussian_hmm_elbo(batch_shape, num_steps, hidden_dim, obs_dim): function random_stable (line 157) | def random_stable(shape): function test_stable_hmm_smoke (line 168) | def test_stable_hmm_smoke(batch_shape, num_steps, hidden_dim, obs_dim): function test_beta_binomial_hmc (line 211) | def test_beta_binomial_hmc(): function test_init (line 238) | def test_init(): FILE: tests/infer/reparam/test_discrete_cosine.py function get_moments (line 18) | def get_moments(x): function test_normal (line 46) | def test_normal(shape, dim, smooth): function test_uniform (line 91) | def test_uniform(shape, dim, smooth): function test_init (line 128) | def test_init(shape, dim, smooth): FILE: tests/infer/reparam/test_haar.py function get_moments (line 20) | def get_moments(x): function test_normal (line 48) | def test_normal(shape, dim, flip): function test_uniform (line 93) | def test_uniform(shape, dim, flip): function test_init (line 128) | def test_init(shape, dim, flip): function test_nested (line 139) | def test_nested(): FILE: tests/infer/reparam/test_hmm.py function random_studentt (line 22) | def random_studentt(shape): function random_stable (line 29) | def random_stable(shape, stability, skew=None): function test_transformed_hmm_shape (line 41) | def test_transformed_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): function test_studentt_hmm_shape (line 72) | def test_studentt_hmm_shape(batch_shape, duration, hidden_dim, obs_dim): function test_stable_hmm_shape (line 104) | def test_stable_hmm_shape(skew, batch_shape, duration, hidden_dim, obs_d... function test_independent_hmm_shape (line 145) | def test_independent_hmm_shape(skew, batch_shape, duration, hidden_dim, ... function get_hmm_moments (line 188) | def get_hmm_moments(samples): function test_stable_hmm_distribution (line 203) | def test_stable_hmm_distribution(stability, skew, duration, hidden_dim, ... function test_stable_hmm_shape_error (line 236) | def test_stable_hmm_shape_error(batch_shape, duration, hidden_dim, obs_d... function test_init_shape (line 269) | def test_init_shape(skew, batch_shape, duration, hidden_dim, obs_dim): FILE: tests/infer/reparam/test_loc_scale.py function get_moments (line 20) | def get_moments(x): function test_moments (line 35) | def test_moments(dist_type, centered, shape): function test_init (line 79) | def test_init(dist_type, centered, shape): function test_init_with_reparam_inside_plate (line 101) | def test_init_with_reparam_inside_plate(): FILE: tests/infer/reparam/test_neutra.py function neals_funnel (line 24) | def neals_funnel(dim=10): function dirichlet_categorical (line 30) | def dirichlet_categorical(data): function test_neals_funnel_smoke (line 43) | def test_neals_funnel_smoke(Guide, jit): function test_reparam_log_joint (line 73) | def test_reparam_log_joint(model, kwargs): function test_init (line 92) | def test_init(): FILE: tests/infer/reparam/test_projected_normal.py function get_moments (line 18) | def get_moments(x): function test_projected_normal (line 28) | def test_projected_normal(shape, dim): function test_init (line 54) | def test_init(shape, dim): FILE: tests/infer/reparam/test_softmax.py function get_moments (line 18) | def get_moments(x): function test_gumbel_softmax (line 29) | def test_gumbel_softmax(temperature, shape, dim): function test_init (line 57) | def test_init(temperature, shape, dim): FILE: tests/infer/reparam/test_split.py function test_normal (line 41) | def test_normal(batch_shape, event_shape, splits, dim): function test_init (line 83) | def test_init(batch_shape, event_shape, splits, dim): function test_observe (line 95) | def test_observe(): function test_transformed_distribution (line 124) | def test_transformed_distribution(batch_shape): function test_predictive (line 175) | def test_predictive(batch_shape, event_shape, splits, dim): FILE: tests/infer/reparam/test_stable.py function get_moments (line 27) | def get_moments(x): function test_stable (line 43) | def test_stable(Reparam, shape): function test_symmetric_stable (line 84) | def test_symmetric_stable(shape): function test_distribution (line 119) | def test_distribution(stability, skew, Reparam): function test_subsample_smoke (line 139) | def test_subsample_smoke(Reparam, subsample): function test_init (line 158) | def test_init(stability, skew, Reparam): FILE: tests/infer/reparam/test_strategies.py function trace_name_is_observed (line 16) | def trace_name_is_observed(model): function normal_model (line 25) | def normal_model(): function test_normal_minimal (line 41) | def test_normal_minimal(): function test_normal_auto (line 60) | def test_normal_auto(centered): function stable_model (line 116) | def stable_model(): function test_stable_minimal (line 155) | def test_stable_minimal(): function test_stable_auto (line 201) | def test_stable_auto(): function projected_normal_model (line 255) | def projected_normal_model(): function test_projected_normal_minimal (line 262) | def test_projected_normal_minimal(): function test_projected_normal_auto (line 275) | def test_projected_normal_auto(): function softmax_model (line 291) | def softmax_model(): function test_softmax_minimal (line 302) | def test_softmax_minimal(): function test_softmax_auto (line 309) | def test_softmax_auto(): function test_end_to_end (line 330) | def test_end_to_end(model): FILE: tests/infer/reparam/test_structured.py function neals_funnel (line 17) | def neals_funnel(dim=10): function test_neals_funnel_smoke (line 24) | def test_neals_funnel_smoke(jit): function test_init (line 57) | def test_init(): FILE: tests/infer/reparam/test_studentt.py function get_moments (line 20) | def get_moments(x): function test_moments (line 27) | def test_moments(shape): function test_distribution (line 60) | def test_distribution(df, loc, scale): function test_init (line 72) | def test_init(shape): FILE: tests/infer/reparam/test_transform.py function get_moments (line 18) | def get_moments(x): function test_log_normal (line 34) | def test_log_normal(batch_shape, event_shape): function test_init (line 67) | def test_init(batch_shape, event_shape): FILE: tests/infer/reparam/test_unit_jacobian.py function get_moments (line 19) | def get_moments(x): function test_normal (line 31) | def test_normal(shape): function test_init (line 61) | def test_init(shape): FILE: tests/infer/reparam/util.py function check_init_reparam (line 14) | def check_init_reparam(model, reparam): FILE: tests/infer/test_abstract_infer.py function model (line 19) | def model(num_trials): function test_nesting (line 26) | def test_nesting(): function test_information_criterion (line 45) | def test_information_criterion(): FILE: tests/infer/test_autoguide.py function xfail_messenger (line 68) | def xfail_messenger(auto_class, Elbo): function test_scores (line 85) | def test_scores(auto_class): function test_factor (line 127) | def test_factor(auto_class, Elbo): class AutoStructured_shapes (line 150) | class AutoStructured_shapes(AutoStructured): method __init__ (line 151) | def __init__(self, model, *, init_loc_fn): function test_shapes (line 229) | def test_shapes(auto_class, init_loc_fn, Elbo, num_particles): function test_iplate_smoke (line 270) | def test_iplate_smoke(auto_class, Elbo): function auto_guide_list_x (line 293) | def auto_guide_list_x(model): function auto_guide_callable (line 300) | def auto_guide_callable(model): class GuideX (line 317) | class GuideX(AutoGuide): method __init__ (line 318) | def __init__(self, model): method forward (line 323) | def forward(self, *args, **kwargs): method median (line 326) | def median(self, *args, **kwargs): function auto_guide_module_callable (line 330) | def auto_guide_module_callable(model): function nested_auto_guide_callable (line 337) | def nested_auto_guide_callable(model): class AutoStructured_median (line 346) | class AutoStructured_median(AutoStructured): method __init__ (line 347) | def __init__(self, model): function test_median (line 387) | def test_median(auto_class, Elbo): function serialization_model (line 420) | def serialization_model(): function test_serialization (line 461) | def test_serialization(auto_class, jit): function AutoGuideList_x (line 509) | def AutoGuideList_x(model): function test_quantiles (line 528) | def test_quantiles(auto_class, Elbo): function test_discrete_parallel (line 594) | def test_discrete_parallel(continuous_class): function test_guide_list (line 631) | def test_guide_list(auto_class): function test_callable (line 655) | def test_callable(auto_class): function test_callable_return_dict (line 685) | def test_callable_return_dict(auto_class): function test_empty_model_error (line 702) | def test_empty_model_error(): function test_unpack_latent (line 711) | def test_unpack_latent(): function test_init_loc_fn (line 735) | def test_init_loc_fn(auto_class): class AutoLowRankMultivariateNormal_100 (line 753) | class AutoLowRankMultivariateNormal_100(AutoLowRankMultivariateNormal): method __init__ (line 754) | def __init__(self, *args, **kwargs): function test_init_scale (line 768) | def test_init_scale(auto_class, init_scale): function test_median_module (line 802) | def test_median_module(auto_class, Elbo): function test_nested_autoguide (line 833) | def test_nested_autoguide(Elbo): function test_linear_regression_smoke (line 895) | def test_linear_regression_smoke(auto_class, Elbo): class AutoStructured_predictive (line 929) | class AutoStructured_predictive(AutoStructured): method __init__ (line 930) | def __init__(self, model): function test_predictive (line 969) | def test_predictive(auto_class): function test_replay_plates (line 1039) | def test_replay_plates(auto_class, sample_shape): function test_subsample_model (line 1071) | def test_subsample_model(auto_class): function test_subsample_model_amortized (line 1109) | def test_subsample_model_amortized(auto_class): function test_subsample_guide (line 1148) | def test_subsample_guide(auto_class, init_fn): function test_subsample_guide_2 (line 1198) | def test_subsample_guide_2(auto_class, independent): function test_discrete_helpful_error (line 1250) | def test_discrete_helpful_error(auto_class, init_loc_fn): function test_sphere_helpful_error (line 1289) | def test_sphere_helpful_error(auto_class, init_loc_fn): function test_sphere_reparam_ok (line 1325) | def test_sphere_reparam_ok(auto_class, init_loc_fn): function test_sphere_raw_ok (line 1348) | def test_sphere_raw_ok(auto_class, init_loc_fn): class AutoStructured_exact_normal (line 1360) | class AutoStructured_exact_normal(AutoStructured): method __init__ (line 1361) | def __init__(self, model): class AutoStructured_exact_mvn (line 1369) | class AutoStructured_exact_mvn(AutoStructured): method __init__ (line 1370) | def __init__(self, model): function test_exact (line 1394) | def test_exact(Guide): function test_exact_batch (line 1462) | def test_exact_batch(Guide): function test_exact_tree (line 1526) | def test_exact_tree(Guide): function test_autonormal_dynamic_model (line 1595) | def test_autonormal_dynamic_model(): FILE: tests/infer/test_compute_downstream_costs.py function _brute_force_compute_downstream_costs (line 20) | def _brute_force_compute_downstream_costs( function _provenance_compute_downstream_costs (line 73) | def _provenance_compute_downstream_costs(model_trace, guide_trace): function big_model_guide (line 98) | def big_model_guide( function test_compute_downstream_costs_big_model_guide_pair (line 177) | def test_compute_downstream_costs_big_model_guide_pair( function diamond_model (line 380) | def diamond_model(dim): function diamond_guide (line 391) | def diamond_guide(dim): function test_compute_downstream_costs_duplicates (line 401) | def test_compute_downstream_costs_duplicates(dim): function nested_model_guide (line 481) | def nested_model_guide(include_obs=True, dim1=11, dim2=7): function test_compute_downstream_costs_plate_in_iplate (line 498) | def test_compute_downstream_costs_plate_in_iplate(dim1): function nested_model_guide2 (line 594) | def nested_model_guide2(include_obs=True, dim1=3, dim2=2): function test_compute_downstream_costs_iplate_in_plate (line 613) | def test_compute_downstream_costs_iplate_in_plate(dim1, dim2): function plate_reuse_model_guide (line 697) | def plate_reuse_model_guide(include_obs=True, dim1=3, dim2=2): function test_compute_downstream_costs_plate_reuse (line 716) | def test_compute_downstream_costs_plate_reuse(dim1, dim2): FILE: tests/infer/test_conjugate_gradients.py class ConjugateChainGradientTests (line 10) | class ConjugateChainGradientTests(GaussianChain): method test_gradients (line 11) | def test_gradients(self): method do_test_gradients (line 16) | def do_test_gradients(self, N, reparameterized): FILE: tests/infer/test_csis.py function model (line 15) | def model(observations={"y1": 0, "y2": 0}): class Guide (line 22) | class Guide(nn.Module): method __init__ (line 23) | def __init__(self): method forward (line 28) | def forward(self, observations={"y1": 0, "y2": 0}): function test_csis_sampling (line 36) | def test_csis_sampling(): function test_csis_parameter_update (line 48) | def test_csis_parameter_update(): function test_csis_validation_batch (line 60) | def test_csis_validation_batch(): FILE: tests/infer/test_discrete.py function elbo_infer_discrete (line 23) | def elbo_infer_discrete(model, first_available_dim, temperature): function log_mean_prob (line 46) | def log_mean_prob(trace, particle_dim): function test_plate_smoke (line 72) | def test_plate_smoke(infer, temperature, plate_size): function test_distribution_1 (line 101) | def test_distribution_1(infer, temperature): function test_distribution_2 (line 150) | def test_distribution_2(infer, temperature): function test_distribution_3 (line 215) | def test_distribution_3(infer, temperature): function test_distribution_masked (line 279) | def test_distribution_masked(infer, temperature): function test_hmm_smoke (line 330) | def test_hmm_smoke(infer, temperature, length): function test_prob (line 363) | def test_prob(nderivs): function test_warning (line 401) | def test_warning(): FILE: tests/infer/test_elbo_mapdata.py function test_elbo_mapdata (line 34) | def test_elbo_mapdata(map_type, batch_size, n_steps, lr): FILE: tests/infer/test_enum.py function _skip_cuda (line 35) | def _skip_cuda(*args): function test_iter_discrete_traces_order (line 45) | def test_iter_discrete_traces_order(depth, graph_type): function test_iter_discrete_traces_scalar (line 60) | def test_iter_discrete_traces_scalar(graph_type): function test_iter_discrete_traces_vector (line 79) | def test_iter_discrete_traces_vector(expand, graph_type): function test_enumerate_sequential_guide (line 105) | def test_enumerate_sequential_guide(): function test_enumerate_sequential_model (line 120) | def test_enumerate_sequential_model(): class UnsafeBernoulli (line 134) | class UnsafeBernoulli(dist.Bernoulli): method log_prob (line 135) | def log_prob(self, value): function test_unsafe_bernoulli (line 142) | def test_unsafe_bernoulli(sample_shape): function test_avoid_nan (line 151) | def test_avoid_nan(enumerate1): function gmm_model (line 175) | def gmm_model(data, verbose=False): function gmm_guide (line 187) | def gmm_guide(data, verbose=False): function test_gmm_iter_discrete_traces (line 199) | def test_gmm_iter_discrete_traces(data_size, graph_type, model): function gmm_batch_model (line 209) | def gmm_batch_model(data): function gmm_batch_guide (line 222) | def gmm_batch_guide(data): function test_gmm_batch_iter_discrete_traces (line 234) | def test_gmm_batch_iter_discrete_traces(model, data_size, graph_type): function test_svi_step_smoke (line 252) | def test_svi_step_smoke(model, guide, enumerate1): function test_differentiable_loss (line 272) | def test_differentiable_loss(model, guide, enumerate1): function test_svi_step_guide_uses_grad (line 307) | def test_svi_step_guide_uses_grad(enumerate1): function test_elbo_bern (line 341) | def test_elbo_bern(method, enumerate1, scale): function test_elbo_normal (line 399) | def test_elbo_normal(method, enumerate1): function test_elbo_bern_bern (line 472) | def test_elbo_bern_bern(method, enumerate1, enumerate2, num_samples1, nu... function test_elbo_berns (line 556) | def test_elbo_berns(method, enumerate1, enumerate2, enumerate3, num_samp... function test_elbo_categoricals (line 635) | def test_elbo_categoricals( function test_elbo_normals (line 721) | def test_elbo_normals(method, enumerate1, enumerate2, enumerate3): function test_elbo_plate (line 807) | def test_elbo_plate(plate_dim, enumerate1, enumerate2, num_samples): function test_elbo_iplate (line 871) | def test_elbo_iplate(plate_dim, enumerate1, enumerate2): function test_elbo_plate_plate (line 946) | def test_elbo_plate_plate( function test_elbo_plate_iplate (line 1036) | def test_elbo_plate_iplate( function test_elbo_iplate_plate (line 1119) | def test_elbo_iplate_plate(outer_dim, inner_dim, enumerate1, enumerate2,... function test_elbo_iplate_iplate (line 1202) | def test_elbo_iplate_iplate(outer_dim, inner_dim, enumerate1, enumerate2... function test_non_mean_field_bern_bern_elbo_gradient (line 1283) | def test_non_mean_field_bern_bern_elbo_gradient(enumerate1, pi1, pi2): function test_non_mean_field_bern_normal_elbo_gradient (line 1351) | def test_non_mean_field_bern_normal_elbo_gradient( function test_non_mean_field_normal_bern_elbo_gradient (line 1442) | def test_non_mean_field_normal_bern_elbo_gradient(pi1, pi2, pi3): function test_elbo_rsvi (line 1497) | def test_elbo_rsvi(enumerate1): function test_elbo_hmm_in_model (line 1569) | def test_elbo_hmm_in_model(enumerate1, num_steps, expand): function test_elbo_hmm_in_guide (line 1646) | def test_elbo_hmm_in_guide(enumerate1, num_steps, expand): function test_hmm_enumerate_model (line 1731) | def test_hmm_enumerate_model(num_steps): function test_hmm_enumerate_model_and_guide (line 1760) | def test_hmm_enumerate_model_and_guide(num_steps): function _check_loss_and_grads (line 1795) | def _check_loss_and_grads(expected_loss, actual_loss): function test_elbo_enumerate_1 (line 1823) | def test_elbo_enumerate_1(scale): function test_elbo_enumerate_2 (line 1868) | def test_elbo_enumerate_2(scale): function test_elbo_enumerate_3 (line 1919) | def test_elbo_enumerate_3(scale): function test_elbo_enumerate_plate_1 (line 1974) | def test_elbo_enumerate_plate_1(num_samples, num_masked, scale): function test_elbo_enumerate_plate_2 (line 2044) | def test_elbo_enumerate_plate_2(num_samples, num_masked, scale): function test_elbo_enumerate_plate_3 (line 2121) | def test_elbo_enumerate_plate_3(num_samples, num_masked, scale): function test_elbo_enumerate_plate_4 (line 2209) | def test_elbo_enumerate_plate_4(outer_obs, inner_obs, scale): function test_elbo_enumerate_plate_5 (line 2274) | def test_elbo_enumerate_plate_5(): function test_elbo_enumerate_plate_6 (line 2345) | def test_elbo_enumerate_plate_6(enumerate1): function test_elbo_enumerate_plate_7 (line 2406) | def test_elbo_enumerate_plate_7(scale): function test_elbo_enumerate_plates_1 (line 2512) | def test_elbo_enumerate_plates_1(scale): function test_elbo_enumerate_plates_2 (line 2575) | def test_elbo_enumerate_plates_2(scale): function test_elbo_enumerate_plates_3 (line 2630) | def test_elbo_enumerate_plates_3(scale): function test_elbo_enumerate_plates_4 (line 2681) | def test_elbo_enumerate_plates_4(scale): function test_elbo_enumerate_plates_5 (line 2739) | def test_elbo_enumerate_plates_5(scale): function test_elbo_enumerate_plates_6 (line 2801) | def test_elbo_enumerate_plates_6(scale): function test_elbo_enumerate_plates_7 (line 2940) | def test_elbo_enumerate_plates_7(scale): function test_elbo_enumerate_plates_8 (line 3097) | def test_elbo_enumerate_plates_8( function test_elbo_scale (line 3225) | def test_elbo_scale(): function test_elbo_hmm_growth (line 3267) | def test_elbo_hmm_growth(): function test_elbo_dbn_growth (line 3338) | def test_elbo_dbn_growth(): function test_bernoulli_pyramid_elbo_gradient (line 3413) | def test_bernoulli_pyramid_elbo_gradient( function test_bernoulli_non_tree_elbo_gradient (line 3513) | def test_bernoulli_non_tree_elbo_gradient( function test_elbo_zip (line 3642) | def test_elbo_zip(gate, rate): function test_mixture_of_diag_normals (line 3681) | def test_mixture_of_diag_normals(mixture, scale): function test_compute_marginals_single (line 3736) | def test_compute_marginals_single(Dist, prior): function test_compute_marginals_restrictions (line 3785) | def test_compute_marginals_restrictions( function test_compute_marginals_hmm (line 3819) | def test_compute_marginals_hmm(size): function test_marginals_2678 (line 3859) | def test_marginals_2678(observed): function test_backwardsample_posterior_smoke (line 3882) | def test_backwardsample_posterior_smoke(data): function test_backwardsample_posterior_2 (line 3916) | def test_backwardsample_posterior_2(): function test_backwardsample_posterior_3 (line 3937) | def test_backwardsample_posterior_3(): function test_backwardsample_posterior_restrictions (line 3975) | def test_backwardsample_posterior_restrictions( function test_vectorized_importance (line 4013) | def test_vectorized_importance(num_samples): function test_multi_dependence_enumeration (line 4066) | def test_multi_dependence_enumeration(): FILE: tests/infer/test_gradient.py function DiffTrace_ELBO (line 34) | def DiffTrace_ELBO(*args, **kwargs): function test_particle_gradient (line 50) | def test_particle_gradient(Elbo, reparameterized, has_rsample): function test_subsample_gradient (line 148) | def test_subsample_gradient( function test_plate (line 224) | def test_plate(Elbo, reparameterized): function test_plate_elbo_vectorized_particles (line 277) | def test_plate_elbo_vectorized_particles(Elbo, reparameterized): function test_subsample_gradient_sequential (line 354) | def test_subsample_gradient_sequential(Elbo, reparameterized, subsample): function test_collapse_beta_binomial (line 396) | def test_collapse_beta_binomial(): FILE: tests/infer/test_inference.py function param_mse (line 47) | def param_mse(name, target): function param_abs_error (line 51) | def param_abs_error(name, target): class NormalNormalTests (line 56) | class NormalNormalTests(TestCase): method setUp (line 57) | def setUp(self): method test_elbo_reparameterized (line 74) | def test_elbo_reparameterized(self): method test_elbo_analytic_kl (line 77) | def test_elbo_analytic_kl(self): method test_elbo_tail_adaptive (line 80) | def test_elbo_tail_adaptive(self): method test_elbo_nonreparameterized (line 87) | def test_elbo_nonreparameterized(self): method test_renyi_reparameterized (line 90) | def test_renyi_reparameterized(self): method test_renyi_nonreparameterized (line 95) | def test_renyi_nonreparameterized(self): method test_rws_reparameterized (line 100) | def test_rws_reparameterized(self): method test_rws_nonreparameterized (line 103) | def test_rws_nonreparameterized(self): method test_mmd_vectorized (line 106) | def test_mmd_vectorized(self): method test_mmd_nonvectorized (line 121) | def test_mmd_nonvectorized(self): method do_elbo_test (line 137) | def do_elbo_test(self, reparameterized, n_steps, loss): method do_fit_prior_test (line 174) | def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=Fals... class TestFixedModelGuide (line 242) | class TestFixedModelGuide(TestCase): method setUp (line 243) | def setUp(self): method do_test_fixedness (line 250) | def do_test_fixedness(self, fixed_parts): method test_model_fixed (line 293) | def test_model_fixed(self): method test_guide_fixed (line 296) | def test_guide_fixed(self): method test_guide_and_model_both_fixed (line 299) | def test_guide_and_model_both_fixed(self): method test_guide_and_model_free (line 302) | def test_guide_and_model_free(self): class PoissonGammaTests (line 307) | class PoissonGammaTests(TestCase): method setUp (line 308) | def setUp(self): method test_elbo_reparameterized (line 321) | def test_elbo_reparameterized(self): method test_elbo_nonreparameterized (line 324) | def test_elbo_nonreparameterized(self): method test_renyi_reparameterized (line 327) | def test_renyi_reparameterized(self): method test_renyi_nonreparameterized (line 330) | def test_renyi_nonreparameterized(self): method test_rws_reparameterized (line 333) | def test_rws_reparameterized(self): method test_rws_nonreparameterized (line 336) | def test_rws_nonreparameterized(self): method test_mmd_vectorized (line 339) | def test_mmd_vectorized(self): method do_elbo_test (line 356) | def do_elbo_test(self, reparameterized, n_steps, loss): method do_fit_prior_test (line 404) | def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=Fals... function test_exponential_gamma (line 516) | def test_exponential_gamma(gamma_dist, n_steps, elbo_impl): class BernoulliBetaTests (line 588) | class BernoulliBetaTests(TestCase): method setUp (line 589) | def setUp(self): method test_elbo_reparameterized (line 605) | def test_elbo_reparameterized(self): method test_elbo_nonreparameterized (line 608) | def test_elbo_nonreparameterized(self): method test_elbo_reparameterized_vectorized (line 612) | def test_elbo_reparameterized_vectorized(self): method test_elbo_nonreparameterized_vectorized (line 620) | def test_elbo_nonreparameterized_vectorized(self): method test_renyi_reparameterized (line 627) | def test_renyi_reparameterized(self): method test_renyi_nonreparameterized (line 630) | def test_renyi_nonreparameterized(self): method test_renyi_reparameterized_vectorized (line 633) | def test_renyi_reparameterized_vectorized(self): method test_renyi_nonreparameterized_vectorized (line 640) | def test_renyi_nonreparameterized_vectorized(self): method test_rws_reparameterized (line 652) | def test_rws_reparameterized(self): method test_rws_nonreparameterized (line 655) | def test_rws_nonreparameterized(self): method test_rws_reparameterized_vectorized (line 658) | def test_rws_reparameterized_vectorized(self): method test_rws_nonreparameterized_vectorized (line 667) | def test_rws_nonreparameterized_vectorized(self): method test_mmd_vectorized (line 676) | def test_mmd_vectorized(self): method do_elbo_test (line 691) | def do_elbo_test(self, reparameterized, n_steps, loss): method do_fit_prior_test (line 718) | def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False): class SafetyTests (line 786) | class SafetyTests(TestCase): method setUp (line 787) | def setUp(self): method test_duplicate_names (line 812) | def test_duplicate_names(self): method test_extra_samples (line 821) | def test_extra_samples(self): method test_duplicate_obs_name (line 830) | def test_duplicate_obs_name(self): function test_energy_distance_univariate (line 842) | def test_energy_distance_univariate(prior_scale): function test_energy_distance_multivariate (line 887) | def test_energy_distance_multivariate(prior_scale): function test_reparam_stable (line 917) | def test_reparam_stable(): function test_sequential_plating_sum (line 944) | def test_sequential_plating_sum(): function test_non_nested_plating_sum (line 972) | def test_non_nested_plating_sum(): FILE: tests/infer/test_initialization.py function test_init_to_generated (line 15) | def test_init_to_generated(): FILE: tests/infer/test_inspect.py function test_get_dependencies (line 14) | def test_get_dependencies(grad_enabled): function test_docstring_example_1 (line 63) | def test_docstring_example_1(): function test_docstring_example_2 (line 81) | def test_docstring_example_2(): function test_docstring_example_3 (line 105) | def test_docstring_example_3(): function test_factor (line 124) | def test_factor(): function test_discrete_obs (line 144) | def test_discrete_obs(): function test_discrete (line 175) | def test_discrete(): function test_plate_coupling (line 202) | def test_plate_coupling(): function test_plate_coupling_2 (line 232) | def test_plate_coupling_2(): function test_plate_coupling_3 (line 268) | def test_plate_coupling_3(): function test_plate_collider (line 306) | def test_plate_collider(): function test_plate_dependency (line 347) | def test_plate_dependency(): function test_nested_plate_collider (line 389) | def test_nested_plate_collider(): function test_deep_merge (line 450) | def test_deep_merge(things, expected): function test_get_model_relations (line 456) | def test_get_model_relations(include_deterministic): FILE: tests/infer/test_jit.py function constant (line 36) | def constant(*args, **kwargs): function test_simple (line 45) | def test_simple(): function test_multi_output (line 67) | def test_multi_output(): function test_backward (line 89) | def test_backward(): function test_grad (line 112) | def test_grad(): function test_grad_expand (line 129) | def test_grad_expand(): function test_scale_and_mask (line 145) | def test_scale_and_mask(): function test_masked_fill (line 162) | def test_masked_fill(): function test_scatter (line 177) | def test_scatter(): function test_scatter_workaround (line 187) | def test_scatter_workaround(): function test_bernoulli_enumerate (line 206) | def test_bernoulli_enumerate(shape, expand): function test_categorical_enumerate (line 222) | def test_categorical_enumerate(shape, expand): function test_one_hot_categorical_enumerate (line 240) | def test_one_hot_categorical_enumerate(shape, expand): function test_loss (line 269) | def test_loss(Elbo): function test_svi (line 316) | def test_svi(Elbo, num_particles): function test_svi_enum (line 339) | def test_svi_enum(plate_dim, enumerate1, enumerate2): function test_beta_bernoulli (line 402) | def test_beta_bernoulli(Elbo, vectorized): function test_svi_irregular_batch_size (line 450) | def test_svi_irregular_batch_size(Elbo): function test_dirichlet_bernoulli (line 472) | def test_dirichlet_bernoulli(Elbo, vectorized): function test_traceenum_elbo (line 507) | def test_traceenum_elbo(length): function test_infer_discrete (line 554) | def test_infer_discrete(temperature, length): function test_cond_indep_equality (line 601) | def test_cond_indep_equality(x, y): function test_jit_arange_workaround (line 607) | def test_jit_arange_workaround(): FILE: tests/infer/test_multi_sample_elbos.py function check_elbo (line 13) | def check_elbo(model, guide, Elbo): function test_inner_outer (line 28) | def test_inner_outer(Elbo): function test_outer_inner (line 45) | def test_outer_inner(Elbo): FILE: tests/infer/test_predictive.py function model (line 19) | def model(num_trials): function one_hot_model (line 26) | def one_hot_model(pseudocounts, classes=None): function beta_guide (line 33) | def beta_guide(num_trials): function test_posterior_predictive_svi_manual_guide (line 54) | def test_posterior_predictive_svi_manual_guide( function test_posterior_predictive_svi_auto_delta_guide (line 135) | def test_posterior_predictive_svi_auto_delta_guide(parallel, predictive): function test_posterior_predictive_svi_auto_diag_normal_guide (line 160) | def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace, p... function test_posterior_predictive_svi_one_hot (line 188) | def test_posterior_predictive_svi_one_hot(): function test_shapes (line 204) | def test_shapes(parallel, predictive): function test_deterministic (line 241) | def test_deterministic(with_plate, event_shape, predictive): function test_get_mask_optimization (line 272) | def test_get_mask_optimization(): FILE: tests/infer/test_resampler.py function test_resampling_cache (line 16) | def test_resampling_cache(stable): FILE: tests/infer/test_sampling.py class HMMSamplingTestCase (line 16) | class HMMSamplingTestCase(TestCase): method setUp (line 17) | def setUp(self): class NormalNormalSamplingTestCase (line 51) | class NormalNormalSamplingTestCase(TestCase): method setUp (line 52) | def setUp(self): class ImportanceTest (line 74) | class ImportanceTest(NormalNormalSamplingTestCase): method test_importance_guide (line 76) | def test_importance_guide(self): method test_importance_prior (line 91) | def test_importance_prior(self): FILE: tests/infer/test_smcfilter.py function test_systematic_sample (line 16) | def test_systematic_sample(size): class SmokeModel (line 33) | class SmokeModel: method __init__ (line 34) | def __init__(self, state_size, plate_size): method init (line 38) | def init(self, state): method step (line 48) | def step(self, state, x=None, y=None): class SmokeGuide (line 66) | class SmokeGuide: method __init__ (line 67) | def __init__(self, state_size, plate_size): method init (line 71) | def init(self, state): method step (line 81) | def step(self, state, x=None, y=None): function test_smoke (line 92) | def test_smoke(max_plate_nesting, state_size, plate_size, num_steps): class HarmonicModel (line 114) | class HarmonicModel: method __init__ (line 115) | def __init__(self): method init (line 121) | def init(self, state): method step (line 127) | def step(self, state, y=None): class HarmonicGuide (line 142) | class HarmonicGuide: method __init__ (line 143) | def __init__(self): method init (line 146) | def init(self, state): method step (line 150) | def step(self, state, y=None): function generate_data (line 162) | def generate_data(): function score_latent (line 177) | def score_latent(zs, ys): function test_likelihood_ratio (line 189) | def test_likelihood_ratio(): function test_gaussian_filter (line 211) | def test_gaussian_filter(): FILE: tests/infer/test_svgd.py function test_mean_variance (line 26) | def test_mean_variance(latent_dist, mode, stein_kernel, verbose=True): function test_shapes (line 74) | def test_shapes(shape, stein_kernel): function test_conjugate (line 103) | def test_conjugate(mode, stein_kernel, verbose=False): FILE: tests/infer/test_tmc.py function test_tmc_categoricals (line 29) | def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_str... function test_tmc_normals_chain_iwae (line 107) | def test_tmc_normals_chain_iwae( function test_tmc_normals_chain_gradient (line 218) | def test_tmc_normals_chain_gradient( FILE: tests/infer/test_util.py function xy_model (line 17) | def xy_model(): function test_multi_frame_tensor (line 30) | def test_multi_frame_tensor(): function test_psis_diagnostic (line 57) | def test_psis_diagnostic(scale, krange, zdim, max_particles, num_particl... function test_render_model_deterministic_param (line 77) | def test_render_model_deterministic_param(): FILE: tests/infer/test_valid_models.py function EnergyDistance_prior (line 39) | def EnergyDistance_prior(**kwargs): function EnergyDistance_noprior (line 45) | def EnergyDistance_noprior(**kwargs): function assert_ok (line 51) | def assert_ok(model, guide, elbo, **kwargs): function assert_error (line 79) | def assert_error(model, guide, elbo, match=None): function assert_warning (line 92) | def assert_warning(model, guide, elbo): function test_nonempty_model_empty_guide_ok (line 118) | def test_nonempty_model_empty_guide_ok(Elbo, strict_enumeration_warning): function test_nonempty_model_empty_guide_error (line 146) | def test_nonempty_model_empty_guide_error(Elbo, strict_enumeration_warni... function test_empty_model_empty_guide_ok (line 161) | def test_empty_model_empty_guide_ok(Elbo, strict_enumeration_warning): function test_variable_clash_in_model_error (line 178) | def test_variable_clash_in_model_error(Elbo): function test_model_guide_dim_mismatch_error (line 194) | def test_model_guide_dim_mismatch_error(Elbo): function test_model_guide_shape_mismatch_error (line 216) | def test_model_guide_shape_mismatch_error(Elbo): function test_variable_clash_in_guide_error (line 236) | def test_variable_clash_in_guide_error(Elbo): function test_set_has_rsample_ok (line 253) | def test_set_has_rsample_ok(has_rsample, Elbo): function test_not_has_rsample_ok (line 277) | def test_not_has_rsample_ok(Elbo): function test_iplate_ok (line 299) | def test_iplate_ok(subsample_size, Elbo): function test_iplate_variable_clash_error (line 321) | def test_iplate_variable_clash_error(Elbo): function test_plate_ok (line 346) | def test_plate_ok(subsample_size, Elbo): function test_plate_subsample_param_ok (line 369) | def test_plate_subsample_param_ok(subsample_size, Elbo): function test_plate_subsample_primitive_ok (line 395) | def test_plate_subsample_primitive_ok(subsample_size, Elbo): function test_plate_param_size_mismatch_error (line 435) | def test_plate_param_size_mismatch_error(subsample_size, Elbo, shape, ok): function test_plate_no_size_ok (line 461) | def test_plate_no_size_ok(Elbo): function test_iplate_iplate_ok (line 485) | def test_iplate_iplate_ok(subsample_size, Elbo, max_plate_nesting): function test_iplate_iplate_swap_ok (line 515) | def test_iplate_iplate_swap_ok(subsample_size, Elbo, max_plate_nesting): function test_iplate_in_model_not_guide_ok (line 544) | def test_iplate_in_model_not_guide_ok(subsample_size, Elbo): function test_iplate_in_guide_not_model_error (line 568) | def test_iplate_in_guide_not_model_error(subsample_size, Elbo, is_valida... function test_plate_broadcast_error (line 592) | def test_plate_broadcast_error(Elbo): function test_plate_iplate_ok (line 604) | def test_plate_iplate_ok(Elbo): function test_iplate_plate_ok (line 628) | def test_iplate_plate_ok(Elbo): function test_plate_stack_ok (line 655) | def test_plate_stack_ok(Elbo, sizes): function test_plate_stack_and_plate_ok (line 678) | def test_plate_stack_and_plate_ok(Elbo, sizes): function test_plate_stack_sizes (line 700) | def test_plate_stack_sizes(sizes): function test_nested_plate_plate_ok (line 713) | def test_nested_plate_plate_ok(Elbo): function test_plate_reuse_ok (line 736) | def test_plate_reuse_ok(Elbo): function test_nested_plate_plate_dim_error_1 (line 763) | def test_nested_plate_plate_dim_error_1(Elbo): function test_nested_plate_plate_dim_error_2 (line 787) | def test_nested_plate_plate_dim_error_2(Elbo): function test_nested_plate_plate_dim_error_3 (line 805) | def test_nested_plate_plate_dim_error_3(Elbo): function test_nested_plate_plate_dim_error_4 (line 821) | def test_nested_plate_plate_dim_error_4(Elbo): function test_nested_plate_plate_subsample_param_ok (line 837) | def test_nested_plate_plate_subsample_param_ok(Elbo): function test_nonnested_plate_plate_ok (line 864) | def test_nonnested_plate_plate_ok(Elbo): function test_three_indep_plate_at_different_depths_ok (line 876) | def test_three_indep_plate_at_different_depths_ok(): function test_plate_wrong_size_error (line 912) | def test_plate_wrong_size_error(): function test_block_plate_name_ok (line 926) | def test_block_plate_name_ok(): function test_block_plate_dim_ok (line 950) | def test_block_plate_dim_ok(): function test_block_plate_missing_error (line 974) | def test_block_plate_missing_error(): function test_enum_discrete_misuse_warning (line 987) | def test_enum_discrete_misuse_warning(Elbo, enumerate_): function test_enum_discrete_single_ok (line 1002) | def test_enum_discrete_single_ok(): function test_enum_discrete_missing_config_warning (line 1015) | def test_enum_discrete_missing_config_warning(strict_enumeration_warning): function test_enum_discrete_single_single_ok (line 1031) | def test_enum_discrete_single_single_ok(): function test_enum_discrete_iplate_single_ok (line 1045) | def test_enum_discrete_iplate_single_ok(): function test_plate_enum_discrete_batch_ok (line 1059) | def test_plate_enum_discrete_batch_ok(): function test_plate_enum_discrete_no_discrete_vars_warning (line 1074) | def test_plate_enum_discrete_no_discrete_vars_warning(strict_enumeration... function test_no_plate_enum_discrete_batch_error (line 1095) | def test_no_plate_enum_discrete_batch_error(): function test_enum_discrete_parallel_ok (line 1110) | def test_enum_discrete_parallel_ok(max_plate_nesting): function test_enum_discrete_parallel_nested_ok (line 1134) | def test_enum_discrete_parallel_nested_ok(max_plate_nesting): function test_enumerate_parallel_plate_ok (line 1165) | def test_enumerate_parallel_plate_ok(enumerate_, expand, num_samples): function test_enum_discrete_plate_dependency_warning (line 1220) | def test_enum_discrete_plate_dependency_warning( function test_enum_discrete_iplate_plate_dependency_ok (line 1241) | def test_enum_discrete_iplate_plate_dependency_ok(enumerate_, max_plate_... function test_enum_discrete_iplates_plate_dependency_warning (line 1260) | def test_enum_discrete_iplates_plate_dependency_warning( function test_enum_discrete_plates_dependency_ok (line 1287) | def test_enum_discrete_plates_dependency_ok(enumerate_): function test_enum_discrete_non_enumerated_plate_ok (line 1305) | def test_enum_discrete_non_enumerated_plate_ok(enumerate_): function test_plate_shape_broadcasting (line 1325) | def test_plate_shape_broadcasting(): function test_enum_discrete_plate_shape_broadcasting_ok (line 1355) | def test_enum_discrete_plate_shape_broadcasting_ok(enumerate_, expand, n... function test_dim_allocation_ok (line 1417) | def test_dim_allocation_ok(Elbo, expand): function test_dim_allocation_error (line 1461) | def test_dim_allocation_error(Elbo, expand): function test_enum_in_model_ok (line 1487) | def test_enum_in_model_ok(): function test_enum_in_model_plate_ok (line 1525) | def test_enum_in_model_plate_ok(): function test_enum_sequential_in_model_error (line 1565) | def test_enum_sequential_in_model_error(): function test_enum_in_model_plate_reuse_ok (line 1581) | def test_enum_in_model_plate_reuse_ok(): function test_enum_in_model_multi_scale_error (line 1598) | def test_enum_in_model_multi_scale_error(): function test_enum_in_model_diamond_error (line 1618) | def test_enum_in_model_diamond_error(use_vindex): function test_vectorized_num_particles (line 1661) | def test_vectorized_num_particles(Elbo): function test_enum_discrete_vectorized_num_particles (line 1701) | def test_enum_discrete_vectorized_num_particles( function test_enum_recycling_chain (line 1799) | def test_enum_recycling_chain(): function test_enum_recycling_dbn (line 1817) | def test_enum_recycling_dbn(markov, use_vindex): function test_enum_recycling_nested (line 1849) | def test_enum_recycling_nested(): function test_enum_recycling_grid (line 1894) | def test_enum_recycling_grid(use_vindex): function test_enum_recycling_reentrant (line 1923) | def test_enum_recycling_reentrant(): function test_enum_recycling_reentrant_history (line 1954) | def test_enum_recycling_reentrant_history(history): function test_enum_recycling_mutual_recursion (line 1985) | def test_enum_recycling_mutual_recursion(): function test_enum_recycling_interleave (line 2032) | def test_enum_recycling_interleave(): function test_enum_recycling_plate (line 2053) | def test_enum_recycling_plate(): function test_factor_in_model_ok (line 2108) | def test_factor_in_model_ok(Elbo): function test_factor_in_guide_error (line 2128) | def test_factor_in_guide_error(Elbo): function test_factor_in_guide_ok (line 2149) | def test_factor_in_guide_ok(Elbo, has_rsample): function test_markov_history (line 2161) | def test_markov_history(history): function test_mean_field_ok (line 2192) | def test_mean_field_ok(): function test_mean_field_mask_ok (line 2206) | def test_mean_field_mask_ok(mask): function test_mean_field_warn (line 2219) | def test_mean_field_warn(): function test_tail_adaptive_ok (line 2232) | def test_tail_adaptive_ok(): function test_tail_adaptive_error (line 2256) | def test_tail_adaptive_error(): function test_tail_adaptive_warning (line 2278) | def test_tail_adaptive_warning(): function test_reparam_ok (line 2301) | def test_reparam_ok(Elbo): function test_reparam_mask_ok (line 2323) | def test_reparam_mask_ok(Elbo, mask): function test_reparam_mask_plate_ok (line 2355) | def test_reparam_mask_plate_ok(Elbo, mask): function test_obs_mask_ok (line 2392) | def test_obs_mask_ok(Elbo, mask, num_particles): function test_obs_mask_multivariate_ok (line 2437) | def test_obs_mask_multivariate_ok(Elbo, mask, num_particles): function test_obs_mask_multivariate_error (line 2473) | def test_obs_mask_multivariate_error(Elbo): function test_reparam_scale_ok (line 2506) | def test_reparam_scale_ok(Elbo, scale): function test_reparam_scale_plate_ok (line 2537) | def test_reparam_scale_plate_ok(Elbo, scale): function test_no_log_prob_ok (line 2561) | def test_no_log_prob_ok(Elbo): function test_reparam_stable (line 2580) | def test_reparam_stable(): function test_collapse_normal_normal (line 2599) | def test_collapse_normal_normal(num_particles): function test_collapse_normal_normal_plate (line 2620) | def test_collapse_normal_normal_plate(num_particles): function test_collapse_normal_plate_normal (line 2644) | def test_collapse_normal_plate_normal(num_particles): function test_collapse_beta_bernoulli (line 2668) | def test_collapse_beta_bernoulli(num_particles): function test_collapse_beta_binomial (line 2689) | def test_collapse_beta_binomial(num_particles): function test_collapse_beta_binomial_plate (line 2710) | def test_collapse_beta_binomial_plate(num_particles): function test_collapse_barrier (line 2734) | def test_collapse_barrier(num_particles): function test_ordered_logistic_plate (line 2755) | def test_ordered_logistic_plate(): FILE: tests/integration_tests/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/integration_tests/test_conjugate_gaussian_models.py function param_mse (line 24) | def param_mse(name, target): class GaussianChain (line 28) | class GaussianChain(TestCase): method setUp (line 31) | def setUp(self): method setup_chain (line 37) | def setup_chain(self, N): method setup_reparam_mask (line 76) | def setup_reparam_mask(self, N): method model (line 82) | def model(self, reparameterized, difficulty=0.0): method guide (line 98) | def guide(self, reparameterized, difficulty=0.0): class GaussianChainTests (line 137) | class GaussianChainTests(GaussianChain): method test_elbo_reparameterized_N_is_3 (line 138) | def test_elbo_reparameterized_N_is_3(self): method test_elbo_reparameterized_N_is_8 (line 142) | def test_elbo_reparameterized_N_is_8(self): method test_elbo_reparameterized_N_is_17 (line 150) | def test_elbo_reparameterized_N_is_17(self): method test_elbo_nonreparameterized_N_is_3 (line 154) | def test_elbo_nonreparameterized_N_is_3(self): method test_elbo_nonreparameterized_N_is_5 (line 158) | def test_elbo_nonreparameterized_N_is_5(self): method test_elbo_nonreparameterized_N_is_7 (line 166) | def test_elbo_nonreparameterized_N_is_7(self): method do_elbo_test (line 170) | def do_elbo_test(self, reparameterized, n_steps, lr, prec, difficulty=... class GaussianPyramidTests (line 256) | class GaussianPyramidTests(TestCase): method setUp (line 257) | def setUp(self): method setup_pyramid (line 260) | def setup_pyramid(self, N): method setup_reparam_mask (line 288) | def setup_reparam_mask(self, n): method set_model_permutations (line 295) | def set_model_permutations(self): method test_elbo_reparameterized_three_layers (line 310) | def test_elbo_reparameterized_three_layers(self): method test_elbo_reparameterized_four_layers (line 317) | def test_elbo_reparameterized_four_layers(self): method test_elbo_nonreparameterized_two_layers (line 324) | def test_elbo_nonreparameterized_two_layers(self): method test_elbo_nonreparameterized_three_layers (line 330) | def test_elbo_nonreparameterized_three_layers(self): method test_elbo_nonreparameterized_two_layers_model_permuted (line 336) | def test_elbo_nonreparameterized_two_layers_model_permuted(self): method test_elbo_nonreparameterized_three_layers_model_permuted (line 346) | def test_elbo_nonreparameterized_three_layers_model_permuted(self): method calculate_variational_targets (line 352) | def calculate_variational_targets(self): method construct_q_dag (line 424) | def construct_q_dag(self): method model (line 456) | def model(self, reparameterized, model_permutation, difficulty=0.0): method guide (line 503) | def guide(self, reparameterized, model_permutation, difficulty=0.0): method do_elbo_test (line 542) | def do_elbo_test( FILE: tests/integration_tests/test_tracegraph_elbo.py function param_mse (line 23) | def param_mse(name, target): function param_abs_error (line 27) | def param_abs_error(name, target): class NormalNormalTests (line 31) | class NormalNormalTests(TestCase): method setUp (line 32) | def setUp(self): method test_elbo_reparameterized (line 51) | def test_elbo_reparameterized(self): method test_elbo_nonreparameterized (line 55) | def test_elbo_nonreparameterized(self): method do_elbo_test (line 58) | def do_elbo_test(self, reparameterized, n_steps, prec): class NormalNormalNormalTests (line 107) | class NormalNormalNormalTests(TestCase): method setUp (line 108) | def setUp(self): method test_elbo_reparameterized (line 121) | def test_elbo_reparameterized(self): method test_elbo_nonreparameterized_both_baselines (line 124) | def test_elbo_nonreparameterized_both_baselines(self): method test_elbo_nonreparameterized_decaying_baseline (line 135) | def test_elbo_nonreparameterized_decaying_baseline(self): method test_elbo_nonreparameterized_nn_baseline (line 146) | def test_elbo_nonreparameterized_nn_baseline(self): method do_elbo_test (line 157) | def do_elbo_test( class BernoulliBetaTests (line 280) | class BernoulliBetaTests(TestCase): method setUp (line 281) | def setUp(self): method test_elbo_reparameterized (line 296) | def test_elbo_reparameterized(self): method test_elbo_nonreparameterized (line 299) | def test_elbo_nonreparameterized(self): method do_elbo_test (line 302) | def do_elbo_test(self, reparameterized, n_steps, beta1, lr): class ExponentialGammaTests (line 345) | class ExponentialGammaTests(TestCase): method setUp (line 346) | def setUp(self): method test_elbo_reparameterized (line 359) | def test_elbo_reparameterized(self): method test_elbo_nonreparameterized (line 362) | def test_elbo_nonreparameterized(self): method do_elbo_test (line 365) | def do_elbo_test(self, reparameterized, n_steps, beta1, lr): class RaoBlackwellizationTests (line 409) | class RaoBlackwellizationTests(TestCase): method setUp (line 410) | def setUp(self): method test_nested_iplate_in_elbo (line 437) | def test_nested_iplate_in_elbo(self, n_steps=4000): method test_plate_in_elbo_with_superfluous_rvs (line 499) | def test_plate_in_elbo_with_superfluous_rvs(self): method _test_plate_in_elbo (line 504) | def _test_plate_in_elbo( FILE: tests/nn/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/nn/test_autoregressive.py class AutoRegressiveNNTests (line 15) | class AutoRegressiveNNTests(TestCase): method setUp (line 16) | def setUp(self): method _test_jacobian (line 19) | def _test_jacobian(self, input_dim, observed_dim, hidden_dim, param_dim): method _test_masks (line 61) | def _test_masks( method test_jacobians (line 119) | def test_jacobians(self): method test_masks (line 124) | def test_masks(self): FILE: tests/nn/test_module.py function test_svi_smoke (line 24) | def test_svi_smoke(): function test_svi_elbomodule_interface (line 76) | def test_svi_elbomodule_interface( function test_local_param_global_behavior_fails (line 139) | def test_local_param_global_behavior_fails(local_params): function test_names (line 162) | def test_names(local_params): function test_delete (line 217) | def test_delete(): function test_nested (line 225) | def test_nested(): function test_module_cache (line 242) | def test_module_cache(): function test_submodule_contains_torch_module (line 266) | def test_submodule_contains_torch_module(): function test_hierarchy_prior_cached (line 273) | def test_hierarchy_prior_cached(): function test_constraints (line 333) | def test_constraints(shape, constraint_): function test_clear (line 360) | def test_clear(local_params): function test_sample (line 415) | def test_sample(): function test_cache (line 446) | def test_cache(): class AttributeModel (line 484) | class AttributeModel(PyroModule): method __init__ (line 485) | def __init__(self, size): method forward (line 496) | def forward(self): class DecoratorModel (line 500) | class DecoratorModel(PyroModule): method __init__ (line 501) | def __init__(self, size): method x (line 506) | def x(self): method y (line 510) | def y(self): method z (line 514) | def z(self): method s (line 518) | def s(self): method t (line 522) | def t(self): method u (line 526) | def u(self): method forward (line 529) | def forward(self): function test_decorator (line 535) | def test_decorator(Model, size): function test_mixin_factory (line 567) | def test_mixin_factory(): function test_to_pyro_module_ (line 607) | def test_to_pyro_module_(): function test_torch_serialize_attributes (line 669) | def test_torch_serialize_attributes(local_params): function test_torch_serialize_decorators (line 692) | def test_torch_serialize_decorators(local_params): function test_pyro_serialize (line 729) | def test_pyro_serialize(): function test_bayesian_gru (line 766) | def test_bayesian_gru(): function test_functorch_pyroparam (line 800) | def test_functorch_pyroparam(use_local_params): class BNN (line 870) | class BNN(PyroModule): method __init__ (line 872) | def __init__( method forward (line 915) | def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor: class SliceIndexingModuleListBNN (line 927) | class SliceIndexingModuleListBNN(BNN): method __init__ (line 930) | def __init__( method forward (line 941) | def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor: class PositionIndexingModuleListBNN (line 949) | class PositionIndexingModuleListBNN(BNN): method __init__ (line 952) | def __init__( method forward (line 963) | def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor: class NestedBNN (line 971) | class NestedBNN(pyro.nn.module.PyroModule): method __init__ (line 974) | def __init__(self, bnns: Iterable[BNN], use_new_module_list_type: bool... method forward (line 981) | def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor: function train_bnn (line 990) | def train_bnn(model: BNN, input_size: int) -> None: class ModuleListTester (line 1008) | class ModuleListTester: method setup (line 1009) | def setup(self, use_new_module_list_type: bool) -> None: method get_position_indexing_modulelist_bnn (line 1016) | def get_position_indexing_modulelist_bnn(self) -> PositionIndexingModu... method get_slice_indexing_modulelist_bnn (line 1024) | def get_slice_indexing_modulelist_bnn(self) -> SliceIndexingModuleList... method train_nested_bnn (line 1032) | def train_nested_bnn(self, module_getter: Callable[[], BNN]) -> None: class TestTorchModuleList (line 1042) | class TestTorchModuleList(ModuleListTester): method test_with_position_indexing (line 1043) | def test_with_position_indexing(self) -> None: method test_with_slice_indexing (line 1047) | def test_with_slice_indexing(self) -> None: class TestPyroModuleList (line 1054) | class TestPyroModuleList(ModuleListTester): method test_with_position_indexing (line 1055) | def test_with_position_indexing(self) -> None: method test_with_slice_indexing (line 1059) | def test_with_slice_indexing(self) -> None: function test_module_list (line 1064) | def test_module_list() -> None: function test_render_constrained_param (line 1069) | def test_render_constrained_param(use_module_local_params): FILE: tests/ops/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/ops/einsum/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/ops/einsum/test_adjoint.py function test_shape (line 52) | def test_shape(backend, equation): function test_marginal (line 95) | def test_marginal(equation): function test_require_backward_memory_leak (line 121) | def test_require_backward_memory_leak(): FILE: tests/ops/einsum/test_torch_log.py function test_einsum (line 42) | def test_einsum(equation, min_size, infinite): FILE: tests/ops/gamma_gaussian.py function random_gamma_gaussian (line 11) | def random_gamma_gaussian(batch_shape, dim, rank=None): function random_gamma (line 33) | def random_gamma(batch_shape): function assert_close_gamma_gaussian (line 42) | def assert_close_gamma_gaussian(actual, expected): FILE: tests/ops/gaussian.py function random_gaussian (line 11) | def random_gaussian(batch_shape, dim, rank=None, *, requires_grad=False): function random_mvn (line 28) | def random_mvn(batch_shape, dim, *, requires_grad=False): function assert_close_gaussian (line 40) | def assert_close_gaussian(actual, expected): FILE: tests/ops/test_arrowhead.py function test_utilities (line 18) | def test_utilities(head_size): FILE: tests/ops/test_contract.py function deep_copy (line 29) | def deep_copy(x): function deep_equal (line 44) | def deep_equal(x, y): function assert_immutable (line 67) | def assert_immutable(fn): function _normalize (line 87) | def _normalize(tensor, dims, plates): function test_partition_terms (line 119) | def test_partition_terms(inputs, dims, expected_num_components): function frame (line 146) | def frame(dim, size): function test_contract_to_tensor (line 259) | def test_contract_to_tensor(example): function test_contract_tensor_tree (line 279) | def test_contract_tensor_tree(example): function make_example (line 369) | def make_example(equation, fill=None, sizes=(2, 3)): function test_naive_ubersum (line 383) | def test_naive_ubersum(equation, plates): function test_ubersum (line 408) | def test_ubersum(equation, plates): function test_einsum_linear (line 432) | def test_einsum_linear(equation, plates): function test_ubersum_jit (line 459) | def test_ubersum_jit(equation, plates): function test_ubersum_total (line 500) | def test_ubersum_total(equation, plates): function test_ubersum_sizes (line 522) | def test_ubersum_sizes(impl, a, b, c, d): function test_ubersum_1 (line 535) | def test_ubersum_1(impl): function test_ubersum_2 (line 549) | def test_ubersum_2(impl): function test_ubersum_3 (line 564) | def test_ubersum_3(impl): function test_ubersum_4 (line 589) | def test_ubersum_4(impl): function test_ubersum_5 (line 610) | def test_ubersum_5(impl): function test_ubersum_collide_implemented (line 647) | def test_ubersum_collide_implemented(impl, implemented): function test_ubersum_collide_ok_1 (line 668) | def test_ubersum_collide_ok_1(impl): function test_ubersum_collide_ok_2 (line 686) | def test_ubersum_collide_ok_2(impl): function test_ubersum_collide_ok_3 (line 705) | def test_ubersum_collide_ok_3(impl): function test_ubersum_size_error (line 729) | def test_ubersum_size_error(impl, equation, shapes, plates): function test_ubersum_plate_error (line 747) | def test_ubersum_plate_error(impl, equation, plates): function test_adjoint_shape (line 771) | def test_adjoint_shape(backend, equation, plates): function test_adjoint_marginal (line 797) | def test_adjoint_marginal(equation, plates): FILE: tests/ops/test_gamma_gaussian.py function test_expand (line 42) | def test_expand( function test_reshape (line 76) | def test_reshape(old_shape, new_shape, dim): function test_cat (line 98) | def test_cat(shape, cat_dim, split, dim): function test_pad (line 123) | def test_pad(shape, left, right, dim): function test_add (line 135) | def test_add(shape, dim): function test_marginalize_shape (line 148) | def test_marginalize_shape(batch_shape, left, right): function test_marginalize (line 158) | def test_marginalize(batch_shape, left, right): function test_marginalize_condition (line 176) | def test_marginalize_condition(sample_shape, batch_shape, left, right): function test_condition (line 191) | def test_condition(sample_shape, batch_shape, left, right): function test_logsumexp (line 210) | def test_logsumexp(batch_shape, dim): function test_gamma_and_mvn_to_gamma_gaussian (line 232) | def test_gamma_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, dim): function test_matrix_and_mvn_to_gamma_gaussian (line 253) | def test_matrix_and_mvn_to_gamma_gaussian(sample_shape, batch_shape, x_d... function test_gamma_gaussian_tensordot (line 303) | def test_gamma_gaussian_tensordot( FILE: tests/ops/test_gaussian.py function test_expand (line 41) | def test_expand( function test_reshape (line 67) | def test_reshape(old_shape, new_shape, dim): function test_cat (line 89) | def test_cat(shape, cat_dim, split, dim): function test_pad (line 114) | def test_pad(shape, left, right, dim): function test_add (line 126) | def test_add(shape, dim): function test_rsample_shape (line 138) | def test_rsample_shape(sample_shape, batch_shape, dim): function test_rsample_distribution (line 149) | def test_rsample_distribution(batch_shape, dim): function test_marginalize_shape (line 174) | def test_marginalize_shape(batch_shape, left, right): function test_marginalize (line 184) | def test_marginalize(batch_shape, left, right): function test_marginalize_condition (line 195) | def test_marginalize_condition(sample_shape, batch_shape, left, right): function test_condition (line 208) | def test_condition(sample_shape, batch_shape, left, right): function test_logsumexp (line 234) | def test_logsumexp(batch_shape, dim): function test_affine_normal (line 255) | def test_affine_normal(batch_shape, x_dim, y_dim): function test_mvn_to_gaussian (line 293) | def test_mvn_to_gaussian(sample_shape, batch_shape, dim): function test_matrix_and_mvn_to_gaussian (line 306) | def test_matrix_and_mvn_to_gaussian(sample_shape, batch_shape, x_dim, y_... function test_matrix_and_mvn_to_gaussian_2 (line 323) | def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, ... function test_gaussian_tensordot (line 370) | def test_gaussian_tensordot( function test_gaussian_funsor (line 431) | def test_gaussian_funsor(batch_shape): function test_sequential_gaussian_tensordot (line 499) | def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps): function test_sequential_gaussian_filter_sample (line 516) | def test_sequential_gaussian_filter_sample( function test_sequential_gaussian_filter_sample_antithetic (line 540) | def test_sequential_gaussian_filter_sample_antithetic( function test_sequential_gaussian_filter_sample_stability (line 565) | def test_sequential_gaussian_filter_sample_stability(num_steps): FILE: tests/ops/test_indexing.py class TensorMock (line 15) | class TensorMock: method __getitem__ (line 16) | def __getitem__(self, args): function z (line 23) | def z(*args): function test_shape (line 105) | def test_shape(expression, expected_shape): function test_value (line 114) | def test_value(x_shape, i_shape, j_shape, event_shape): function test_hmm_example (line 134) | def test_hmm_example(prev_enum_dim, curr_enum_dim): function test_index (line 168) | def test_index(args, expected): FILE: tests/ops/test_integrator.py function register_model (line 25) | def register_model(init_args): class HarmonicOscillator (line 53) | class HarmonicOscillator: method kinetic_grad (line 55) | def kinetic_grad(p): method energy (line 59) | def energy(q, p): method potential_fn (line 63) | def potential_fn(q): class CircularPlanetaryMotion (line 80) | class CircularPlanetaryMotion: method kinetic_grad (line 82) | def kinetic_grad(p): method energy (line 86) | def energy(q, p): method potential_fn (line 94) | def potential_fn(q): class QuarticOscillator (line 111) | class QuarticOscillator: method kinetic_grad (line 113) | def kinetic_grad(p): method energy (line 117) | def energy(q, p): method potential_fn (line 121) | def potential_fn(q): function test_trajectory (line 126) | def test_trajectory(example): function test_energy_conservation (line 143) | def test_energy_conservation(example): function test_time_reversibility (line 161) | def test_time_reversibility(example): FILE: tests/ops/test_jit.py function test_varying_len_args (line 10) | def test_varying_len_args(): function test_varying_kwargs (line 24) | def test_varying_kwargs(): function test_varying_unhashable_kwargs (line 34) | def test_varying_unhashable_kwargs(): FILE: tests/ops/test_linalg.py function test_sym_rinverse (line 31) | def test_sym_rinverse(A, use_sym): FILE: tests/ops/test_newton.py function random_inside_unit_circle (line 17) | def random_inside_unit_circle(shape, requires_grad=False): function test_newton_step (line 30) | def test_newton_step(batch_shape, trust_radius, dims): function test_newton_step_trust (line 88) | def test_newton_step_trust(trust_radius, dims): function test_newton_step_converges (line 118) | def test_newton_step_converges(trust_radius, dims): FILE: tests/ops/test_packed.py function test_unpack_pack (line 22) | def test_unpack_pack(dims): function make_inputs (line 52) | def make_inputs(shapes, num_numbers=0): function test_broadcast_all (line 66) | def test_broadcast_all(shapes): FILE: tests/ops/test_provenance.py function test_provenance_tensor (line 38) | def test_provenance_tensor(dtype1, dtype2): function test_track_provenance (line 64) | def test_track_provenance(x): FILE: tests/ops/test_special.py function test_safe_log (line 14) | def test_safe_log(): function test_log_beta_stirling (line 43) | def test_log_beta_stirling(tol): function test_log_binomial_stirling (line 69) | def test_log_binomial_stirling(tol): function test_log_I1 (line 83) | def test_log_I1(order, value): function test_log_I1_shapes (line 90) | def test_log_I1_shapes(): function test_get_quad_rule (line 99) | def test_get_quad_rule(sigma): FILE: tests/ops/test_ssm_gp.py function test_matern_kernel (line 13) | def test_matern_kernel(num_gps, nu): FILE: tests/ops/test_stats.py function test_resample (line 30) | def test_resample(replacement): function test_quantile (line 50) | def test_quantile(): function test_weighed_quantile (line 63) | def test_weighed_quantile(): function test_pi (line 81) | def test_pi(): function test_hpdi (line 87) | def test_hpdi(): function _quantile (line 95) | def _quantile(x, dim=0): function _pi (line 99) | def _pi(x, dim=0): function _hpdi (line 103) | def _hpdi(x, dim=0): function test_statistics_A_ok_with_sample_shape (line 109) | def test_statistics_A_ok_with_sample_shape(statistics, sample_shape): function test_autocorrelation (line 127) | def test_autocorrelation(): function test_autocorrelation_trivial (line 137) | def test_autocorrelation_trivial(): function test_autocorrelation_vectorized (line 143) | def test_autocorrelation_vectorized(): function test_autocovariance (line 157) | def test_autocovariance(): function test_cummin (line 170) | def test_cummin(): function test_statistics_B_ok_with_sample_shape (line 182) | def test_statistics_B_ok_with_sample_shape(statistics, sample_shape): function test_gelman_rubin (line 202) | def test_gelman_rubin(): function test_split_gelman_rubin_agree_with_gelman_rubin (line 212) | def test_split_gelman_rubin_agree_with_gelman_rubin(): function test_effective_sample_size (line 219) | def test_effective_sample_size(): function test_diagnostics_ok_with_sample_shape (line 231) | def test_diagnostics_ok_with_sample_shape(diagnostics, sample_shape): function test_waic (line 256) | def test_waic(): function test_weighted_waic (line 273) | def test_weighted_waic(): function test_fit_generalized_pareto (line 304) | def test_fit_generalized_pareto(k, sigma, n_samples=5000): function test_crps_univariate_energy_score_empirical (line 317) | def test_crps_univariate_energy_score_empirical(num_samples, event_shape): function test_multivariate_energy_score (line 336) | def test_multivariate_energy_score(sample_dim, num_samples=10000): function test_energy_score_empirical_batched_calculation (line 365) | def test_energy_score_empirical_batched_calculation( FILE: tests/ops/test_streaming.py function generate_data (line 19) | def generate_data(num_samples): function sort_samples_in_place (line 44) | def sort_samples_in_place(x): function test_update_get (line 54) | def test_update_get(make_stats, size): function test_update_merge_get (line 76) | def test_update_merge_get(make_stats, left_size, right_size): function test_stats_of_dict (line 98) | def test_stats_of_dict(): FILE: tests/ops/test_tensor_utils.py function test_block_diag_embed (line 34) | def test_block_diag_embed(batch_size, block_size): function test_block_diag (line 53) | def test_block_diag(batch_shape, mat_size, block_size): function test_periodic_repeat (line 64) | def test_periodic_repeat(period, size, left_shape, right_shape): function test_periodic_features (line 75) | def test_periodic_features(duration): function test_periodic_cumsum (line 94) | def test_periodic_cumsum(period, size, left_shape, right_shape): function test_convolve_shape (line 111) | def test_convolve_shape(m, n, mode): function test_convolve (line 123) | def test_convolve(batch_shape, m, n, mode): function test_repeated_matmul (line 140) | def test_repeated_matmul(size, n): function test_dct (line 152) | def test_dct(shape): function test_idct (line 160) | def test_idct(shape): function test_dct_dim (line 169) | def test_dct_dim(fn, dim): function test_next_fast_len (line 179) | def test_next_fast_len(): function test_precision_to_scale_tril (line 191) | def test_precision_to_scale_tril(batch_shape, event_shape): FILE: tests/ops/test_welford.py function test_welford_diagonal (line 16) | def test_welford_diagonal(n_samples, dim_size): function test_welford_dense (line 36) | def test_welford_dense(n_samples, dim_size): function test_welford_arrowhead (line 57) | def test_welford_arrowhead(n_samples, dim_size, head_size, regularize): FILE: tests/optim/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): function pytest_addoption (line 16) | def pytest_addoption(parser): function pytest_generate_tests (line 20) | def pytest_generate_tests(metafunc): FILE: tests/optim/test_multi.py function test_optimizers (line 37) | def test_optimizers(factory): function test_multi_optimizer_disjoint_ok (line 75) | def test_multi_optimizer_disjoint_ok(): function test_multi_optimizer_overlap_error (line 83) | def test_multi_optimizer_overlap_error(): FILE: tests/optim/test_optim.py class OptimTests (line 23) | class OptimTests(TestCase): method setUp (line 24) | def setUp(self): method test_per_param_optim (line 32) | def test_per_param_optim(self): method do_test_per_param_optim (line 37) | def do_test_per_param_optim(self, fixed_param, free_param): function test_dynamic_lr (line 121) | def test_dynamic_lr(scheduler): function test_autowrap (line 171) | def test_autowrap(factory): function test_clip_norm (line 179) | def test_clip_norm(pyro_optim, clip, value): function test_clippedadam_clip (line 197) | def test_clippedadam_clip(clip_norm): function test_clippedadam_pass (line 215) | def test_clippedadam_pass(clip_norm): function test_clippedadam_lrd (line 234) | def test_clippedadam_lrd(lrd): function test_dctadam_param_subsample (line 245) | def test_dctadam_param_subsample(): function test_name_preserved_by_to_pyro_module (line 281) | def test_name_preserved_by_to_pyro_module(): function test_checkpoint (line 372) | def test_checkpoint(Optim, config): function test_centered_clipped_adam (line 440) | def test_centered_clipped_adam(plot): FILE: tests/params/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/params/test_module.py class outest (line 14) | class outest(nn.Module): method __init__ (line 15) | def __init__(self): method forward (line 21) | def forward(self, s): class outer (line 25) | class outer(torch.nn.Module): method __init__ (line 26) | def __init__(self): method forward (line 31) | def forward(self, s): class inner (line 35) | class inner(torch.nn.Module): method __init__ (line 36) | def __init__(self): method forward (line 41) | def forward(self, s): function test_module_nn (line 49) | def test_module_nn(nn_module): function test_param_no_grad (line 59) | def test_param_no_grad(nn_module): function test_module_sequential (line 76) | def test_module_sequential(nn_module): function test_random_module (line 86) | def test_random_module(nn_module): FILE: tests/params/test_param.py class ParamStoreDictTests (line 17) | class ParamStoreDictTests(TestCase): method setUp (line 18) | def setUp(self): method test_save_and_load (line 24) | def test_save_and_load(self): function test_dict_interface (line 90) | def test_dict_interface(): function test_scope (line 154) | def test_scope(): FILE: tests/perf/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/perf/test_benchmark.py function register_model (line 38) | def register_model(**model_kwargs): function poisson_gamma_model (line 65) | def poisson_gamma_model(reparameterized, Elbo): function bernoulli_beta_hmc (line 102) | def bernoulli_beta_hmc(**kwargs): function vsgp_multiclass (line 123) | def vsgp_multiclass(num_steps, whiten): function test_benchmark (line 154) | def test_benchmark(benchmark, model, model_args, id): function profile_fn (line 159) | def profile_fn(test_model): FILE: tests/poutine/conftest.py function pytest_collection_modifyitems (line 7) | def pytest_collection_modifyitems(items): FILE: tests/poutine/test_counterfactual.py function _item (line 13) | def _item(x): function test_counterfactual_query (line 28) | def test_counterfactual_query(intervene, observe, flip): function test_plate_duplication_smoke (line 81) | def test_plate_duplication_smoke(): FILE: tests/poutine/test_mapdata.py function test_nested_iplate (line 17) | def test_nested_iplate(): function plate_model (line 44) | def plate_model(subsample_size): function iplate_model (line 53) | def iplate_model(subsample_size): function nested_iplate_model (line 63) | def nested_iplate_model(subsample_size): function test_cond_indep_stack (line 85) | def test_cond_indep_stack(model, subsample_size): function test_replay (line 100) | def test_replay(model, subsample_size): function plate_custom_model (line 114) | def plate_custom_model(subsample): function iplate_custom_model (line 120) | def iplate_custom_model(subsample): function test_custom_subsample (line 130) | def test_custom_subsample(model): function plate_cuda_model (line 138) | def plate_cuda_model(subsample_size): function iplate_cuda_model (line 145) | def iplate_cuda_model(subsample_size): function test_cuda (line 157) | def test_cuda(model, subsample_size): function test_model_guide_mismatch (line 177) | def test_model_guide_mismatch(behavior, model_size, guide_size, model): FILE: tests/poutine/test_nesting.py function test_nested_reset (line 14) | def test_nested_reset(): FILE: tests/poutine/test_poutines.py function eq (line 27) | def eq(x, y, prec=1e-10): class NormalNormalNormalHandlerTestCase (line 32) | class NormalNormalNormalHandlerTestCase(TestCase): method setUp (line 33) | def setUp(self): class TraceHandlerTests (line 73) | class TraceHandlerTests(NormalNormalNormalHandlerTestCase): method test_trace_full (line 74) | def test_trace_full(self): method test_trace_return (line 91) | def test_trace_return(self): method test_trace_param_only (line 97) | def test_trace_param_only(self): class ReplayHandlerTests (line 102) | class ReplayHandlerTests(NormalNormalNormalHandlerTestCase): method test_replay_full (line 103) | def test_replay_full(self): method test_replay_full_repeat (line 113) | def test_replay_full_repeat(self): class BlockHandlerTests (line 126) | class BlockHandlerTests(NormalNormalNormalHandlerTestCase): method test_block_hide_fn (line 127) | def test_block_hide_fn(self): method test_block_expose_fn (line 139) | def test_block_expose_fn(self): method test_block_full (line 151) | def test_block_full(self): method test_block_full_hide (line 159) | def test_block_full_hide(self): method test_block_full_expose (line 171) | def test_block_full_expose(self): method test_block_full_hide_expose (line 183) | def test_block_full_hide_expose(self): method test_block_partial_hide (line 194) | def test_block_partial_hide(self): method test_block_partial_expose (line 209) | def test_block_partial_expose(self): method test_block_tutorial_case (line 224) | def test_block_tutorial_case(self): class QueueHandlerDiscreteTest (line 236) | class QueueHandlerDiscreteTest(TestCase): method setUp (line 237) | def setUp(self): method test_queue_single (line 272) | def test_queue_single(self): method test_queue_enumerate (line 278) | def test_queue_enumerate(self): method test_queue_max_tries (line 306) | def test_queue_max_tries(self): class Model (line 312) | class Model(nn.Module): method __init__ (line 313) | def __init__(self): method forward (line 317) | def forward(self, x): class LiftHandlerTests (line 321) | class LiftHandlerTests(TestCase): method setUp (line 322) | def setUp(self): method test_splice (line 386) | def test_splice(self): method test_memoize (line 397) | def test_memoize(self): method test_prior_dict (line 400) | def test_prior_dict(self): method test_unlifted_param (line 413) | def test_unlifted_param(self): method test_random_module (line 428) | def test_random_module(self): method test_random_module_warn (line 440) | def test_random_module_warn(self): method test_random_module_prior_dict (line 454) | def test_random_module_prior_dict(self): class QueueHandlerMixedTest (line 467) | class QueueHandlerMixedTest(TestCase): method setUp (line 468) | def setUp(self): method test_queue_single (line 485) | def test_queue_single(self): method test_queue_enumerate (line 491) | def test_queue_enumerate(self): class IndirectLambdaHandlerTests (line 518) | class IndirectLambdaHandlerTests(TestCase): method setUp (line 519) | def setUp(self): method test_graph_structure (line 545) | def test_graph_structure(self): method test_scale_factors (line 558) | def test_scale_factors(self): class SubstituteHandlerTests (line 575) | class SubstituteHandlerTests(NormalNormalNormalHandlerTestCase): method test_substitute (line 576) | def test_substitute(self): method test_stack_overwrite_behavior (line 583) | def test_stack_overwrite_behavior(self): method test_stack_success (line 593) | def test_stack_success(self): class ConditionHandlerTests (line 605) | class ConditionHandlerTests(NormalNormalNormalHandlerTestCase): method test_condition (line 606) | def test_condition(self): method test_trace_data (line 616) | def test_trace_data(self): method test_stack_overwrite_behavior (line 627) | def test_stack_overwrite_behavior(self): method test_stack_success (line 637) | def test_stack_success(self): class UnconditionHandlerTests (line 655) | class UnconditionHandlerTests(NormalNormalNormalHandlerTestCase): method test_uncondition (line 656) | def test_uncondition(self): method test_undo_uncondition (line 663) | def test_undo_uncondition(self): class EscapeHandlerTests (line 672) | class EscapeHandlerTests(TestCase): method setUp (line 673) | def setUp(self): method test_discrete_escape (line 688) | def test_discrete_escape(self): method test_all_escape (line 698) | def test_all_escape(self): method test_trace_compose (line 707) | def test_trace_compose(self): class InferConfigHandlerTests (line 729) | class InferConfigHandlerTests(TestCase): method setUp (line 730) | def setUp(self): method test_infer_config_sample (line 748) | def test_infer_config_sample(self): class EqualizeHandlerTests (line 758) | class EqualizeHandlerTests(TestCase): method setUp (line 759) | def setUp(self): method test_sample_site_equalization (line 774) | def test_sample_site_equalization(self): method test_param_equalization (line 793) | def test_param_equalization(self): method test_render_model (line 801) | def test_render_model(self): function test_condition_by_equalize (line 812) | def test_condition_by_equalize(loc_x, scale_x, loc_y, scale_y, keep_dist): function test_enumerate_poutine (line 854) | def test_enumerate_poutine(depth, first_available_dim): function test_replay_enumerate_poutine (line 880) | def test_replay_enumerate_poutine(depth, first_available_dim): function test_plate_preserves_has_rsample (line 921) | def test_plate_preserves_has_rsample(has_rsample, depth): function test_plate_error_on_enter (line 932) | def test_plate_error_on_enter(): function test_trace_plate (line 946) | def test_trace_plate(graph_type: str, expected: set): function test_decorator_interface_primitives (line 957) | def test_decorator_interface_primitives(): function test_decorator_interface_queue (line 987) | def test_decorator_interface_queue(): function test_method_decorator_interface_condition (line 1008) | def test_method_decorator_interface_condition(): function test_trace_log_prob_err_msg (line 1024) | def test_trace_log_prob_err_msg(): function test_trace_log_prob_sum_err_msg (line 1034) | def test_trace_log_prob_sum_err_msg(): function test_trace_score_parts_err_msg (line 1044) | def test_trace_score_parts_err_msg(): function _model (line 1054) | def _model(a=torch.tensor(1.0), b=torch.tensor(1.0)): function test_pickling (line 1068) | def test_pickling(wrapper): function test_arg_kwarg_error (line 1090) | def test_arg_kwarg_error(): function test_block_class_method (line 1106) | def test_block_class_method(): FILE: tests/poutine/test_properties.py class ExampleModel (line 17) | class ExampleModel: method __init__ (line 18) | def __init__(self, fn, poutine_kwargs): method __call__ (line 22) | def __call__(self, *args, **kwargs): method bind_poutine (line 25) | def bind_poutine(self, poutine_name): function register_model (line 34) | def register_model(**poutine_kwargs): function trivial_model (line 51) | def trivial_model(): function normal_model (line 67) | def normal_model(): function normal_normal_model (line 84) | def normal_normal_model(): function bernoulli_normal_model (line 103) | def bernoulli_normal_model(): function get_trace (line 111) | def get_trace(fn, *args, **kwargs): function test_idempotent (line 125) | def test_idempotent(poutine_name, model): function test_commutes (line 141) | def test_commutes(p1_name, p2_name, model): FILE: tests/poutine/test_runtime.py function test_get_mask (line 12) | def test_get_mask(): function test_get_plates (line 29) | def test_get_plates(): FILE: tests/poutine/test_trace_struct.py function test_topological_sort (line 36) | def test_topological_sort(edges): function test_connectivity_on_removal (line 56) | def test_connectivity_on_removal(edges): FILE: tests/pyroapi/conftest.py function pytest_runtest_call (line 7) | def pytest_runtest_call(item): FILE: tests/pyroapi/test_pyroapi.py function backend (line 12) | def backend(request): FILE: tests/test_examples.py function xfail_jit (line 215) | def xfail_jit(*args, **kwargs): function test_coverage (line 327) | def test_coverage(): function test_cpu (line 361) | def test_cpu(example): function test_cuda (line 371) | def test_cuda(example): function test_jit (line 380) | def test_jit(example): function test_horovod (line 391) | def test_horovod(np, example): function test_funsor (line 404) | def test_funsor(example): FILE: tests/test_generic.py function test_mcmc_interface (line 16) | def test_mcmc_interface(model, backend): function test_not_implemented (line 31) | def test_not_implemented(backend): function test_model_sample (line 41) | def test_model_sample(model, backend): function test_rng_seed (line 50) | def test_rng_seed(model, backend): function test_rng_state (line 65) | def test_rng_state(model, backend): function test_trace_handler (line 84) | def test_trace_handler(model, backend): FILE: tests/test_primitives.py function test_sample_ok (line 16) | def test_sample_ok(): function test_observe_warn (line 22) | def test_observe_warn(): function test_param_ok (line 27) | def test_param_ok(): function test_deterministic_ok (line 33) | def test_deterministic_ok(): function test_obs_mask_shape (line 48) | def test_obs_mask_shape(mask: Optional[torch.Tensor]): FILE: tests/test_settings.py function test_settings (line 13) | def test_settings(): function test_register (line 23) | def test_register(): FILE: tests/test_util.py function test_warn_if_nan (line 14) | def test_warn_if_nan(): function test_warn_if_inf (line 52) | def test_warn_if_inf(): FILE: tutorial/source/cleannb.py function cleannb (line 10) | def cleannb(nbfile): FILE: tutorial/source/search_inference.py function memoize (line 23) | def memoize(fn=None, **kwargs): class HashingMarginal (line 29) | class HashingMarginal(dist.Distribution): method __init__ (line 38) | def __init__(self, trace_dist, sites=None): method _dist_and_values (line 55) | def _dist_and_values(self): method sample (line 86) | def sample(self): method log_prob (line 91) | def log_prob(self, val): method enumerate_support (line 101) | def enumerate_support(self): method _dict_to_tuple (line 105) | def _dict_to_tuple(self, d): method _weighted_mean (line 116) | def _weighted_mean(self, value, dim=0): method mean (line 123) | def mean(self): method variance (line 128) | def variance(self): class Search (line 139) | class Search(TracePosterior): method __init__ (line 144) | def __init__(self, model, max_tries=int(1e6), **kwargs): method _traces (line 149) | def _traces(self, *args, **kwargs): function pqueue (line 163) | def pqueue(fn, queue): class BestFirstSearch (line 202) | class BestFirstSearch(TracePosterior): method __init__ (line 208) | def __init__(self, model, num_samples=None, **kwargs): method _traces (line 215) | def _traces(self, *args, **kwargs):