Copy disabled (too large)
Download .txt
Showing preview only (21,182K chars total). Download the full file to get everything.
Repository: pyro-ppl/pyro
Branch: dev
Commit: 1bbbf38f3c26
Files: 797
Total size: 20.1 MB
Directory structure:
gitextract_cigk8yyx/
├── .codecov.yml
├── .coveragerc
├── .gitattributes
├── .github/
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE/
│ │ ├── config.yml
│ │ └── issue_template.md
│ └── workflows/
│ ├── ci.yml
│ └── publish.yml
├── .gitignore
├── .readthedocs.yml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.md
├── LICENSES/
│ ├── Apache-2.0.txt
│ ├── BSD-3-Clause.txt
│ └── MIT.txt
├── MANIFEST.in
├── Makefile
├── README.md
├── RELEASE-MANAGEMENT.md
├── docker/
│ ├── Dockerfile
│ ├── Makefile
│ ├── README.md
│ └── install.sh
├── docs/
│ ├── Makefile
│ ├── README.md
│ ├── requirements.txt
│ └── source/
│ ├── _static/
│ │ ├── css/
│ │ │ └── pyro.css
│ │ └── img/
│ │ └── favicon/
│ │ ├── browserconfig.xml
│ │ └── manifest.json
│ ├── conf.py
│ ├── contrib.autoname.rst
│ ├── contrib.bnn.rst
│ ├── contrib.cevae.rst
│ ├── contrib.easyguide.rst
│ ├── contrib.epidemiology.rst
│ ├── contrib.examples.rst
│ ├── contrib.forecast.rst
│ ├── contrib.funsor.rst
│ ├── contrib.gp.rst
│ ├── contrib.minipyro.rst
│ ├── contrib.mue.rst
│ ├── contrib.oed.rst
│ ├── contrib.randomvariable.rst
│ ├── contrib.timeseries.rst
│ ├── contrib.tracking.rst
│ ├── contrib.zuko.rst
│ ├── distributions.rst
│ ├── getting_started.rst
│ ├── index.rst
│ ├── infer.autoguide.rst
│ ├── infer.reparam.rst
│ ├── infer.util.rst
│ ├── inference.rst
│ ├── inference_algos.rst
│ ├── mcmc.rst
│ ├── nn.rst
│ ├── ops.rst
│ ├── optimization.rst
│ ├── parameters.rst
│ ├── poutine.rst
│ ├── primitives.rst
│ ├── pyro.infer.mcmc.txt
│ ├── pyro.optim.txt
│ ├── pyro.poutine.txt
│ ├── settings.rst
│ └── testing.rst
├── examples/
│ ├── __init__.py
│ ├── air/
│ │ ├── air.py
│ │ ├── main.py
│ │ ├── modules.py
│ │ └── viz.py
│ ├── baseball.py
│ ├── capture_recapture/
│ │ └── cjs.py
│ ├── contrib/
│ │ ├── __init__.py
│ │ ├── autoname/
│ │ │ ├── mixture.py
│ │ │ ├── scoping_mixture.py
│ │ │ └── tree_data.py
│ │ ├── cevae/
│ │ │ └── synthetic.py
│ │ ├── epidemiology/
│ │ │ ├── regional.py
│ │ │ └── sir.py
│ │ ├── forecast/
│ │ │ └── bart.py
│ │ ├── funsor/
│ │ │ ├── __init__.py
│ │ │ └── hmm.py
│ │ ├── gp/
│ │ │ └── sv-dkl.py
│ │ ├── mue/
│ │ │ ├── FactorMuE.py
│ │ │ └── ProfileHMM.py
│ │ ├── oed/
│ │ │ ├── ab_test.py
│ │ │ └── gp_bayes_opt.py
│ │ └── timeseries/
│ │ └── gp_models.py
│ ├── cvae/
│ │ ├── __init__.py
│ │ ├── baseline.py
│ │ ├── cvae.py
│ │ ├── main.py
│ │ ├── mnist.py
│ │ └── util.py
│ ├── dmm.py
│ ├── eight_schools/
│ │ ├── README.md
│ │ ├── data.py
│ │ ├── mcmc.py
│ │ └── svi.py
│ ├── einsum.py
│ ├── hmm.py
│ ├── inclined_plane.py
│ ├── lda.py
│ ├── lkj.py
│ ├── minipyro.py
│ ├── mixed_hmm/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── experiment.py
│ │ ├── model.py
│ │ └── seal_data.py
│ ├── neutra.py
│ ├── rsa/
│ │ ├── README.md
│ │ ├── generics.py
│ │ ├── hyperbole.py
│ │ ├── schelling.py
│ │ ├── schelling_false.py
│ │ ├── search_inference.py
│ │ └── semantic_parsing.py
│ ├── scanvi/
│ │ ├── __init__.py
│ │ └── scanvi.py
│ ├── sir_hmc.py
│ ├── smcfilter.py
│ ├── sparse_gamma_def.py
│ ├── sparse_regression.py
│ ├── svi_horovod.py
│ ├── svi_lightning.py
│ ├── svi_torch.py
│ ├── toy_mixture_model_discrete_enumeration.py
│ └── vae/
│ ├── ss_vae_M2.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── custom_mlp.py
│ │ ├── mnist_cached.py
│ │ └── vae_plots.py
│ ├── vae.py
│ └── vae_comparison.py
├── profiler/
│ ├── __init__.py
│ ├── distributions.py
│ ├── gaussianhmm.py
│ ├── hmm.py
│ └── profiling_utils.py
├── pyproject.toml
├── pyro/
│ ├── __init__.py
│ ├── contrib/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── autoguide.py
│ │ ├── autoname/
│ │ │ ├── __init__.py
│ │ │ ├── autoname.py
│ │ │ ├── named.py
│ │ │ └── scoping.py
│ │ ├── bnn/
│ │ │ ├── __init__.py
│ │ │ ├── hidden_layer.py
│ │ │ └── utils.py
│ │ ├── cevae/
│ │ │ └── __init__.py
│ │ ├── conjugate/
│ │ │ ├── __init__.py
│ │ │ └── infer.py
│ │ ├── easyguide/
│ │ │ ├── __init__.py
│ │ │ └── easyguide.py
│ │ ├── epidemiology/
│ │ │ ├── __init__.py
│ │ │ ├── compartmental.py
│ │ │ ├── distributions.py
│ │ │ ├── models.py
│ │ │ └── util.py
│ │ ├── examples/
│ │ │ ├── __init__.py
│ │ │ ├── bart.py
│ │ │ ├── finance.py
│ │ │ ├── multi_mnist.py
│ │ │ ├── nextstrain.py
│ │ │ ├── polyphonic_data_loader.py
│ │ │ ├── scanvi_data.py
│ │ │ └── util.py
│ │ ├── forecast/
│ │ │ ├── __init__.py
│ │ │ ├── evaluate.py
│ │ │ ├── forecaster.py
│ │ │ └── util.py
│ │ ├── funsor/
│ │ │ ├── __init__.py
│ │ │ ├── handlers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── enum_messenger.py
│ │ │ │ ├── named_messenger.py
│ │ │ │ ├── plate_messenger.py
│ │ │ │ ├── primitives.py
│ │ │ │ ├── replay_messenger.py
│ │ │ │ ├── runtime.py
│ │ │ │ └── trace_messenger.py
│ │ │ └── infer/
│ │ │ ├── __init__.py
│ │ │ ├── discrete.py
│ │ │ ├── elbo.py
│ │ │ ├── trace_elbo.py
│ │ │ ├── traceenum_elbo.py
│ │ │ └── tracetmc_elbo.py
│ │ ├── gp/
│ │ │ ├── __init__.py
│ │ │ ├── kernels/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── brownian.py
│ │ │ │ ├── coregionalize.py
│ │ │ │ ├── dot_product.py
│ │ │ │ ├── isotropic.py
│ │ │ │ ├── kernel.py
│ │ │ │ ├── periodic.py
│ │ │ │ └── static.py
│ │ │ ├── likelihoods/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── binary.py
│ │ │ │ ├── gaussian.py
│ │ │ │ ├── likelihood.py
│ │ │ │ ├── multi_class.py
│ │ │ │ └── poisson.py
│ │ │ ├── models/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── gplvm.py
│ │ │ │ ├── gpr.py
│ │ │ │ ├── model.py
│ │ │ │ ├── sgpr.py
│ │ │ │ ├── vgp.py
│ │ │ │ └── vsgp.py
│ │ │ ├── parameterized.py
│ │ │ └── util.py
│ │ ├── minipyro.py
│ │ ├── mue/
│ │ │ ├── __init__.py
│ │ │ ├── dataloaders.py
│ │ │ ├── missingdatahmm.py
│ │ │ ├── models.py
│ │ │ └── statearrangers.py
│ │ ├── oed/
│ │ │ ├── __init__.py
│ │ │ ├── eig.py
│ │ │ ├── glmm/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── glmm.py
│ │ │ │ └── guides.py
│ │ │ ├── search.py
│ │ │ └── util.py
│ │ ├── randomvariable/
│ │ │ ├── __init__.py
│ │ │ └── random_variable.py
│ │ ├── timeseries/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── gp.py
│ │ │ ├── lgssm.py
│ │ │ └── lgssmgp.py
│ │ ├── tracking/
│ │ │ ├── __init__.py
│ │ │ ├── assignment.py
│ │ │ ├── distributions.py
│ │ │ ├── dynamic_models.py
│ │ │ ├── extended_kalman_filter.py
│ │ │ ├── hashing.py
│ │ │ └── measurements.py
│ │ ├── util.py
│ │ └── zuko.py
│ ├── distributions/
│ │ ├── __init__.py
│ │ ├── affine_beta.py
│ │ ├── asymmetriclaplace.py
│ │ ├── avf_mvn.py
│ │ ├── coalescent.py
│ │ ├── conditional.py
│ │ ├── conjugate.py
│ │ ├── constraints.py
│ │ ├── delta.py
│ │ ├── diag_normal_mixture.py
│ │ ├── diag_normal_mixture_shared_cov.py
│ │ ├── distribution.py
│ │ ├── empirical.py
│ │ ├── extended.py
│ │ ├── folded.py
│ │ ├── gaussian_scale_mixture.py
│ │ ├── grouped_normal_normal.py
│ │ ├── hmm.py
│ │ ├── improper_uniform.py
│ │ ├── inverse_gamma.py
│ │ ├── kl.py
│ │ ├── lkj.py
│ │ ├── log_normal_negative_binomial.py
│ │ ├── logistic.py
│ │ ├── mixture.py
│ │ ├── multivariate_studentt.py
│ │ ├── nanmasked.py
│ │ ├── omt_mvn.py
│ │ ├── one_one_matching.py
│ │ ├── one_two_matching.py
│ │ ├── ordered_logistic.py
│ │ ├── polya_gamma.py
│ │ ├── projected_normal.py
│ │ ├── rejector.py
│ │ ├── relaxed_straight_through.py
│ │ ├── score_parts.py
│ │ ├── sine_bivariate_von_mises.py
│ │ ├── sine_skewed.py
│ │ ├── softlaplace.py
│ │ ├── spanning_tree.cpp
│ │ ├── spanning_tree.py
│ │ ├── stable.py
│ │ ├── stable_log_prob.py
│ │ ├── testing/
│ │ │ ├── __init__.py
│ │ │ ├── fakes.py
│ │ │ ├── gof.py
│ │ │ ├── naive_dirichlet.py
│ │ │ ├── rejection_exponential.py
│ │ │ ├── rejection_gamma.py
│ │ │ └── special.py
│ │ ├── torch.py
│ │ ├── torch_distribution.py
│ │ ├── torch_patch.py
│ │ ├── torch_transform.py
│ │ ├── transforms/
│ │ │ ├── __init__.py
│ │ │ ├── affine_autoregressive.py
│ │ │ ├── affine_coupling.py
│ │ │ ├── basic.py
│ │ │ ├── batchnorm.py
│ │ │ ├── block_autoregressive.py
│ │ │ ├── cholesky.py
│ │ │ ├── discrete_cosine.py
│ │ │ ├── generalized_channel_permute.py
│ │ │ ├── haar.py
│ │ │ ├── householder.py
│ │ │ ├── lower_cholesky_affine.py
│ │ │ ├── matrix_exponential.py
│ │ │ ├── neural_autoregressive.py
│ │ │ ├── normalize.py
│ │ │ ├── ordered.py
│ │ │ ├── permute.py
│ │ │ ├── planar.py
│ │ │ ├── polynomial.py
│ │ │ ├── power.py
│ │ │ ├── radial.py
│ │ │ ├── simplex_to_ordered.py
│ │ │ ├── softplus.py
│ │ │ ├── spline.py
│ │ │ ├── spline_autoregressive.py
│ │ │ ├── spline_coupling.py
│ │ │ ├── sylvester.py
│ │ │ ├── unit_cholesky.py
│ │ │ └── utils.py
│ │ ├── unit.py
│ │ ├── util.py
│ │ ├── von_mises_3d.py
│ │ └── zero_inflated.py
│ ├── generic.py
│ ├── infer/
│ │ ├── __init__.py
│ │ ├── abstract_infer.py
│ │ ├── autoguide/
│ │ │ ├── __init__.py
│ │ │ ├── effect.py
│ │ │ ├── gaussian.py
│ │ │ ├── guides.py
│ │ │ ├── initialization.py
│ │ │ ├── structured.py
│ │ │ └── utils.py
│ │ ├── csis.py
│ │ ├── discrete.py
│ │ ├── elbo.py
│ │ ├── energy_distance.py
│ │ ├── enum.py
│ │ ├── importance.py
│ │ ├── inspect.py
│ │ ├── mcmc/
│ │ │ ├── __init__.py
│ │ │ ├── adaptation.py
│ │ │ ├── api.py
│ │ │ ├── hmc.py
│ │ │ ├── logger.py
│ │ │ ├── mcmc_kernel.py
│ │ │ ├── nuts.py
│ │ │ ├── rwkernel.py
│ │ │ └── util.py
│ │ ├── predictive.py
│ │ ├── renyi_elbo.py
│ │ ├── reparam/
│ │ │ ├── __init__.py
│ │ │ ├── conjugate.py
│ │ │ ├── discrete_cosine.py
│ │ │ ├── haar.py
│ │ │ ├── hmm.py
│ │ │ ├── loc_scale.py
│ │ │ ├── neutra.py
│ │ │ ├── projected_normal.py
│ │ │ ├── reparam.py
│ │ │ ├── softmax.py
│ │ │ ├── split.py
│ │ │ ├── stable.py
│ │ │ ├── strategies.py
│ │ │ ├── structured.py
│ │ │ ├── studentt.py
│ │ │ ├── transform.py
│ │ │ └── unit_jacobian.py
│ │ ├── resampler.py
│ │ ├── rws.py
│ │ ├── smcfilter.py
│ │ ├── svgd.py
│ │ ├── svi.py
│ │ ├── trace_elbo.py
│ │ ├── trace_mean_field_elbo.py
│ │ ├── trace_mmd.py
│ │ ├── trace_tail_adaptive_elbo.py
│ │ ├── traceenum_elbo.py
│ │ ├── tracegraph_elbo.py
│ │ ├── tracetmc_elbo.py
│ │ └── util.py
│ ├── logger.py
│ ├── nn/
│ │ ├── __init__.py
│ │ ├── auto_reg_nn.py
│ │ ├── dense_nn.py
│ │ └── module.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── arrowhead.py
│ │ ├── contract.py
│ │ ├── dual_averaging.py
│ │ ├── einsum/
│ │ │ ├── __init__.py
│ │ │ ├── adjoint.py
│ │ │ ├── torch_log.py
│ │ │ ├── torch_map.py
│ │ │ ├── torch_marginal.py
│ │ │ ├── torch_sample.py
│ │ │ └── util.py
│ │ ├── gamma_gaussian.py
│ │ ├── gaussian.py
│ │ ├── hessian.py
│ │ ├── indexing.py
│ │ ├── integrator.py
│ │ ├── jit.py
│ │ ├── linalg.py
│ │ ├── newton.py
│ │ ├── packed.py
│ │ ├── provenance.py
│ │ ├── rings.py
│ │ ├── special.py
│ │ ├── ssm_gp.py
│ │ ├── stats.py
│ │ ├── streaming.py
│ │ ├── tensor_utils.py
│ │ └── welford.py
│ ├── optim/
│ │ ├── __init__.py
│ │ ├── adagrad_rmsprop.py
│ │ ├── clipped_adam.py
│ │ ├── dct_adam.py
│ │ ├── horovod.py
│ │ ├── lr_scheduler.py
│ │ ├── multi.py
│ │ ├── optim.py
│ │ └── pytorch_optimizers.py
│ ├── params/
│ │ ├── __init__.py
│ │ └── param_store.py
│ ├── poutine/
│ │ ├── __init__.py
│ │ ├── block_messenger.py
│ │ ├── broadcast_messenger.py
│ │ ├── collapse_messenger.py
│ │ ├── condition_messenger.py
│ │ ├── do_messenger.py
│ │ ├── enum_messenger.py
│ │ ├── equalize_messenger.py
│ │ ├── escape_messenger.py
│ │ ├── guide.py
│ │ ├── handlers.py
│ │ ├── indep_messenger.py
│ │ ├── infer_config_messenger.py
│ │ ├── lift_messenger.py
│ │ ├── markov_messenger.py
│ │ ├── mask_messenger.py
│ │ ├── messenger.py
│ │ ├── plate_messenger.py
│ │ ├── reentrant_messenger.py
│ │ ├── reparam_messenger.py
│ │ ├── replay_messenger.py
│ │ ├── runtime.py
│ │ ├── scale_messenger.py
│ │ ├── seed_messenger.py
│ │ ├── subsample_messenger.py
│ │ ├── substitute_messenger.py
│ │ ├── trace_messenger.py
│ │ ├── trace_struct.py
│ │ ├── uncondition_messenger.py
│ │ └── util.py
│ ├── primitives.py
│ ├── py.typed
│ ├── settings.py
│ └── util.py
├── scripts/
│ ├── install_pytorch.sh
│ ├── perf_test.sh
│ ├── profile_model.sh
│ ├── update_headers.py
│ └── update_version.py
├── setup.cfg
├── setup.py
├── tests/
│ ├── README.md
│ ├── __init__.py
│ ├── common.py
│ ├── conftest.py
│ ├── contrib/
│ │ ├── __init__.py
│ │ ├── autoname/
│ │ │ ├── test_autoname.py
│ │ │ ├── test_named.py
│ │ │ └── test_scoping.py
│ │ ├── bnn/
│ │ │ └── test_hidden_layer.py
│ │ ├── cevae/
│ │ │ └── test_cevae.py
│ │ ├── conftest.py
│ │ ├── easyguide/
│ │ │ └── test_easyguide.py
│ │ ├── epidemiology/
│ │ │ ├── __init__.py
│ │ │ ├── test_distributions.py
│ │ │ ├── test_models.py
│ │ │ ├── test_quant.py
│ │ │ └── test_util.py
│ │ ├── forecast/
│ │ │ ├── __init__.py
│ │ │ ├── test_evaluate.py
│ │ │ ├── test_forecaster.py
│ │ │ └── test_util.py
│ │ ├── funsor/
│ │ │ ├── conftest.py
│ │ │ ├── test_enum_funsor.py
│ │ │ ├── test_infer_discrete.py
│ │ │ ├── test_named_handlers.py
│ │ │ ├── test_pyroapi_funsor.py
│ │ │ ├── test_tmc.py
│ │ │ ├── test_valid_models_enum.py
│ │ │ ├── test_valid_models_plate.py
│ │ │ ├── test_valid_models_sequential_plate.py
│ │ │ └── test_vectorized_markov.py
│ │ ├── gp/
│ │ │ ├── __init__.py
│ │ │ ├── test_conditional.py
│ │ │ ├── test_kernels.py
│ │ │ ├── test_likelihoods.py
│ │ │ ├── test_models.py
│ │ │ └── test_parameterized.py
│ │ ├── mue/
│ │ │ ├── test_dataloaders.py
│ │ │ ├── test_missingdatahmm.py
│ │ │ ├── test_models.py
│ │ │ └── test_statearrangers.py
│ │ ├── oed/
│ │ │ ├── test_ewma.py
│ │ │ ├── test_finite_spaces_eig.py
│ │ │ ├── test_glmm.py
│ │ │ ├── test_linear_models_eig.py
│ │ │ └── test_xexpx.py
│ │ ├── randomvariable/
│ │ │ └── test_random_variable.py
│ │ ├── test_hessian.py
│ │ ├── test_minipyro.py
│ │ ├── test_util.py
│ │ ├── test_zuko.py
│ │ ├── timeseries/
│ │ │ ├── test_gp.py
│ │ │ └── test_lgssm.py
│ │ └── tracking/
│ │ ├── __init__.py
│ │ ├── test_assignment.py
│ │ ├── test_distributions.py
│ │ ├── test_dynamic_models.py
│ │ ├── test_ekf.py
│ │ ├── test_em.py
│ │ ├── test_hashing.py
│ │ └── test_measurements.py
│ ├── distributions/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── dist_fixture.py
│ │ ├── test_binomial.py
│ │ ├── test_categorical.py
│ │ ├── test_coalescent.py
│ │ ├── test_conjugate.py
│ │ ├── test_conjugate_update.py
│ │ ├── test_constraints.py
│ │ ├── test_cuda.py
│ │ ├── test_delta.py
│ │ ├── test_distributions.py
│ │ ├── test_empirical.py
│ │ ├── test_extended.py
│ │ ├── test_gaussian_mixtures.py
│ │ ├── test_grouped_normal_normal.py
│ │ ├── test_haar.py
│ │ ├── test_hmm.py
│ │ ├── test_ig.py
│ │ ├── test_improper_uniform.py
│ │ ├── test_independent.py
│ │ ├── test_kl.py
│ │ ├── test_lkj.py
│ │ ├── test_log_normal_negative_binomial.py
│ │ ├── test_lowrank_mvn.py
│ │ ├── test_mask.py
│ │ ├── test_mixture.py
│ │ ├── test_mvn.py
│ │ ├── test_mvt.py
│ │ ├── test_nanmasked.py
│ │ ├── test_omt_mvn.py
│ │ ├── test_one_hot_categorical.py
│ │ ├── test_one_one_matching.py
│ │ ├── test_one_two_matching.py
│ │ ├── test_ordered_logistic.py
│ │ ├── test_pickle.py
│ │ ├── test_polya_gamma.py
│ │ ├── test_projected_normal.py
│ │ ├── test_rejector.py
│ │ ├── test_relaxed_straight_through.py
│ │ ├── test_reshape.py
│ │ ├── test_shapes.py
│ │ ├── test_sine_bivariate_von_mises.py
│ │ ├── test_sine_skewed.py
│ │ ├── test_spanning_tree.py
│ │ ├── test_stable.py
│ │ ├── test_stable_log_prob.py
│ │ ├── test_tensor_type.py
│ │ ├── test_torch_patch.py
│ │ ├── test_transforms.py
│ │ ├── test_unit.py
│ │ ├── test_util.py
│ │ ├── test_von_mises.py
│ │ ├── test_zero_inflated.py
│ │ └── testing/
│ │ ├── test_gof.py
│ │ └── test_special.py
│ ├── doctest_fixtures.py
│ ├── infer/
│ │ ├── __init__.py
│ │ ├── autoguide/
│ │ │ ├── __init__.py
│ │ │ ├── conftest.py
│ │ │ ├── test_gaussian.py
│ │ │ ├── test_inference.py
│ │ │ └── test_mean_field_entropy.py
│ │ ├── conftest.py
│ │ ├── enum_growth.ipynb
│ │ ├── mcmc/
│ │ │ ├── __init__.py
│ │ │ ├── test_adaptation.py
│ │ │ ├── test_hmc.py
│ │ │ ├── test_mcmc_api.py
│ │ │ ├── test_mcmc_util.py
│ │ │ ├── test_nuts.py
│ │ │ ├── test_rwkernel.py
│ │ │ └── test_valid_models.py
│ │ ├── reparam/
│ │ │ ├── __init__.py
│ │ │ ├── test_conjugate.py
│ │ │ ├── test_discrete_cosine.py
│ │ │ ├── test_haar.py
│ │ │ ├── test_hmm.py
│ │ │ ├── test_loc_scale.py
│ │ │ ├── test_neutra.py
│ │ │ ├── test_projected_normal.py
│ │ │ ├── test_softmax.py
│ │ │ ├── test_split.py
│ │ │ ├── test_stable.py
│ │ │ ├── test_strategies.py
│ │ │ ├── test_structured.py
│ │ │ ├── test_studentt.py
│ │ │ ├── test_transform.py
│ │ │ ├── test_unit_jacobian.py
│ │ │ └── util.py
│ │ ├── test_abstract_infer.py
│ │ ├── test_autoguide.py
│ │ ├── test_compute_downstream_costs.py
│ │ ├── test_conjugate_gradients.py
│ │ ├── test_csis.py
│ │ ├── test_discrete.py
│ │ ├── test_elbo_mapdata.py
│ │ ├── test_enum.py
│ │ ├── test_gradient.py
│ │ ├── test_inference.py
│ │ ├── test_initialization.py
│ │ ├── test_inspect.py
│ │ ├── test_jit.py
│ │ ├── test_multi_sample_elbos.py
│ │ ├── test_predictive.py
│ │ ├── test_resampler.py
│ │ ├── test_sampling.py
│ │ ├── test_smcfilter.py
│ │ ├── test_svgd.py
│ │ ├── test_tmc.py
│ │ ├── test_util.py
│ │ └── test_valid_models.py
│ ├── integration_tests/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_conjugate_gaussian_models.py
│ │ └── test_tracegraph_elbo.py
│ ├── nn/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_autoregressive.py
│ │ └── test_module.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── einsum/
│ │ │ ├── conftest.py
│ │ │ ├── test_adjoint.py
│ │ │ └── test_torch_log.py
│ │ ├── gamma_gaussian.py
│ │ ├── gaussian.py
│ │ ├── test_arrowhead.py
│ │ ├── test_contract.py
│ │ ├── test_gamma_gaussian.py
│ │ ├── test_gaussian.py
│ │ ├── test_indexing.py
│ │ ├── test_integrator.py
│ │ ├── test_jit.py
│ │ ├── test_linalg.py
│ │ ├── test_newton.py
│ │ ├── test_packed.py
│ │ ├── test_provenance.py
│ │ ├── test_special.py
│ │ ├── test_ssm_gp.py
│ │ ├── test_stats.py
│ │ ├── test_streaming.py
│ │ ├── test_tensor_utils.py
│ │ └── test_welford.py
│ ├── optim/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_multi.py
│ │ └── test_optim.py
│ ├── params/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_module.py
│ │ └── test_param.py
│ ├── perf/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ └── test_benchmark.py
│ ├── poutine/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_counterfactual.py
│ │ ├── test_mapdata.py
│ │ ├── test_nesting.py
│ │ ├── test_poutines.py
│ │ ├── test_properties.py
│ │ ├── test_runtime.py
│ │ └── test_trace_struct.py
│ ├── pyroapi/
│ │ ├── conftest.py
│ │ └── test_pyroapi.py
│ ├── test_examples.py
│ ├── test_generic.py
│ ├── test_primitives.py
│ ├── test_settings.py
│ └── test_util.py
└── tutorial/
├── Makefile
├── README.md
└── source/
├── RSA-hyperbole.ipynb
├── RSA-implicature.ipynb
├── _static/
│ ├── css/
│ │ └── pyro.css
│ └── img/
│ └── dmm.tex
├── air.ipynb
├── autoname_examples.rst
├── baseball.rst
├── bayesian_regression.ipynb
├── bayesian_regression_ii.ipynb
├── bo.ipynb
├── boosting_bbvi.ipynb
├── capture_recapture.rst
├── cevae.rst
├── cleannb.py
├── conf.py
├── contrib_funsor_intro_i.ipynb
├── contrib_funsor_intro_ii.ipynb
├── csis.ipynb
├── custom_objectives.ipynb
├── cvae.ipynb
├── dirichlet_process_mixture.ipynb
├── dkl.rst
├── dmm.ipynb
├── easyguide.ipynb
├── effect_handlers.ipynb
├── einsum.rst
├── ekf.ipynb
├── elections.ipynb
├── enumeration.ipynb
├── epi_intro.ipynb
├── epi_regional.rst
├── epi_sir.rst
├── forecast_simple.rst
├── forecasting_dlm.ipynb
├── forecasting_i.ipynb
├── forecasting_ii.ipynb
├── forecasting_iii.ipynb
├── gmm.ipynb
├── gp.ipynb
├── gplvm.ipynb
├── hmm.rst
├── hmm_funsor.rst
├── inclined_plane.rst
├── index.rst
├── intro_long.ipynb
├── intro_part_i.ipynb
├── intro_part_ii.ipynb
├── jit.ipynb
├── lda.rst
├── lkj.rst
├── logistic-growth.ipynb
├── mcmc.rst
├── minipyro.rst
├── mixed_hmm.rst
├── mle_map.ipynb
├── model_rendering.ipynb
├── modules.ipynb
├── mue_factor.rst
├── mue_profile.rst
├── neutra.rst
├── normalizing_flows_intro.ipynb
├── predictive_deterministic.ipynb
├── prior_predictive.ipynb
├── prodlda.ipynb
├── reconciling_experts.ipynb
├── scanvi.ipynb
├── search_inference.py
├── sir_hmc.rst
├── smcfilter.rst
├── sparse_gamma.rst
├── sparse_regression.rst
├── ss-vae.ipynb
├── stable.ipynb
├── svi_flow_guide.ipynb
├── svi_horovod.rst
├── svi_lightning.rst
├── svi_part_i.ipynb
├── svi_part_ii.ipynb
├── svi_part_iii.ipynb
├── svi_part_iv.ipynb
├── svi_torch.rst
├── tensor_shapes.ipynb
├── timeseries.rst
├── toy_mixture_model_discrete_enumeration.rst
├── tracking_1d.ipynb
├── vae.ipynb
├── vae_flow_prior.ipynb
├── workflow.ipynb
└── working_memory.ipynb
================================================
FILE CONTENTS
================================================
================================================
FILE: .codecov.yml
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
ignore:
- "pyro/docutil.py"
- "pyro/logger.py"
coverage:
range: 60..95
round: nearest
precision: 2
comment: false
================================================
FILE: .coveragerc
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
[report]
omit =
pyro/docutil.py
pyro/logger.py
exclude_lines =
pragma: no cover
def backward
raise AssertionError
raise NotImplementedError
raise ValueError
except NotImplementedError
except ImportError
except KeyError
except TypeError
warnings\.warn.*
warn_if.*
if __name__ == .__main__.:
================================================
FILE: .gitattributes
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
*.ipynb linguist-documentation
================================================
FILE: .github/FUNDING.yml
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
# These are supported funding model platforms
github: [fritzo] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: false
contact_links:
- name: Pyro Discussion Forum
url: https://forum.pyro.ai/
about: For general questions about Pyro, please use the forum instead of GitHub issues.
================================================
FILE: .github/ISSUE_TEMPLATE/issue_template.md
================================================
---
name: General Issue
about: Report a bug or request a feature
---
<!--
Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
### Guidelines
**NOTE:** Issues are for bugs and feature requests only. If you have a question about using Pyro or general modeling questions, please post it on the [forum](https://forum.pyro.ai/).
If you would like to address any minor bugs in the documentation or source, please feel free to contribute a Pull Request without creating an issue first.
Please tag the issue appropriately in the title e.g. [bug], [feature request], [discussion], etc.
Please provide the following details:
--------------------------------------------------------------------------------------------------
### Issue Description
Provide a brief description of the issue.
### Environment
For any bugs, please provide the following:
- OS and python version.
- PyTorch version, or if relevant, output of `pip freeze`.
- Pyro version: output of `python -c 'import pyro; print pyro.__version__'`
### Code Snippet
Provide any relevant code snippets and commands run to replicate the issue.
================================================
FILE: .github/workflows/ci.yml
================================================
name: CI
on:
push:
branches: [dev, master]
pull_request:
branches: [dev, master]
env:
CXX: g++-9
CC: gcc-9
# See coveralls-python - Github Actions support:
# https://github.com/TheKevJames/coveralls-python/blob/master/docs/usage/configuration.rst#github-actions-support
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COVERALLS_SERVICE_NAME: github
jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel setuptools
pip install ruff black mypy nbstripout nbformat
- name: Lint
run: |
make lint
docs:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-9 g++-9 ninja-build graphviz
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install -r docs/requirements.txt
pip install --upgrade sphinx docutils sphinx-rtd-theme
pip freeze
- name: Build docs and run doctest
run: |
make docs
make doctest
tutorials-html:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-9 g++-9 ninja-build graphviz pandoc
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install -r docs/requirements.txt
# requirements for tutorials (from .[dev])
sudo apt-get install pandoc
pip install nbformat
pip install nbsphinx>=0.3.2
pip install nbstripout
pip install pypandoc
pip install ninja
pip freeze
- name: Build HTML from tutorials
run: |
SPHINXOPTS="-E" make tutorial
unit:
runs-on: ubuntu-latest
needs: docs
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install --upgrade coveralls
pip freeze
- name: Run unit tests
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage unit --durations 20
- name: Submit to coveralls
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
examples:
runs-on: ubuntu-latest
needs: [docs, tutorials-html]
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v1
- name: Install dependencies
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install --upgrade coveralls
pip freeze
- name: Run examples
run: |
CI=1 pytest -vs --cov=pyro --cov-config .coveragerc --stage test_examples --durations 10
grep -l smoke_test tutorial/source/*.ipynb | xargs grep -L 'smoke_test = False' \
| CI=1 xargs pytest -vx --nbval-lax --current-env
- name: Submit to coveralls
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
integration_1:
runs-on: ubuntu-latest
needs: docs
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install --upgrade coveralls
pip freeze
- name: Run integration test (batch 1)
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_1 --durations 10
- name: Submit to coveralls
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
integration_2:
runs-on: ubuntu-latest
needs: docs
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install --upgrade coveralls
pip freeze
- name: Run integration test (batch 2)
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_2 --durations 10
- name: Submit to coveralls
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
funsor:
runs-on: ubuntu-latest
needs: docs
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-9 g++-9 ninja-build
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install -e .[funsor]
pip install --upgrade coveralls
pip freeze
- name: Run funsor tests
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage funsor --durations 10
CI=1 pytest -vs --cov=pyro --cov-config .coveragerc --stage test_examples --durations 10 -k funsor
- name: Submit to coveralls
run: coveralls --service=github || true
env:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.test-name }}
finish:
needs: [unit, examples, integration_1, integration_2, funsor]
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Coveralls Finished
run: |
pip install --upgrade coveralls
coveralls --service=github --finish || true
================================================
FILE: .github/workflows/publish.yml
================================================
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
# GitHub recommends pinning actions to a commit SHA.
# To get a newer version, you will need to update the SHA.
# You can also reference a tag or branch, but the action may change without warning.
name: Upload Python Package
on:
release:
types: [published]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .gitignore
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
run_outputs*
.DS_Store
.benchmarks
data
.data
results
*.csv
examples/*/processed
examples/*/results
examples/*/raw
examples/dmm/*.pkl
examples/*.csv
examples/*.pdf
examples/mixed_hmm/*csv
pyro/contrib/examples/processed
pyro/contrib/examples/results
pyro/contrib/examples/raw
pyro/_version.py
processed
raw
*.pkl
*.fasta
baseline_net_q1.pth
cvae_net_q1.pth
cvae_plot_q1.png
prep_seal_data.csv
results.csv
tutorial/source/model.pdf
# Logs
logs
*.log
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# dotenv environment variables file
.env
#ignore python
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
.pytest_cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
*.out
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# IDE settings
.spyderproject
.idea
.vscode
# Rope project settings
.ropeproject
# tmp files
*.swn
*.swo
*.swp
*~
# mypy cache
.mypy_cache
================================================
FILE: .readthedocs.yml
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
# Required
version: 2
build:
os: ubuntu-20.04
tools:
python: "3.8"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats: all
# Optionally set the version of Python and requirements required to build your docs
python:
install:
- requirements: docs/requirements.txt
================================================
FILE: CODE_OF_CONDUCT.md
================================================
<!--
Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at fritz.obermeyer@gmail.com or fehiepsi@gmail.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version]
[homepage]: http://contributor-covenant.org
[version]: http://contributor-covenant.org/version/1/4/
================================================
FILE: CONTRIBUTING.md
================================================
<!--
Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
# Development
Please follow our established coding style including variable names, module imports, and function definitions.
The Pyro codebase follows the [PEP8 style guide](https://www.python.org/dev/peps/pep-0008/)
(which you can check with `make lint`) and follows
[`isort`](https://github.com/timothycrosley/isort) import order (which you can enforce with `make format`).
When creating new files please add a license header; this can be done automatically via `make license` or simply `make format`.
# Setup
First install [PyTorch](http://pytorch.org/).
Then, install all the dev dependencies for Pyro.
```sh
make install
```
or explicitly
```sh
pip install -e .[dev]
```
# Testing
Before submitting a pull request, please autoformat code and ensure that unit tests pass locally
```sh
make format # runs isort
make test # linting and unit tests
```
If you've modified core pyro code, examples, or tutorials, you can run more comprehensive tests locally (after first adding any new files to the appropriate `tests/` script)
```sh
make test-examples # test examples/
make integration-test # longer-running tests (may take hours)
make test-cuda # runs unit tests in cuda mode
```
To run all tests locally in parallel, use the `pytest-xdist` package
```sh
pip install pytest-xdist
pytest -vs -n auto
```
To run a single test from the command line
```sh
pytest -vs {path_to_test}::{test_name}
# or in cuda mode
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vs {path_to_test}::{test_name}
```
To ensure documentation builds correctly, run
```sh
make docs
```
## Testing Tutorials
We run some tutorials on travis to avoid bit rot.
Before submitting a new tutorial, please run `make scrub` from
the top-level pyro directory in order to scrub the metadata in
the notebooks.
To enable a tutorial for testing
1. Add a line `smoke_test = ('CI' in os.environ)` to your tutorial. Our test
scripts only test tutorials that contain the string `smoke_test`.
2. Each time you do something expensive for many iterations, set the number
of iterations like this:
```py
for epoch in range(200 if not smoke_test else 1):
...
```
You can test locally by running `make test-tutorials`.
# Profiling
The profiler module contains scripts to support profiling different
Pyro modules, as well as test for performance regression.
To run the profiling utilities, ensure that all dependencies for profiling are satisfied,
by running `make install`, or more specifically, `pip install -e .[profile]`.
There are some generic test cases available in the `profiler` module. Currently, this supports
only the `distributions` library, but we will be adding test cases for inference methods
soon.
#### Some useful invocations
To get help on the parameters that the profiling script takes, run:
```sh
python -m profiler.distributions --help
```
To run the profiler on all the distributions, simply run:
```sh
python -m profiler.distributions
```
To run the profiler on a few distributions by varying the batch size, run:
```sh
python -m profiler.distributions --dist bernoulli normal --batch_sizes 1000 100000
```
To get more details on the potential sources of slowdown, use the `cProfile` tool
as follows:
```sh
python -m profiler.distributions --dist bernoulli --tool cprofile
```
# Submitting
For larger changes, please open an issue for discussion before submitting a pull request.
For relevant design questions to consider, see past
[design documents](https://github.com/pyro-ppl/pyro/wiki/Design-Docs).
In your pull request description on github, please note:
- Proposed changes
- Links to related issues/PRs
- New and existing tests
Before submitting, please run `make format`, `make lint`, and running tests as described above.
For speculative changes meant for early-stage review, include `[WIP]` in the PR's title.
(One of the maintainers will add the `WIP` tag.)
================================================
FILE: LICENSE.md
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: LICENSES/Apache-2.0.txt
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: LICENSES/BSD-3-Clause.txt
================================================
Copyright (c) <year> <owner>.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: LICENSES/MIT.txt
================================================
MIT License
Copyright (c) <year> <copyright holders>
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
================================================
FILE: MANIFEST.in
================================================
include LICENSE.md MANIFEST.in
recursive-include pyro *.cpp
================================================
FILE: Makefile
================================================
.PHONY: all install docs lint format test integration-test clean FORCE
all: docs test
install: FORCE
pip install -e .[dev,profile] --config-settings editable_mode=strict
uninstall: FORCE
pip uninstall pyro-ppl
docs: FORCE
$(MAKE) -C docs html
apidoc: FORCE
$(MAKE) -C docs apidoc
tutorial: FORCE
$(MAKE) -C tutorial html
lint: FORCE
ruff check .
black --check *.py pyro examples tests scripts profiler
python scripts/update_headers.py --check
mypy --install-types --non-interactive pyro scripts tests
license: FORCE
python scripts/update_headers.py
format: license FORCE
ruff check --fix .
black *.py pyro examples tests scripts profiler
version: FORCE
python scripts/update_version.py
scrub: FORCE
find tutorial -name "*.ipynb" | xargs python -m nbstripout --keep-output --keep-count
find tutorial -name "*.ipynb" | xargs python tutorial/source/cleannb.py
doctest: FORCE
# We skip testing pyro.distributions.torch wrapper classes because
# they include torch docstrings which are tested upstream.
python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py
perf-test: FORCE
bash scripts/perf_test.sh ${ref}
profile: ref=dev
profile: FORCE
bash scripts/profile_model.sh ${ref} ${models}
test: lint docs doctest FORCE
pytest -vx -n auto --stage unit
test-examples: lint FORCE
pytest -vx --stage test_examples
test-tutorials: lint FORCE
grep -l smoke_test tutorial/source/*.ipynb | xargs grep -L 'smoke_test = False' \
| CI=1 xargs pytest -vx --nbval-lax --current-env
integration-test: lint FORCE
pytest -vx -n auto --stage integration
test-all: lint FORCE
pytest -vx -n auto
CI=1 grep -l smoke_test tutorial/source/*.ipynb \
| xargs pytest -vx --nbval-lax
test-cuda: lint FORCE
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vx --stage unit
CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda
test-cuda-lax: lint FORCE
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vx --stage unit --lax
CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda
test-jit: FORCE
@echo See jit.log
pytest -v -n auto --tb=short --runxfail tests/infer/test_jit.py tests/test_examples.py::test_jit | tee jit.log
pytest -v -n auto --tb=short --runxfail tests/infer/mcmc/test_hmc.py tests/infer/mcmc/test_nuts.py \
-k JIT=True | tee -a jit.log
test-funsor: lint FORCE
pytest -vx -n auto --stage funsor
clean: FORCE
git clean -dfx -e pyro_ppl.egg-info
FORCE:
================================================
FILE: README.md
================================================
<!--
Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
<div align="center">
<a href="http://pyro.ai"> <img width="220px" height="220px" src="docs/source/_static/img/pyro_logo_with_text.png"></a>
</div>
-----------------------------------------
[](https://github.com/pyro-ppl/pyro/actions)
[](https://coveralls.io/github/pyro-ppl/pyro?branch=dev)
[](https://pypi.python.org/pypi/pyro-ppl)
[](http://pyro-ppl.readthedocs.io/en/stable/?badge=dev)
[](https://bestpractices.coreinfrastructure.org/projects/3056)
[Getting Started](http://pyro.ai/examples) |
[Documentation](http://docs.pyro.ai/) |
[Community](http://forum.pyro.ai/) |
[Contributing](https://github.com/pyro-ppl/pyro/blob/master/CONTRIBUTING.md)
Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. Notably, it was designed with these principles in mind:
- **Universal**: Pyro is a universal PPL - it can represent any computable probability distribution.
- **Scalable**: Pyro scales to large data sets with little overhead compared to hand-written code.
- **Minimal**: Pyro is agile and maintainable. It is implemented with a small core of powerful, composable abstractions.
- **Flexible**: Pyro aims for automation when you want it, control when you need it. This is accomplished through high-level abstractions to express generative and inference models, while allowing experts easy-access to customize inference.
Pyro was originally developed at Uber AI and is now actively maintained by community contributors, including a dedicated team at the [Broad Institute](https://www.broadinstitute.org/).
In 2019, Pyro [became](https://www.linuxfoundation.org/press-release/2019/02/pyro-probabilistic-programming-language-becomes-newest-lf-deep-learning-project/) a project of the Linux Foundation, a neutral space for collaboration on open source software, open standards, open data, and open hardware.
For more information about the high level motivation for Pyro, check out our [launch blog post](http://eng.uber.com/pyro).
For additional blog posts, check out work on [experimental design](https://eng.uber.com/oed-pyro-release/) and
[time-to-event modeling](https://eng.uber.com/modeling-censored-time-to-event-data-using-pyro/) in Pyro.
## Installing
### Installing a stable Pyro release
**Install using pip:**
```sh
pip install pyro-ppl
```
**Install from source:**
```sh
git clone git@github.com:pyro-ppl/pyro.git
cd pyro
git checkout master # master is pinned to the latest release
pip install .
```
**Install with extra packages:**
To install the dependencies required to run the probabilistic models included in the `examples`/`tutorials` directories, please use the following command:
```sh
pip install pyro-ppl[extras]
```
Make sure that the models come from the same release version of the [Pyro source code](https://github.com/pyro-ppl/pyro/releases) as you have installed.
### Installing Pyro dev branch
For recent features you can install Pyro from source.
**Install Pyro using pip:**
```sh
pip install git+https://github.com/pyro-ppl/pyro.git
```
or, with the `extras` dependency to run the probabilistic models included in the `examples`/`tutorials` directories:
```sh
pip install git+https://github.com/pyro-ppl/pyro.git#egg=project[extras]
```
**Install Pyro from source:**
```sh
git clone https://github.com/pyro-ppl/pyro
cd pyro
pip install . # pip install .[extras] for running models in examples/tutorials
```
## Running Pyro from a Docker Container
Refer to the instructions [here](docker/README.md).
## Citation
If you use Pyro, please consider citing:
```
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}
```
================================================
FILE: RELEASE-MANAGEMENT.md
================================================
<!--
Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
# Pyro release management
This describes the process by which versions of Pyro are officially released to the public.
## Versioning
Releases are versioned according to the `version_prefix` constant in [pyro/__init__.py](pyro/__init__.py).
Pyro releases follow semantic versioning with the following caveats:
- Behavior of documented APIs will remain stable across minor releases, except for bug fixes and features marked EXPERIMENTAL or DEPRECATED.
- Serialization formats will remain stable across patch releases, but may change across minor releases (e.g. if you save a model in 1.0.0, it will be safe to load it in 1.0.1, but not in 1.1.0).
- Undocumented APIs, features marked EXPERIMENTAL or DEPRECATED, and anything in `pyro.contrib` may change at any time (though we aim for stability).
- All deprecated features throw a `FutureWarning` and specify possible work-arounds. Features marked as deprecated will not be maintained, and are likely to be removed in a future release.
- If you want more stability for a particular feature, [contribute](https://github.com/pyro-ppl/pyro/blob/dev/CONTRIBUTING.md) a unit test.
## Release process
Pyro is released at irregular cadence, typically about 4 times per year.
Releases are managed by:
- [Fritz Obermeyer](https://github.com/fritzo)
- [Neeraj Pradhan](https://github.com/neerajprad)
- [JP Chen](https://github.com/jpchen)
Releases and release notes are published to [github](https://github.com/pyro-ppl/pyro/releases).
Documentation for is published to [readthedocs](https://docs.pyro.ai).
Release builds are published to [pypi](https://pypi.org/project/pyro-ppl/).
================================================
FILE: docker/Dockerfile
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
ARG base_img=ubuntu:24.04
FROM ${base_img}
# Optional args
ARG python_version=3
ARG pyro_branch=release
ARG pytorch_whl=cpu
ARG pytorch_branch=release
ARG uid=1000
ARG gid=1000
ARG ostype=Linux
ARG pyro_git_url=https://github.com/pyro-ppl/pyro.git
ARG trust_hosts=no
# Configurable settings
ENV USER_NAME pyromancer
ENV CONDA_DIR /opt/conda
ENV WORK_DIR /home/${USER_NAME}/workspace
ENV PATH ${CONDA_DIR}/bin:${PATH}
# Install linux utils
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
wget \
ca-certificates && \
rm -rf /var/lib/apt/lists/*
# Change to default user
RUN bash -c 'if [ ${ostype} == Linux ]; then groupadd -r --gid ${gid} ${USER_NAME}; fi && \
useradd -r --create-home --shell /bin/bash --uid ${uid} --gid ${gid} ${USER_NAME}' && \
mkdir -p ${CONDA_DIR} ${WORK_DIR} && chown ${USER_NAME} ${CONDA_DIR} ${WORK_DIR}
USER ${USER_NAME}
# Install conda
RUN if [ ${trust_hosts} = yes ] ; then WGET_ARGS="--no-check-certificate" ; fi && \
wget ${WGET_ARGS} -O ~/miniconda.sh \
https://repo.anaconda.com/miniconda/Miniconda${python_version%%.*}-latest-Linux-x86_64.sh && \
bash ~/miniconda.sh -f -b -p ${CONDA_DIR} && \
rm ~/miniconda.sh
# Trust conda and pip hosts if needed
RUN if [ ${trust_hosts} = yes ] ; \
then \
pip config set global.trusted-host "pypi.org files.pythonhosted.org download.pytorch.org" && \
conda config --set ssl_verify False ; \
fi
# Update python version
RUN conda install python=${python_version}
# Move to home directory; and copy the install script
WORKDIR ${WORK_DIR}
COPY install.sh ${WORK_DIR}/install.sh
# Install python 2/3, PyTorch and Pyro
RUN cd ${WORK_DIR} && conda update -n base conda -c defaults && bash install.sh
# Run Jupyter notebook
# (Ref: http://jupyter-notebook.readthedocs.io/en/latest/public_server.html#docker-cmd)
EXPOSE 8888
CMD ["jupyter", "notebook", "--port=8888", "--no-browser", "--ip=0.0.0.0"]
================================================
FILE: docker/Makefile
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
.PHONY: help create-host-workspace build build-gpu run run-gpu notebook notebook-gpu
DOCKER_FILE=Dockerfile
BASE_IMG=ubuntu:24.04
BASE_CUDA_IMG=nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04
DOCKER_CMD=docker
HOST_WORK_DIR=${HOME}/pyro_docker
UID=999
GID=999
OSTYPE=$(shell uname)
USER=pyromancer
DOCKER_WORK_DIR=/home/${USER}/workspace/shared
pyro_git_url=https://github.com/pyro-ppl/pyro.git
# Optional args
python_version?=3.12
pytorch_branch?=release
pyro_branch?=release
cmd?=bash
trust_hosts?="no"
# Determine name of docker image
build run notebook: img_prefix=pyro-cpu
build-gpu run-gpu notebook-gpu: img_prefix=pyro-gpu
build run lab: img_prefix=pyro-cpu
build-gpu run-gpu lab-gpu: img_prefix=pyro-gpu
ifeq ($(img), )
IMG_NAME=${img_prefix}-${pyro_branch}-${pytorch_branch}-${python_version}
else
IMG_NAME=${img}
endif
help:
@fgrep -h "##" ${MAKEFILE_LIST} | fgrep -v fgrep | sed -e 's/##//'
##
##Available targets:
##
build: ##
## Build a docker image for running Pyro on a CPU.
## Args:
## python_version: version of python to use. default - python 3.12
## pytorch_branch: whether to build PyTorch from conda or from source
## (git branch specified by pytorch_branch)
## default - latest pytorch version on the torch python package index
## pyro_branch: whether to use the released Pyro wheel or a git branch.
## default - latest pyro-ppl wheel on pypi
## trust_hosts: If set to yes hosts SSL ceritificates will be trusted
## (might be needed when running begind a firewall)
## default - Verify hosts SSL certificates
##
${DOCKER_CMD} build -t ${IMG_NAME} \
--build-arg base_img=${BASE_IMG} \
--build-arg uid=${UID} \
--build-arg gid=${GID} \
--build-arg ostype=${OSTYPE} \
--build-arg python_version=${python_version} \
--build-arg pytorch_branch=${pytorch_branch} \
--build-arg pyro_git_url=${pyro_git_url} \
--build-arg pyro_branch=${pyro_branch} \
--build-arg trust_hosts=${trust_hosts} -f ${DOCKER_FILE} .
build-gpu: ##
## Build a docker image for running Pyro on a GPU.
## Args:
## python_version: version of python to use. default - python 3.12
## pytorch_branch: whether to build PyTorch from conda or from source
## (git branch specified by pytorch_branch)
## default - latest pytorch version on the torch python package index
## pyro_branch: whether to use the released Pyro wheel or a git branch.
## default - latest pyro-ppl wheel on pypi
## trust_hosts: If set to yes hosts SSL ceritificates will be trusted
## (might be needed when running begind a firewall)
## default - Verify hosts SSL certificates
##
${DOCKER_CMD} build -t ${IMG_NAME} \
--build-arg base_img=${BASE_CUDA_IMG} \
--build-arg uid=${UID} \
--build-arg gid=${GID} \
--build-arg ostype=${OSTYPE} \
--build-arg pytorch_whl=cu126 \
--build-arg python_version=${python_version} \
--build-arg pytorch_branch=${pytorch_branch} \
--build-arg pyro_git_url=${pyro_git_url} \
--build-arg pyro_branch=${pyro_branch} \
--build-arg trust_hosts=${trust_hosts} -f ${DOCKER_FILE} .
create-host-workspace: ##
## Create shared volume on the host for sharing files with the container.
##
mkdir -p ${HOST_WORK_DIR}
run: create-host-workspace
run: ##
## Start a Pyro CPU docker instance, and run the command `cmd`.
## Args:
## img: use image name given by `img`.
## cmd: command invoked on running a docker instance.
## default - bash
##
docker run --init -it --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME} ${cmd}
run-gpu: create-host-workspace
run-gpu: ##
## Start a Pyro GPU docker instance, and run the command `cmd`.
## Args:
## img: use image name given by `img`.
## cmd: command invoked on running a docker instance.
## default - bash
##
docker run --init --gpus=all -it --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME} ${cmd}
notebook: create-host-workspace
notebook: ##
## Start a jupyter notebook on the Pyro CPU docker container.
## Args:
## img: use image name given by `img`.
##
docker run --init -it -p 8888:8888 --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME}
notebook-gpu: create-host-workspace
notebook-gpu: ##
## Start a jupyter notebook on the Pyro GPU docker container.
## Args:
## img: use image name given by `img`.
##
docker run --gpus=all --init -it -p 8888:8888 --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME}
notebook: create-host-workspace
lab: ##
## Start jupyterlab on the Pyro CPU docker container.
## Args:
## img: use image name given by `img`.
##
docker run --init -it -p 8888:8888 --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME} jupyter lab --port=8888 --no-browser --ip=0.0.0.0
lab-gpu: create-host-workspace
lab-gpu: ##
## Start jupyterlab on the Pyro GPU docker container.
## Args:
## img: use image name given by `img`.
##
docker run --gpus=all --init -it -p 8888:8888 --user ${USER} \
-v ${HOST_WORK_DIR}:${DOCKER_WORK_DIR} \
${IMG_NAME} jupyter lab --port=8888 --no-browser --ip=0.0.0.0
================================================
FILE: docker/README.md
================================================
<!--
Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
## Using Pyro Docker
Some utilities for building docker images and running Pyro inside a Docker container are
included in the `docker` directory. This includes a Dockerfile to build PyTorch and Pyro,
with some common recipes included in the Makefile.
Dependencies for building the docker images:
- **docker** (>= version 17.05)
- **nvidia-docker** Refer to the [readme](https://github.com/NVIDIA/nvidia-docker) for
installation.
### Building Images
The Makefile can be used to build CPU and CUDA images for Pyro and PyTorch. Some common
options are as follows:
1. **Source:** Uses the latest released package (conda package for PyTorch and PyPi wheel
for Pyro) by default. However, both Pyro and PyTorch can be built from source from the
master branch or any other arbitrary branch specified by `pytorch_branch` and
`pyro_branch`.
2. **CPU / CUDA:** `make build` or `make build-gpu` can be used to specify whether the CPU
or the CUDA image is to be built. For building the CUDA image, *nvidia-docker* is
required.
3. **Python Version:** Python version can be specified via the argument `python_version`.
For example, the `make` command to build an image that uses Pyro's `dev` branch over
PyTorch's `master` branch, using python 3.6 to run on a GPU, is as follows:
```sh
make build-gpu pyro_branch=dev pytorch_branch=master python_version=3.6
```
This will build an image named `pyro-gpu-dev-3.6`. To spin up a docker container from this
image, and run jupyter notebook on this, use the following `make` command:
```sh
make notebook-gpu img=pyro-gpu-dev-3.6
```
For help on the `make` commands available, run `make help`.
**NOTE (Mac Users)**: Please increase the memory available to the Docker application
via *Preferences --> Advanced* from 2GB (default) to at least 4GB prior to building the
docker image (specially for building PyTorch from source).
### Running the Docker container
Once the image is built, the docker container can be started via `make run`, or
`make run-gpu`. By default this starts a *bash* shell. One could start an *ipython*
shell instead by running `make run cmd=ipython`. The image to be used can be
specified via the argument `img`.
To run a *jupyter notebook* use `make notebook`, or `make notebook-gpu`. This will
start a jupyter notebook server which can be accessed from the browser using the link
mentioned in the terminal.
Note that there is a shared volume between the container and the host system, with the
location `$DOCKER_WORK_DIR` on the container, and `$HOST_WORK_DIR` on the local system.
These variables can be configured in the `Makefile`.
================================================
FILE: docker/install.sh
================================================
#!/usr/bin/env bash
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
set -xe
pip install --upgrade pip
pip install notebook ipywidgets matplotlib
# 1. Install PyTorch
# Use conda package if pytorch_branch = 'release'.
# Else, install from source, using git branch `pytorch_branch`
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/${pytorch_whl}
if [ ${pytorch_branch} != "release" ]
then
git clone --recursive https://github.com/pytorch/pytorch.git
pushd pytorch && git checkout ${pytorch_branch}
pip uninstall -y torch
conda install cmake ninja
pip install -r requirements.txt
pip install mkl-static mkl-include
if [ ${pytorch_whl} != "cpu" ]
then
conda install -c pytorch magma-cuda${pytorch_whl:2}
fi
pip install -e .
popd
fi
# 2. Install Pyro
# Use pypi wheel if pyro_branch = 'release'.
# Else, install from source, using git branch `pyro_branch`
if [ ${pyro_branch} = "release" ]
then
pip install pyro-ppl
else
git clone ${pyro_git_url}
(cd pyro && git checkout ${pyro_branch} && pip install -e .[dev])
fi
================================================
FILE: docs/Makefile
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS ?= -E -W
SPHINXBUILD = python -msphinx
APIDOC = sphinx-apidoc
SPHINXPROJ = Pyro
SOURCEDIR = source
PROJECTDIR = ../pyro
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
apidoc:
$(APIDOC) -o "$(SOURCEDIR)" "$(PROJECTDIR)"
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/README.md
================================================
<!--
Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
# Documentation #
Pyro Documentation is supported by [Sphinx](http://www.sphinx-doc.org/en/stable/).
To build the docs, run from the toplevel directory:
```
make docs
```
## Installation ##
```
pip install -r requirements.txt
```
Note that you will need to install [graphviz](https://www.graphviz.org/) separately.
## Workflow ##
To change the documentation, update the `*.rst` files in `source`.
To build the docstrings, `sphinx-apidoc [options] -o <output_path> <module_path> [exclude_pattern, ...]`
To build the html pages, `make html`
================================================
FILE: docs/requirements.txt
================================================
# Copyright Contributors to the Pyro project.
#
# SPDX-License-Identifier: Apache-2.0
sphinx==4.2.0
sphinx-rtd-theme==1.0.0
graphviz>=0.8
numpy>=1.7
observations>=0.1.4
opt_einsum>=2.3.2
pyro-api>=0.1.1
tqdm>=4.36
funsor[torch]
setuptools
sphinx_copybutton
================================================
FILE: docs/source/_static/css/pyro.css
================================================
/*
* Copyright Contributors to the Pyro project.
*
* SPDX-License-Identifier: Apache-2.0
*/
@import url("theme.css");
.wy-side-nav-search {
background-color: #565656;
}
.wy-side-nav-search a {
margin: 0
}
.wy-side-nav-search > div.version {
color: #f26822;
}
.wy-nav-top {
background: #404040;
}
.wy-menu-vertical li.on a, .wy-menu-vertical li.current>a {
background: #ccc;
}
.wy-side-nav-search input[type=text] {
border-color: #313131;
}
.wy-side-nav-search>a img.logo, .wy-side-nav-search .wy-dropdown>a img.logo {
max-width: 60%;
}
================================================
FILE: docs/source/_static/img/favicon/browserconfig.xml
================================================
<?xml version="1.0" encoding="utf-8"?>
<!--
Copyright Contributors to the Pyro project.
SPDX-License-Identifier: Apache-2.0
-->
<browserconfig><msapplication><tile><square70x70logo src="/ms-icon-70x70.png"/><square150x150logo src="/ms-icon-150x150.png"/><square310x310logo src="/ms-icon-310x310.png"/><TileColor>#ffffff</TileColor></tile></msapplication></browserconfig>
================================================
FILE: docs/source/_static/img/favicon/manifest.json
================================================
{
"name": "App",
"icons": [
{
"src": "\/android-icon-36x36.png",
"sizes": "36x36",
"type": "image\/png",
"density": "0.75"
},
{
"src": "\/android-icon-48x48.png",
"sizes": "48x48",
"type": "image\/png",
"density": "1.0"
},
{
"src": "\/android-icon-72x72.png",
"sizes": "72x72",
"type": "image\/png",
"density": "1.5"
},
{
"src": "\/android-icon-96x96.png",
"sizes": "96x96",
"type": "image\/png",
"density": "2.0"
},
{
"src": "\/android-icon-144x144.png",
"sizes": "144x144",
"type": "image\/png",
"density": "3.0"
},
{
"src": "\/android-icon-192x192.png",
"sizes": "192x192",
"type": "image\/png",
"density": "4.0"
}
]
}
================================================
FILE: docs/source/conf.py
================================================
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import os
import sys
# import pkg_resources
# -*- coding: utf-8 -*-
#
# Pyro documentation build configuration file, created by
# sphinx-quickstart on Thu Jun 15 17:16:14 2017.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
sys.path.insert(0, os.path.abspath("../.."))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.intersphinx", #
"sphinx.ext.todo", #
"sphinx.ext.mathjax", #
"sphinx.ext.ifconfig", #
"sphinx.ext.viewcode", #
"sphinx.ext.githubpages", #
"sphinx.ext.graphviz", #
"sphinx.ext.autodoc",
"sphinx.ext.doctest",
'sphinx.ext.napoleon',
"sphinx_copybutton",
]
# Disable documentation inheritance so as to avoid inheriting
# docstrings in a different format, e.g. when the parent class
# is a PyTorch class.
autodoc_inherit_docstrings = False
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = ".rst"
# The master toctree document.
master_doc = "index"
# General information about the project.
project = u"Pyro"
copyright = u"2017-2018, Uber Technologies, Inc"
author = u"Uber AI Labs"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
version = ""
if "READTHEDOCS" not in os.environ:
# if developing locally, use pyro.__version__ as version
from pyro import __version__ # noqa: E402
version = __version__
# release version
release = version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
# do not prepend module name to functions
add_module_names = False
# -- Options for HTML output ----------------------------------------------
# logo
html_logo = "_static/img/pyro_logo_wide.png"
# logo
html_favicon = "_static/img/favicon/favicon.ico"
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
html_theme_options = {
"navigation_depth": 3,
"logo_only": True,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
html_style = "css/pyro.css"
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "Pyrodoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "Pyro.tex", u"Pyro Documentation", u"Uber AI Labs", "manual"),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "pyro", u"Pyro Documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(
master_doc,
"Pyro",
u"Pyro Documentation",
author,
"Pyro",
"Deep Universal Probabilistic Programming.",
"Miscellaneous",
),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"torch": ("https://pytorch.org/docs/main/", None),
"funsor": ("https://funsor.pyro.ai/en/stable/", None),
"opt_einsum": ("https://optimized-einsum.readthedocs.io/en/stable/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"Bio": ("https://biopython.org/docs/latest/", None),
"horovod": ("https://horovod.readthedocs.io/en/stable/", None),
"graphviz": ("https://graphviz.readthedocs.io/en/stable/", None),
}
# document class constructors (__init__ methods):
""" comment out this functionality for now;
def skip(app, what, name, obj, skip, options):
if name == "__init__":
return False
return skip
"""
def setup(app):
app.add_css_file("css/pyro.css")
# app.connect("autodoc-skip-member", skip)
# @jpchen's hack to get rtd builder to install latest pytorch
# See similar line in the install section of .travis.yml
if "READTHEDOCS" in os.environ:
os.system("pip install numpy")
os.system(
"pip install torch==2.0+cpu torchvision==0.15.0+cpu "
"-f https://download.pytorch.org/whl/torch_stable.html"
)
================================================
FILE: docs/source/contrib.autoname.rst
================================================
Automatic Name Generation
==========================
.. automodule:: pyro.contrib.autoname
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Named Data Structures
---------------------
.. automodule:: pyro.contrib.autoname.named
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Scoping
-------
.. automodule:: pyro.contrib.autoname.scoping
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.bnn.rst
================================================
Bayesian Neural Networks
=========================
.. automodule:: pyro.contrib.bnn
HiddenLayer
-------------------------
.. automodule:: pyro.contrib.bnn.hidden_layer
:members:
:member-order: bysource
================================================
FILE: docs/source/contrib.cevae.rst
================================================
Causal Effect VAE
=================
.. automodule:: pyro.contrib.cevae
CEVAE Class
-----------
.. autoclass:: pyro.contrib.cevae.CEVAE
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
CEVAE Components
----------------
.. autoclass:: pyro.contrib.cevae.Model
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.Guide
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.TraceCausalEffect_ELBO
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Utilities
---------
.. autoclass:: pyro.contrib.cevae.FullyConnected
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.DistributionNet
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.BernoulliNet
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.ExponentialNet
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.LaplaceNet
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.NormalNet
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.StudentTNet
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.contrib.cevae.DiagNormalNet
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.easyguide.rst
================================================
Easy Custom Guides
==================
.. automodule:: pyro.contrib.easyguide
EasyGuide
---------
.. autoclass:: pyro.contrib.easyguide.EasyGuide
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
easy_guide
----------
.. autofunction:: pyro.contrib.easyguide.easy_guide
Group
-----
.. autoclass:: pyro.contrib.easyguide.easyguide.Group
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.epidemiology.rst
================================================
Epidemiology
============
.. automodule:: pyro.contrib.epidemiology
.. warning:: Code in ``pyro.contrib.epidemiology`` is under development.
This code makes no guarantee about maintaining backwards compatibility.
``pyro.contrib.epidemiology`` provides a modeling language for a class of
stochastic discrete-time discrete-count compartmental models. This module
implements black-box **inference** (both Stochastic Variational Inference and
Hamiltonian Monte Carlo), **prediction** of latent variables, and
**forecasting** of future trajectories.
For example usage see the following tutorials:
- `Introduction <http://pyro.ai/examples/epi_intro.html>`_
- `Univariate models <http://pyro.ai/examples/epi_sir.html>`_
- `Regional models <http://pyro.ai/examples/epi_regional.html>`_
- `Inference via auxiliary variable HMC <http://pyro.ai/examples/sir_hmc.html>`_
Base Compartmental Model
------------------------
.. automodule:: pyro.contrib.epidemiology.compartmental
:members:
:show-inheritance:
:member-order: bysource
Example Models
--------------
.. automodule:: pyro.contrib.epidemiology.models
Distributions
-------------
.. automodule:: pyro.contrib.epidemiology.distributions
:members:
:show-inheritance:
:member-order: bysource
.. autoclass:: pyro.distributions.CoalescentRateLikelihood
:members:
:show-inheritance:
:member-order: bysource
:special-members: __call__
.. autofunction:: pyro.distributions.coalescent.bio_phylo_to_times
================================================
FILE: docs/source/contrib.examples.rst
================================================
Pyro Examples
=============
Datasets
--------
Multi MNIST
~~~~~~~~~~~
.. automodule:: pyro.contrib.examples.multi_mnist
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
BART Ridership
~~~~~~~~~~~~~~
.. automodule:: pyro.contrib.examples.bart
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Nextstrain SARS-CoV-2 counts
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: pyro.contrib.examples.nextstrain
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Utilities
---------
.. automodule:: pyro.contrib.examples.util
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.forecast.rst
================================================
Forecasting
===========
.. automodule:: pyro.contrib.forecast
``pyro.contrib.forecast`` is a lightweight framework for experimenting with a
restricted class of time series models and inference algorithms using familiar
Pyro modeling syntax and PyTorch neural networks.
Models include hierarchical multivariate heavy-tailed time series of ~1000 time
steps and ~1000 separate series. Inference combines subsample-compatible
variational inference with Gaussian variable elimination based on the
:class:`~pyro.distributions.GaussianHMM` class. Inference using Hamiltonian Monte Carlo
sampling is also supported with :class:`~pyro.contrib.forecast.forecaster.HMCForecaster`.
Forecasts are in the form of joint posterior samples at multiple future time steps.
Hierarchical models use the familiar :class:`~pyro.plate` syntax for
general hierarchical modeling in Pyro. Plates can be subsampled, enabling
training of joint models over thousands of time series. Multivariate
observations are handled via multivariate likelihoods like
:class:`~pyro.distributions.MultivariateNormal`, :class:`~pyro.distributions.GaussianHMM`, or
:class:`~pyro.distributions.LinearHMM`. Heavy tailed models are possible by
using :class:`~pyro.distributions.StudentT` or
:class:`~pyro.distributions.Stable` likelihoods, possibly together with
:class:`~pyro.distributions.LinearHMM` and reparameterizers including
:class:`~pyro.infer.reparam.studentt.StudentTReparam`,
:class:`~pyro.infer.reparam.stable.StableReparam`, and
:class:`~pyro.infer.reparam.hmm.LinearHMMReparam`.
Seasonality can be handled using the helpers
:func:`~pyro.ops.tensor_utils.periodic_repeat`,
:func:`~pyro.ops.tensor_utils.periodic_cumsum`, and
:func:`~pyro.ops.tensor_utils.periodic_features`.
See :mod:`pyro.contrib.timeseries` for ways to construct temporal Gaussian processes useful as likelihoods.
For example usage see:
- The `univariate forecasting tutorial <http://pyro.ai/examples/forecasting_i.html>`_
- The `state space modeling tutorial <http://pyro.ai/examples/forecasting_ii.html>`_
- The `hierarchical forecasting tutorial <http://pyro.ai/examples/forecasting_iii.html>`_
- The `forecasting example <http://pyro.ai/examples/forecasting_simple.html>`_
Forecaster Interface
---------------------
.. automodule:: pyro.contrib.forecast.forecaster
:members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
Evaluation
----------
.. automodule:: pyro.contrib.forecast.evaluate
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.funsor.rst
================================================
Funsor-based Pyro
==========================
Primitives
----------
.. automodule:: pyro.contrib.funsor
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Effect handlers
---------------------
.. automodule:: pyro.contrib.funsor.handlers
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.contrib.funsor.handlers.named_messenger
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.contrib.funsor.handlers.primitives
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.contrib.funsor.handlers.runtime
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Inference algorithms
--------------------
.. automodule:: pyro.contrib.funsor.infer.elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.contrib.funsor.infer.trace_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.contrib.funsor.infer.traceenum_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.contrib.funsor.infer.tracetmc_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.contrib.funsor.infer.discrete
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.gp.rst
================================================
Gaussian Processes
==================
See the `Gaussian Processes tutorial <http://pyro.ai/examples/gp.html>`_ for an introduction.
.. automodule:: pyro.contrib.gp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Models
~~~~~~
GPModel
-------
.. automodule:: pyro.contrib.gp.models.model
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
GPRegression
------------
.. automodule:: pyro.contrib.gp.models.gpr
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
SparseGPRegression
------------------
.. automodule:: pyro.contrib.gp.models.sgpr
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
VariationalGP
-------------
.. automodule:: pyro.contrib.gp.models.vgp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
VariationalSparseGP
-------------------
.. automodule:: pyro.contrib.gp.models.vsgp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
GPLVM
-----
.. automodule:: pyro.contrib.gp.models.gplvm
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Kernels
~~~~~~~
.. automodule:: pyro.contrib.gp.kernels
Likelihoods
~~~~~~~~~~~
.. automodule:: pyro.contrib.gp.likelihoods
Parameterized
~~~~~~~~~~~~~
.. automodule:: pyro.contrib.gp.parameterized
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Util
~~~~
.. automodule:: pyro.contrib.gp.util
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.minipyro.rst
================================================
Minipyro
========
.. automodule:: pyro.contrib.minipyro
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
================================================
FILE: docs/source/contrib.mue.rst
================================================
Biological Sequence Models with MuE
===================================
.. automodule:: pyro.contrib.mue
.. warning:: Code in ``pyro.contrib.mue`` is under development.
This code makes no guarantee about maintaining backwards compatibility.
``pyro.contrib.mue`` provides modeling tools for working with biological
sequence data. In particular it implements MuE distributions, which are used as
a fully generative alternative to multiple sequence alignment-based
preprocessing.
Reference:
MuE models were described in Weinstein and Marks (2021),
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.
Example MuE Models
------------------
.. automodule:: pyro.contrib.mue.models
:members:
:show-inheritance:
:member-order: bysource
State Arrangers for Parameterizing MuEs
---------------------------------------
.. automodule:: pyro.contrib.mue.statearrangers
:members:
:show-inheritance:
:member-order: bysource
Missing or Variable Length Data HMM
-----------------------------------
.. automodule:: pyro.contrib.mue.missingdatahmm
:members:
:show-inheritance:
:member-order: bysource
Biosequence Dataset Loading
---------------------------
.. automodule:: pyro.contrib.mue.dataloaders
:members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.oed.rst
================================================
Optimal Experiment Design
=========================
.. automodule:: pyro.contrib.oed
Expected Information Gain
-------------------------
.. automodule:: pyro.contrib.oed.eig
:members:
:member-order: bysource
Generalised Linear Mixed Models
-------------------------------
.. automodule:: pyro.contrib.oed.glmm
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.randomvariable.rst
================================================
Random Variables
================
.. automodule:: pyro.contrib.randomvariable
Random Variable
---------------
.. autoclass:: pyro.contrib.randomvariable.random_variable.RandomVariable
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/contrib.timeseries.rst
================================================
Time Series
===========
.. automodule:: pyro.contrib.timeseries
See the `GP example <http://pyro.ai/examples/timeseries.html>`_ for example usage.
Abstract Models
---------------
.. automodule:: pyro.contrib.timeseries.base
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Gaussian Processes
------------------
.. automodule:: pyro.contrib.timeseries.gp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Linear Gaussian State Space Models
----------------------------------
.. automodule:: pyro.contrib.timeseries.lgssm
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.contrib.timeseries.lgssmgp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/contrib.tracking.rst
================================================
Tracking
========
.. automodule:: pyro.contrib.tracking
Data Association
----------------
.. automodule:: pyro.contrib.tracking.assignment
:members:
:member-order: bysource
Distributions
-------------
.. automodule:: pyro.contrib.tracking.distributions
:members:
:member-order: bysource
Dynamic Models
--------------
.. automodule:: pyro.contrib.tracking.dynamic_models
:members:
:member-order: bysource
Extended Kalman Filter
----------------------
.. automodule:: pyro.contrib.tracking.extended_kalman_filter
:members:
:member-order: bysource
Hashing
-------
.. automodule:: pyro.contrib.tracking.hashing
:members:
:member-order: bysource
Measurements
------------
.. automodule:: pyro.contrib.tracking.measurements
:members:
:member-order: bysource
================================================
FILE: docs/source/contrib.zuko.rst
================================================
Zuko in Pyro
============
.. automodule:: pyro.contrib.zuko
:members:
================================================
FILE: docs/source/distributions.rst
================================================
Distributions
=============
.. toctree::
:glob:
:maxdepth: 2
:caption: Contents:
PyTorch Distributions
~~~~~~~~~~~~~~~~~~~~~
Most distributions in Pyro are thin wrappers around PyTorch distributions.
For details on the PyTorch distribution interface, see
:class:`torch.distributions.distribution.Distribution`.
For differences between the Pyro and PyTorch interfaces, see
:class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`.
.. automodule:: pyro.distributions.torch
Pyro Distributions
~~~~~~~~~~~~~~~~~~
Abstract Distribution
---------------------
.. autoclass:: pyro.distributions.Distribution
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
TorchDistributionMixin
----------------------
.. autoclass:: pyro.distributions.torch_distribution.TorchDistributionMixin
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
TorchDistribution
-----------------
.. autoclass:: pyro.distributions.TorchDistribution
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
AffineBeta
---------------------
.. autoclass:: pyro.distributions.AffineBeta
:members:
:undoc-members:
:show-inheritance:
AsymmetricLaplace
---------------------
.. autoclass:: pyro.distributions.AsymmetricLaplace
:members:
:undoc-members:
:show-inheritance:
AVFMultivariateNormal
---------------------
.. autoclass:: pyro.distributions.AVFMultivariateNormal
:members:
:undoc-members:
:show-inheritance:
BetaBinomial
------------
.. autoclass:: pyro.distributions.BetaBinomial
:members:
:undoc-members:
:show-inheritance:
CoalescentTimes
---------------
.. autoclass:: pyro.distributions.CoalescentTimes
:members:
:undoc-members:
:show-inheritance:
CoalescentTimesWithRate
-----------------------
.. autoclass:: pyro.distributions.CoalescentTimesWithRate
:members:
:undoc-members:
:show-inheritance:
ConditionalDistribution
-----------------------
.. autoclass:: pyro.distributions.ConditionalDistribution
:members:
:undoc-members:
:show-inheritance:
ConditionalTransformedDistribution
----------------------------------
.. autoclass:: pyro.distributions.ConditionalTransformedDistribution
:members:
:undoc-members:
:show-inheritance:
Delta
-----
.. autoclass:: pyro.distributions.Delta
:members:
:undoc-members:
:show-inheritance:
DirichletMultinomial
--------------------
.. autoclass:: pyro.distributions.DirichletMultinomial
:members:
:undoc-members:
:show-inheritance:
DiscreteHMM
-----------
.. autoclass:: pyro.distributions.DiscreteHMM
:members:
:undoc-members:
:show-inheritance:
EmpiricalDistribution
---------------------
.. autoclass:: pyro.distributions.Empirical
:members:
:undoc-members:
:show-inheritance:
ExtendedBetaBinomial
--------------------
.. autoclass:: pyro.distributions.ExtendedBetaBinomial
:members:
:undoc-members:
:show-inheritance:
ExtendedBinomial
----------------
.. autoclass:: pyro.distributions.ExtendedBinomial
:members:
:undoc-members:
:show-inheritance:
FoldedDistribution
---------------------
.. autoclass:: pyro.distributions.FoldedDistribution
:members:
:undoc-members:
:show-inheritance:
GammaGaussianHMM
----------------
.. autoclass:: pyro.distributions.GammaGaussianHMM
:members:
:undoc-members:
:show-inheritance:
GammaPoisson
------------
.. autoclass:: pyro.distributions.GammaPoisson
:members:
:undoc-members:
:show-inheritance:
GaussianHMM
-----------
.. autoclass:: pyro.distributions.GaussianHMM
:members:
:undoc-members:
:show-inheritance:
GaussianMRF
-----------
.. autoclass:: pyro.distributions.GaussianMRF
:members:
:undoc-members:
:show-inheritance:
GaussianScaleMixture
--------------------
.. autoclass:: pyro.distributions.GaussianScaleMixture
:members:
:undoc-members:
:show-inheritance:
GroupedNormalNormal
-------------------
.. autoclass:: pyro.distributions.GroupedNormalNormal
:members:
:undoc-members:
:show-inheritance:
ImproperUniform
---------------
.. autoclass:: pyro.distributions.improper_uniform.ImproperUniform
:members:
:undoc-members:
:show-inheritance:
IndependentHMM
--------------
.. autoclass:: pyro.distributions.IndependentHMM
:members:
:undoc-members:
:show-inheritance:
InverseGamma
------------
.. autoclass:: pyro.distributions.InverseGamma
:members:
:undoc-members:
:show-inheritance:
LinearHMM
---------
.. autoclass:: pyro.distributions.LinearHMM
:members:
:undoc-members:
:show-inheritance:
LKJ
---
.. autoclass:: pyro.distributions.LKJ
:members:
:undoc-members:
:show-inheritance:
LKJCorrCholesky
---------------
.. autoclass:: pyro.distributions.LKJCorrCholesky
:members:
:undoc-members:
:show-inheritance:
LogNormalNegativeBinomial
-------------------------
.. autoclass:: pyro.distributions.LogNormalNegativeBinomial
:members:
:undoc-members:
:show-inheritance:
Logistic
--------
.. autoclass:: pyro.distributions.Logistic
:members:
:undoc-members:
:show-inheritance:
MaskedDistribution
------------------
.. autoclass:: pyro.distributions.MaskedDistribution
:members:
:undoc-members:
:show-inheritance:
MaskedMixture
-------------
.. autoclass:: pyro.distributions.MaskedMixture
:members:
:undoc-members:
:show-inheritance:
MixtureOfDiagNormals
--------------------
.. autoclass:: pyro.distributions.MixtureOfDiagNormals
:members:
:undoc-members:
:show-inheritance:
MixtureOfDiagNormalsSharedCovariance
------------------------------------
.. autoclass:: pyro.distributions.MixtureOfDiagNormalsSharedCovariance
:members:
:undoc-members:
:show-inheritance:
MultivariateStudentT
--------------------
.. autoclass:: pyro.distributions.MultivariateStudentT
:members:
:undoc-members:
:show-inheritance:
NanMaskedNormal
---------------
.. autoclass:: pyro.distributions.NanMaskedNormal
:members:
:undoc-members:
:show-inheritance:
NanMaskedMultivariateNormal
---------------------------
.. autoclass:: pyro.distributions.NanMaskedMultivariateNormal
:members:
:undoc-members:
:show-inheritance:
OMTMultivariateNormal
---------------------
.. autoclass:: pyro.distributions.OMTMultivariateNormal
:members:
:undoc-members:
:show-inheritance:
OneOneMatching
--------------
.. autoclass:: pyro.distributions.OneOneMatching
:members:
:undoc-members:
:show-inheritance:
OneTwoMatching
--------------
.. autoclass:: pyro.distributions.OneTwoMatching
:members:
:undoc-members:
:show-inheritance:
OrderedLogistic
-------------------------------
.. autoclass:: pyro.distributions.OrderedLogistic
:members:
:undoc-members:
:show-inheritance:
ProjectedNormal
---------------
.. autoclass:: pyro.distributions.ProjectedNormal
:members:
:undoc-members:
:show-inheritance:
RelaxedBernoulliStraightThrough
-------------------------------
.. autoclass:: pyro.distributions.RelaxedBernoulliStraightThrough
:members:
:undoc-members:
:show-inheritance:
RelaxedOneHotCategoricalStraightThrough
---------------------------------------
.. autoclass:: pyro.distributions.RelaxedOneHotCategoricalStraightThrough
:members:
:undoc-members:
:show-inheritance:
Rejector
--------
.. autoclass:: pyro.distributions.Rejector
:members:
:undoc-members:
:show-inheritance:
SineBivariateVonMises
---------------------
.. autoclass:: pyro.distributions.SineBivariateVonMises
:members:
:undoc-members:
:show-inheritance:
SineSkewed
----------
.. autoclass:: pyro.distributions.SineSkewed
:members:
:undoc-members:
:show-inheritance:
SkewLogistic
------------
.. autoclass:: pyro.distributions.SkewLogistic
:members:
:undoc-members:
:show-inheritance:
SoftAsymmetricLaplace
---------------------
.. autoclass:: pyro.distributions.SoftAsymmetricLaplace
:members:
:undoc-members:
:show-inheritance:
SoftLaplace
-------------
.. autoclass:: pyro.distributions.SoftLaplace
:members:
:undoc-members:
:show-inheritance:
SpanningTree
------------
.. autoclass:: pyro.distributions.SpanningTree
:members:
:undoc-members:
:show-inheritance:
Stable
------
.. autoclass:: pyro.distributions.Stable
:members:
:undoc-members:
:show-inheritance:
StableWithLogProb
-----------------
.. autoclass:: pyro.distributions.StableWithLogProb
:members:
:undoc-members:
:show-inheritance:
TruncatedPolyaGamma
-------------------
.. autoclass:: pyro.distributions.TruncatedPolyaGamma
:members:
:undoc-members:
:show-inheritance:
Unit
----
.. autoclass:: pyro.distributions.Unit
:members:
:undoc-members:
:show-inheritance:
VonMises3D
----------
.. autoclass:: pyro.distributions.VonMises3D
:members:
:undoc-members:
:show-inheritance:
ZeroInflatedDistribution
------------------------
.. autoclass:: pyro.distributions.ZeroInflatedDistribution
:members:
:undoc-members:
:show-inheritance:
ZeroInflatedNegativeBinomial
----------------------------
.. autoclass:: pyro.distributions.ZeroInflatedNegativeBinomial
:members:
:undoc-members:
:show-inheritance:
ZeroInflatedPoisson
-------------------
.. autoclass:: pyro.distributions.ZeroInflatedPoisson
:members:
:undoc-members:
:show-inheritance:
Transforms
~~~~~~~~~~
ConditionalTransform
--------------------
.. autoclass:: pyro.distributions.ConditionalTransform
:members:
:undoc-members:
:show-inheritance:
CholeskyTransform
-----------------
.. autoclass:: pyro.distributions.transforms.CholeskyTransform
:members:
:undoc-members:
:show-inheritance:
CorrMatrixCholeskyTransform
---------------------------
.. autoclass:: pyro.distributions.transforms.CorrMatrixCholeskyTransform
:members:
:undoc-members:
:show-inheritance:
DiscreteCosineTransform
-----------------------
.. autoclass:: pyro.distributions.transforms.DiscreteCosineTransform
:members:
:undoc-members:
:show-inheritance:
ELUTransform
------------
.. autoclass:: pyro.distributions.transforms.ELUTransform
:members:
:undoc-members:
:show-inheritance:
HaarTransform
-------------
.. autoclass:: pyro.distributions.transforms.HaarTransform
:members:
:undoc-members:
:show-inheritance:
LeakyReLUTransform
------------------
.. autoclass:: pyro.distributions.transforms.LeakyReLUTransform
:members:
:undoc-members:
:show-inheritance:
LowerCholeskyAffine
-------------------
.. autoclass:: pyro.distributions.transforms.LowerCholeskyAffine
:members:
:undoc-members:
:show-inheritance:
Normalize
---------
.. autoclass:: pyro.distributions.transforms.Normalize
:members:
:undoc-members:
:show-inheritance:
OrderedTransform
----------------
.. autoclass:: pyro.distributions.transforms.OrderedTransform
:members:
:undoc-members:
:show-inheritance:
Permute
-------
.. autoclass:: pyro.distributions.transforms.Permute
:members:
:undoc-members:
:show-inheritance:
PositivePowerTransform
----------------------
.. autoclass:: pyro.distributions.transforms.PositivePowerTransform
:members:
:undoc-members:
:show-inheritance:
SimplexToOrderedTransform
-------------------------
.. autoclass:: pyro.distributions.transforms.SimplexToOrderedTransform
:members:
:undoc-members:
:show-inheritance:
SoftplusLowerCholeskyTransform
------------------------------
.. autoclass:: pyro.distributions.transforms.SoftplusLowerCholeskyTransform
:members:
:undoc-members:
:show-inheritance:
SoftplusTransform
-----------------
.. autoclass:: pyro.distributions.transforms.SoftplusTransform
:members:
:undoc-members:
:show-inheritance:
UnitLowerCholeskyTransform
--------------------------
.. autoclass:: pyro.distributions.transforms.UnitLowerCholeskyTransform
:members:
:undoc-members:
:show-inheritance:
TransformModules
~~~~~~~~~~~~~~~~
AffineAutoregressive
--------------------
.. autoclass:: pyro.distributions.transforms.AffineAutoregressive
:members:
:undoc-members:
:show-inheritance:
AffineCoupling
--------------
.. autoclass:: pyro.distributions.transforms.AffineCoupling
:members:
:undoc-members:
:show-inheritance:
BatchNorm
---------
.. autoclass:: pyro.distributions.transforms.BatchNorm
:members:
:undoc-members:
:show-inheritance:
BlockAutoregressive
-------------------
.. autoclass:: pyro.distributions.transforms.BlockAutoregressive
:members:
:undoc-members:
:show-inheritance:
ConditionalAffineAutoregressive
-------------------------------
.. autoclass:: pyro.distributions.transforms.ConditionalAffineAutoregressive
:members:
:undoc-members:
:show-inheritance:
ConditionalAffineCoupling
-------------------------
.. autoclass:: pyro.distributions.transforms.ConditionalAffineCoupling
:members:
:undoc-members:
:show-inheritance:
ConditionalGeneralizedChannelPermute
------------------------------------
.. autoclass:: pyro.distributions.transforms.ConditionalGeneralizedChannelPermute
:members:
:undoc-members:
:show-inheritance:
ConditionalHouseholder
----------------------
.. autoclass:: pyro.distributions.transforms.ConditionalHouseholder
:members:
:undoc-members:
:show-inheritance:
ConditionalMatrixExponential
----------------------------
.. autoclass:: pyro.distributions.transforms.ConditionalMatrixExponential
:members:
:undoc-members:
:show-inheritance:
ConditionalNeuralAutoregressive
-------------------------------
.. autoclass:: pyro.distributions.transforms.ConditionalNeuralAutoregressive
:members:
:undoc-members:
:show-inheritance:
ConditionalPlanar
-----------------
.. autoclass:: pyro.distributions.transforms.ConditionalPlanar
:members:
:undoc-members:
:show-inheritance:
ConditionalRadial
-----------------
.. autoclass:: pyro.distributions.transforms.ConditionalRadial
:members:
:undoc-members:
:show-inheritance:
ConditionalSpline
-----------------
.. autoclass:: pyro.distributions.transforms.ConditionalSpline
:members:
:undoc-members:
:show-inheritance:
ConditionalSplineAutoregressive
-------------------------------
.. autoclass:: pyro.distributions.transforms.ConditionalSplineAutoregressive
:members:
:undoc-members:
:show-inheritance:
ConditionalTransformModule
--------------------------
.. autoclass:: pyro.distributions.ConditionalTransformModule
:members:
:undoc-members:
:show-inheritance:
GeneralizedChannelPermute
-------------------------
.. autoclass:: pyro.distributions.transforms.GeneralizedChannelPermute
:members:
:undoc-members:
:show-inheritance:
Householder
-----------
.. autoclass:: pyro.distributions.transforms.Householder
:members:
:undoc-members:
:show-inheritance:
MatrixExponential
-----------------
.. autoclass:: pyro.distributions.transforms.MatrixExponential
:members:
:undoc-members:
:show-inheritance:
NeuralAutoregressive
--------------------
.. autoclass:: pyro.distributions.transforms.NeuralAutoregressive
:members:
:undoc-members:
:show-inheritance:
Planar
------
.. autoclass:: pyro.distributions.transforms.Planar
:members:
:undoc-members:
:show-inheritance:
Polynomial
----------
.. autoclass:: pyro.distributions.transforms.Polynomial
:members:
:undoc-members:
:show-inheritance:
Radial
------
.. autoclass:: pyro.distributions.transforms.Radial
:members:
:undoc-members:
:show-inheritance:
Spline
------
.. autoclass:: pyro.distributions.transforms.Spline
:members:
:undoc-members:
:show-inheritance:
SplineAutoregressive
--------------------
.. autoclass:: pyro.distributions.transforms.SplineAutoregressive
:members:
:undoc-members:
:show-inheritance:
SplineCoupling
--------------
.. autoclass:: pyro.distributions.transforms.SplineCoupling
:members:
:undoc-members:
:show-inheritance:
Sylvester
---------
.. autoclass:: pyro.distributions.transforms.Sylvester
:members:
:undoc-members:
:show-inheritance:
TransformModule
---------------
.. autoclass:: pyro.distributions.TransformModule
:members:
:undoc-members:
:show-inheritance:
ComposeTransformModule
----------------------
.. autoclass:: pyro.distributions.ComposeTransformModule
:members:
:undoc-members:
:show-inheritance:
Transform Factories
~~~~~~~~~~~~~~~~~~~
Each :class:`~torch.distributions.transforms.Transform` and :class:`~pyro.distributions.TransformModule` includes a corresponding helper function in lower case that inputs, at minimum, the input dimensions of the transform, and possibly additional arguments to customize the transform in an intuitive way. The purpose of these helper functions is to hide from the user whether or not the transform requires the construction of a hypernet, and if so, the input and output dimensions of that hypernet.
iterated
--------
.. autofunction:: pyro.distributions.transforms.iterated
affine_autoregressive
---------------------
.. autofunction:: pyro.distributions.transforms.affine_autoregressive
affine_coupling
---------------
.. autofunction:: pyro.distributions.transforms.affine_coupling
batchnorm
---------
.. autofunction:: pyro.distributions.transforms.batchnorm
block_autoregressive
--------------------
.. autofunction:: pyro.distributions.transforms.block_autoregressive
conditional_affine_autoregressive
---------------------------------
.. autofunction:: pyro.distributions.transforms.conditional_affine_autoregressive
conditional_affine_coupling
---------------------------
.. autofunction:: pyro.distributions.transforms.conditional_affine_coupling
conditional_generalized_channel_permute
---------------------------------------
.. autofunction:: pyro.distributions.transforms.conditional_generalized_channel_permute
conditional_householder
-----------------------
.. autofunction:: pyro.distributions.transforms.conditional_householder
conditional_matrix_exponential
------------------------------
.. autofunction:: pyro.distributions.transforms.conditional_matrix_exponential
conditional_neural_autoregressive
---------------------------------
.. autofunction:: pyro.distributions.transforms.conditional_neural_autoregressive
conditional_planar
------------------
.. autofunction:: pyro.distributions.transforms.conditional_planar
conditional_radial
------------------
.. autofunction:: pyro.distributions.transforms.conditional_radial
conditional_spline
------------------
.. autofunction:: pyro.distributions.transforms.conditional_spline
conditional_spline_autoregressive
---------------------------------
.. autofunction:: pyro.distributions.transforms.conditional_spline_autoregressive
elu
---
.. autofunction:: pyro.distributions.transforms.elu
generalized_channel_permute
---------------------------
.. autofunction:: pyro.distributions.transforms.generalized_channel_permute
householder
-----------
.. autofunction:: pyro.distributions.transforms.householder
leaky_relu
----------
.. autofunction:: pyro.distributions.transforms.leaky_relu
matrix_exponential
------------------
.. autofunction:: pyro.distributions.transforms.matrix_exponential
neural_autoregressive
---------------------
.. autofunction:: pyro.distributions.transforms.neural_autoregressive
permute
-------
.. autofunction:: pyro.distributions.transforms.permute
planar
------
.. autofunction:: pyro.distributions.transforms.planar
polynomial
----------
.. autofunction:: pyro.distributions.transforms.polynomial
radial
------
.. autofunction:: pyro.distributions.transforms.radial
spline
------
.. autofunction:: pyro.distributions.transforms.spline
spline_autoregressive
---------------------
.. autofunction:: pyro.distributions.transforms.spline_autoregressive
spline_coupling
---------------
.. autofunction:: pyro.distributions.transforms.spline_coupling
sylvester
---------
.. autofunction:: pyro.distributions.transforms.sylvester
Constraints
~~~~~~~~~~~
.. automodule:: pyro.distributions.constraints
================================================
FILE: docs/source/getting_started.rst
================================================
Getting Started
===============
- `Install Pyro <http://pyro.ai#install>`_.
- Learn the basic concepts of Pyro:
`models and inference <http://pyro.ai/examples/intro_long.html>`_
- Dive in to other `tutorials <http://pyro.ai/examples>`_ and
`examples <https://github.com/pyro-ppl/pyro/tree/dev/examples>`_.
================================================
FILE: docs/source/index.rst
================================================
.. Pyro documentation master file, created by
sphinx-quickstart on Thu Jun 15 17:16:14 2017.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
:github_url: https://github.com/pyro-ppl/pyro
Pyro Documentation
================================
.. toctree::
:glob:
:maxdepth: 2
:caption: Pyro Core:
getting_started
primitives
inference
distributions
parameters
nn
optimization
poutine
ops
settings
testing
.. toctree::
:glob:
:maxdepth: 2
:caption: Contributed Code:
contrib.autoname
contrib.bnn
contrib.cevae
contrib.easyguide
contrib.epidemiology
contrib.examples
contrib.forecast
contrib.funsor
contrib.gp
contrib.minipyro
contrib.mue
contrib.oed
contrib.randomvariable
contrib.timeseries
contrib.tracking
contrib.zuko
Indices and tables
==================
* :ref:`genindex`
* :ref:`search`
.. * :ref:`modindex`
================================================
FILE: docs/source/infer.autoguide.rst
================================================
Automatic Guide Generation
==========================
.. automodule:: pyro.infer.autoguide
AutoGuide
---------
.. autoclass:: pyro.infer.autoguide.AutoGuide
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoGuideList
-------------
.. autoclass:: pyro.infer.autoguide.AutoGuideList
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoCallable
------------
.. autoclass:: pyro.infer.autoguide.AutoCallable
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoNormal
----------
.. autoclass:: pyro.infer.autoguide.AutoNormal
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoDelta
---------
.. autoclass:: pyro.infer.autoguide.AutoDelta
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoContinuous
--------------
.. autoclass:: pyro.infer.autoguide.AutoContinuous
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoMultivariateNormal
----------------------
.. autoclass:: pyro.infer.autoguide.AutoMultivariateNormal
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoDiagonalNormal
------------------
.. autoclass:: pyro.infer.autoguide.AutoDiagonalNormal
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoLowRankMultivariateNormal
-----------------------------
.. autoclass:: pyro.infer.autoguide.AutoLowRankMultivariateNormal
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoNormalizingFlow
-------------------
.. autoclass:: pyro.infer.autoguide.AutoNormalizingFlow
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoIAFNormal
-----------------------------
.. autoclass:: pyro.infer.autoguide.AutoIAFNormal
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoLaplaceApproximation
-----------------------------
.. autoclass:: pyro.infer.autoguide.AutoLaplaceApproximation
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoDiscreteParallel
--------------------
.. autoclass:: pyro.infer.autoguide.AutoDiscreteParallel
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoStructured
--------------------
.. autoclass:: pyro.infer.autoguide.AutoStructured
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoGaussian
------------
.. autoclass:: pyro.infer.autoguide.AutoGaussian
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoMessenger
-------------
.. autoclass:: pyro.infer.autoguide.AutoMessenger
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoNormalMessenger
-------------------
.. autoclass:: pyro.infer.autoguide.AutoNormalMessenger
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoHierarchicalNormalMessenger
-------------------------------
.. autoclass:: pyro.infer.autoguide.AutoHierarchicalNormalMessenger
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
AutoRegressiveMessenger
-----------------------
.. autoclass:: pyro.infer.autoguide.AutoRegressiveMessenger
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:
.. _autoguide-initialization:
Initialization
--------------
.. automodule:: pyro.infer.autoguide.initialization
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:
================================================
FILE: docs/source/infer.reparam.rst
================================================
Reparameterizers
================
.. automodule:: pyro.infer.reparam
The :mod:`pyro.infer.reparam` module contains reparameterization strategies for
the :func:`pyro.poutine.handlers.reparam` effect. These are useful for altering
geometry of a poorly-conditioned parameter space to make the posterior better
shaped. These can be used with a variety of inference algorithms, e.g.
``Auto*Normal`` guides and MCMC.
.. automodule:: pyro.infer.reparam.reparam
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
Automatic Strategies
--------------------
.. automodule:: pyro.infer.reparam.strategies
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Conjugate Updating
------------------
.. automodule:: pyro.infer.reparam.conjugate
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Loc-Scale Decentering
---------------------
.. automodule:: pyro.infer.reparam.loc_scale
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Gumbel-Softmax
--------------
.. automodule:: pyro.infer.reparam.softmax
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Transformed Distributions
-------------------------
.. automodule:: pyro.infer.reparam.transform
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Discrete Cosine Transform
-------------------------
.. automodule:: pyro.infer.reparam.discrete_cosine
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Haar Transform
--------------
.. automodule:: pyro.infer.reparam.haar
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Unit Jacobian Transforms
------------------------
.. automodule:: pyro.infer.reparam.unit_jacobian
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
StudentT Distributions
----------------------
.. automodule:: pyro.infer.reparam.studentt
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Stable Distributions
--------------------
.. automodule:: pyro.infer.reparam.stable
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Projected Normal Distributions
------------------------------
.. automodule:: pyro.infer.reparam.projected_normal
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Hidden Markov Models
--------------------
.. automodule:: pyro.infer.reparam.hmm
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Site Splitting
--------------
.. automodule:: pyro.infer.reparam.split
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Neural Transport
----------------
.. automodule:: pyro.infer.reparam.neutra
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
Structured Preconditioning
--------------------------
.. automodule:: pyro.infer.reparam.structured
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
================================================
FILE: docs/source/infer.util.rst
================================================
Inference utilities
===================
.. autofunction:: pyro.infer.util.enable_validation
.. autofunction:: pyro.infer.util.is_validation_enabled
.. autofunction:: pyro.infer.util.validation_enabled
Model inspection
----------------
.. automodule:: pyro.infer.inspect
:members:
:member-order: bysource
Interactive prior tuning
------------------------
.. automodule:: pyro.infer.resampler
:members:
:member-order: bysource
================================================
FILE: docs/source/inference.rst
================================================
Inference
=========
In the context of probabilistic modeling, learning is usually called inference.
In the particular case of Bayesian inference, this often involves computing
(approximate) posterior distributions. In the case of parameterized models, this
usually involves some sort of optimization. Pyro supports multiple inference algorithms,
with support for stochastic variational inference (SVI) being the most extensive.
Look here for more inference algorithms in future versions of Pyro.
See the `Introductory tutorial <http://pyro.ai/examples/intro_long.html>`_ for a discussion of inference in Pyro.
.. toctree::
:glob:
:maxdepth: 2
:caption: Contents:
inference_algos
mcmc
infer.autoguide
infer.reparam
infer.util
================================================
FILE: docs/source/inference_algos.rst
================================================
SVI
---
.. automodule:: pyro.infer.svi
:members:
:undoc-members:
:show-inheritance:
ELBO
----
.. automodule:: pyro.infer.elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.infer.trace_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.infer.tracegraph_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.infer.traceenum_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.infer.trace_mean_field_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.infer.trace_tail_adaptive_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.infer.renyi_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.infer.tracetmc_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Importance
----------
.. automodule:: pyro.infer.importance
:members:
:undoc-members:
:show-inheritance:
Reweighted Wake-Sleep
---------------------
.. automodule:: pyro.infer.rws
:members:
:undoc-members:
:show-inheritance:
Sequential Monte Carlo
----------------------
.. automodule:: pyro.infer.smcfilter
:members:
:undoc-members:
:show-inheritance:
Stein Methods
----------------------
.. automodule:: pyro.infer.svgd
:members:
:undoc-members:
:show-inheritance:
Likelihood free methods
-----------------------
.. automodule:: pyro.infer.energy_distance
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
Discrete Inference
------------------
.. automodule:: pyro.infer.discrete
:members:
:show-inheritance:
:member-order: bysource
Prediction utilities
--------------------
.. automodule:: pyro.infer.predictive
:members:
:undoc-members:
:show-inheritance:
.. automodule:: pyro.infer.abstract_infer
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/mcmc.rst
================================================
MCMC
====
.. include:: pyro.infer.mcmc.txt
================================================
FILE: docs/source/nn.rst
================================================
Neural Networks
===============
The module `pyro.nn` provides implementations of neural network modules
that are useful in the context of deep probabilistic programming.
Pyro Modules
------------
.. automodule:: pyro.nn.module
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
AutoRegressiveNN
----------------
.. autoclass:: pyro.nn.auto_reg_nn.AutoRegressiveNN
:members:
:undoc-members:
:show-inheritance:
DenseNN
-------
.. autoclass:: pyro.nn.dense_nn.DenseNN
:members:
:undoc-members:
:show-inheritance:
ConditionalAutoRegressiveNN
---------------------------
.. autoclass:: pyro.nn.auto_reg_nn.ConditionalAutoRegressiveNN
:members:
:undoc-members:
:show-inheritance:
ConditionalDenseNN
------------------
.. autoclass:: pyro.nn.dense_nn.ConditionalDenseNN
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/source/ops.rst
================================================
Miscellaneous Ops
=================
The ``pyro.ops`` module implements tensor utilities
that are mostly independent of the rest of Pyro.
Utilities for HMC
-----------------
.. automodule:: pyro.ops.dual_averaging
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.ops.integrator
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.ops.welford
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Newton Optimizers
-----------------
.. automodule:: pyro.ops.newton
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Special Functions
-----------------
.. automodule:: pyro.ops.special
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Tensor Utilities
----------------
.. automodule:: pyro.ops.tensor_utils
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Tensor Indexing
---------------
.. automodule:: pyro.ops.indexing
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Tensor Contraction
------------------
.. automodule:: pyro.ops.einsum
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
.. autofunction:: pyro.ops.contract.einsum
.. autofunction:: pyro.ops.contract.ubersum
Gaussian Contraction
--------------------
.. automodule:: pyro.ops.gaussian
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
:special-members: __add__,__getitem__
Statistical Utilities
---------------------
.. automodule:: pyro.ops.stats
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
Streaming Statistics
--------------------
.. automodule:: pyro.ops.streaming
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
State Space Model and GP Utilities
----------------------------------
.. automodule:: pyro.ops.ssm_gp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/optimization.rst
================================================
Optimization
============
The module `pyro.optim` provides support for optimization in Pyro. In particular
it provides `PyroOptim`, which is used to wrap PyTorch optimizers
and manage optimizers for dynamically generated parameters
(see the tutorial `SVI Part I <http://pyro.ai/examples/svi_part_i.html>`_ for
a discussion). Any custom optimization algorithms are also to be found here.
.. include:: pyro.optim.txt
================================================
FILE: docs/source/parameters.rst
================================================
Parameters
==========
Parameters in Pyro are basically thin wrappers around PyTorch Tensors that carry unique names.
As such Parameters are the primary stateful objects in Pyro. Users typically interact with parameters
via the Pyro primitive `pyro.param`. Parameters play a central role in stochastic variational inference,
where they are used to represent point estimates for the parameters in parameterized families of
models and guides.
ParamStore
----------
.. automodule:: pyro.params.param_store
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/poutine.rst
================================================
Poutine (Effect handlers)
==========================
Beneath the built-in inference algorithms, Pyro has a library of composable
effect handlers for creating new inference algorithms and working with probabilistic
programs. Pyro's inference algorithms are all built by applying these handlers to stochastic functions.
In order to get a general understanding what effect handlers are and what problem they solve, read
`An Introduction to Algebraic Effects and Handlers <https://www.eff-lang.org/handlers-tutorial.pdf>`_
by Matija Pretnar.
Handlers
---------
.. automodule:: pyro.poutine.handlers
:members:
.. autofunction:: pyro.infer.enum.config_enumerate
Trace
------
.. autoclass:: pyro.poutine.Trace
:members:
:undoc-members:
:show-inheritance:
Runtime
--------
.. automodule:: pyro.poutine.runtime
:members:
:undoc-members:
:show-inheritance:
Utilities
----------
.. automodule:: pyro.poutine.util
:members:
:undoc-members:
:show-inheritance:
Messengers
-----------
Messenger objects contain the implementations of the effects exposed by handlers.
Advanced users may modify the implementations of messengers behind existing handlers or write new messengers
that implement new effects and compose correctly with the rest of the library.
.. include:: pyro.poutine.txt
================================================
FILE: docs/source/primitives.rst
================================================
Primitives
==========
.. automodule:: pyro.primitives
:members:
:show-inheritance:
:member-order: bysource
.. autofunction:: pyro.ops.jit.trace
================================================
FILE: docs/source/pyro.infer.mcmc.txt
================================================
MCMC
----
.. autoclass:: pyro.infer.mcmc.api.MCMC
:members:
:undoc-members:
:show-inheritance:
StreamingMCMC
-------------
.. autoclass:: pyro.infer.mcmc.api.StreamingMCMC
:members:
:undoc-members:
:show-inheritance:
MCMCKernel
----------
.. autoclass:: pyro.infer.mcmc.mcmc_kernel.MCMCKernel
:members:
:undoc-members:
:show-inheritance:
HMC
---
.. autoclass:: pyro.infer.mcmc.HMC
:members:
:undoc-members:
:show-inheritance:
NUTS
----
.. autoclass:: pyro.infer.mcmc.NUTS
:members:
:undoc-members:
:show-inheritance:
RandomWalkKernel
----------------
.. autoclass:: pyro.infer.mcmc.RandomWalkKernel
:members:
:undoc-members:
:show-inheritance:
BlockMassMatrix
---------------
.. autoclass:: pyro.infer.mcmc.BlockMassMatrix
:members:
:undoc-members:
:show-inheritance:
Utilities
---------
.. autofunction:: pyro.infer.mcmc.util.initialize_model
.. autofunction:: pyro.infer.mcmc.util.diagnostics
.. autofunction:: pyro.infer.mcmc.util.select_samples
================================================
FILE: docs/source/pyro.optim.txt
================================================
Pyro Optimizers
---------------
.. automodule:: pyro.optim.optim
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.optim.lr_scheduler
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
.. automodule:: pyro.optim.adagrad_rmsprop
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.optim.clipped_adam
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
.. automodule:: pyro.optim.horovod
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
PyTorch Optimizers
------------------
.. automodule:: pyro.optim.pytorch_optimizers
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
Higher-Order Optimizers
-----------------------
.. automodule:: pyro.optim.multi
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
================================================
FILE: docs/source/pyro.poutine.txt
================================================
Messenger
__________
.. automodule:: pyro.poutine.messenger
:members:
:undoc-members:
:show-inheritance:
BlockMessenger
_______________
.. automodule:: pyro.poutine.block_messenger
:members:
:undoc-members:
:show-inheritance:
BroadcastMessenger
__________________
.. automodule:: pyro.poutine.broadcast_messenger
:members:
:undoc-members:
:show-inheritance:
CollapseMessenger
_________________
.. automodule:: pyro.poutine.collapse_messenger
:members:
:undoc-members:
:show-inheritance:
ConditionMessenger
___________________
.. automodule:: pyro.poutine.condition_messenger
:members:
:undoc-members:
:show-inheritance:
DoMessenger
___________________
.. automodule:: pyro.poutine.do_messenger
:members:
:undoc-members:
:show-inheritance:
EnumMessenger
________________
.. automodule:: pyro.poutine.enum_messenger
:members:
:undoc-members:
:show-inheritance:
EqualizeMessenger
____________________
.. automodule:: pyro.poutine.equalize_messenger
:members:
:undoc-members:
:show-inheritance:
EscapeMessenger
________________
.. automodule:: pyro.poutine.escape_messenger
:members:
:undoc-members:
:show-inheritance:
IndepMessenger
_______________
.. automodule:: pyro.poutine.indep_messenger
:members:
:undoc-members:
:show-inheritance:
InferConfigMessenger
____________________
.. automodule:: pyro.poutine.infer_config_messenger
:members:
:undoc-members:
:show-inheritance:
LiftMessenger
______________
.. automodule:: pyro.poutine.lift_messenger
:members:
:undoc-members:
:show-inheritance:
MarkovMessenger
_______________
.. automodule:: pyro.poutine.markov_messenger
:members:
:undoc-members:
:show-inheritance:
MaskMessenger
______________
.. automodule:: pyro.poutine.mask_messenger
:members:
:undoc-members:
:show-inheritance:
PlateMessenger
______________
.. automodule:: pyro.poutine.plate_messenger
:members:
:undoc-members:
:show-inheritance:
ReentrantMessenger
___________________
.. automodule:: pyro.poutine.reentrant_messenger
:members:
:undoc-members:
:show-inheritance:
ReparamMessenger
________________
.. automodule:: pyro.poutine.reparam_messenger
:members:
:undoc-members:
:show-inheritance:
ReplayMessenger
________________
.. automodule:: pyro.poutine.replay_messenger
:members:
:undoc-members:
:show-inheritance:
ScaleMessenger
_______________
.. automodule:: pyro.poutine.scale_messenger
:members:
:undoc-members:
:show-inheritance:
SeedMessenger
_______________
.. automodule:: pyro.poutine.seed_messenger
:members:
:undoc-members:
:show-inheritance:
SubsampleMessenger
__________________
.. automodule:: pyro.poutine.subsample_messenger
:members:
:undoc-members:
:show-inheritance:
SubstituteMessenger
___________________
.. automodule:: pyro.poutine.substitute_messenger
:members:
:undoc-members:
:show-inheritance:
TraceMessenger
_______________
.. automodule:: pyro.poutine.trace_messenger
:members:
:undoc-members:
:show-inheritance:
UnconditionMessenger
____________________
.. automodule:: pyro.poutine.uncondition_messenger
:members:
:undoc-members:
:show-inheritance:
GuideMessenger
______________
.. automodule:: pyro.poutine.guide
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:
================================================
FILE: docs/source/settings.rst
================================================
Settings
--------
.. automodule:: pyro.settings
:members:
:member-order: bysource
================================================
FILE: docs/source/testing.rst
================================================
Testing Utilities
-----------------
.. automodule:: pyro.distributions.testing.gof
:members:
:member-order: bysource
================================================
FILE: examples/__init__.py
================================================
================================================
FILE: examples/air/air.py
================================================
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
An implementation of the model described in [1].
[1] Eslami, SM Ali, et al. "Attend, infer, repeat: Fast scene
understanding with generative models." Advances in Neural Information
Processing Systems. 2016.
"""
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import MLP, Decoder, Encoder, Identity, Predict
import pyro
import pyro.distributions as dist
# Default prior success probability for z_pres.
def default_z_pres_prior_p(t):
return 0.5
ModelState = namedtuple("ModelState", ["x", "z_pres", "z_where"])
GuideState = namedtuple(
"GuideState", ["h", "c", "bl_h", "bl_c", "z_pres", "z_where", "z_what"]
)
class AIR(nn.Module):
def __init__(
self,
num_steps,
x_size,
window_size,
z_what_size,
rnn_hidden_size,
encoder_net=[],
decoder_net=[],
predict_net=[],
embed_net=None,
bl_predict_net=[],
non_linearity="ReLU",
decoder_output_bias=None,
decoder_output_use_sigmoid=False,
use_masking=True,
use_baselines=True,
baseline_scalar=None,
scale_prior_mean=3.0,
scale_prior_sd=0.1,
pos_prior_mean=0.0,
pos_prior_sd=1.0,
likelihood_sd=0.3,
use_cuda=False,
):
super().__init__()
self.num_steps = num_steps
self.x_size = x_size
self.window_size = window_size
self.z_what_size = z_what_size
self.rnn_hidden_size = rnn_hidden_size
self.use_masking = use_masking
self.use_baselines = use_baselines
self.baseline_scalar = baseline_scalar
self.likelihood_sd = likelihood_sd
self.use_cuda = use_cuda
prototype = torch.tensor(0.0).cuda() if use_cuda else torch.tensor(0.0)
self.options = dict(dtype=prototype.dtype, device=prototype.device)
self.z_pres_size = 1
self.z_where_size = 3
# By making these parameters they will be moved to the gpu
# when necessary. (They are not registered with pyro for
# optimization.)
self.z_where_loc_prior = nn.Parameter(
torch.FloatTensor([scale_prior_mean, pos_prior_mean, pos_prior_mean]),
requires_grad=False,
)
self.z_where_scale_prior = nn.Parameter(
torch.FloatTensor([scale_prior_sd, pos_prior_sd, pos_prior_sd]),
requires_grad=False,
)
# Create nn modules.
rnn_input_size = x_size**2 if embed_net is None else embed_net[-1]
rnn_input_size += self.z_where_size + z_what_size + self.z_pres_size
nl = getattr(nn, non_linearity)
self.rnn = nn.LSTMCell(rnn_input_size, rnn_hidden_size)
self.encode = Encoder(window_size**2, encoder_net, z_what_size, nl)
self.decode = Decoder(
window_size**2,
decoder_net,
z_what_size,
decoder_output_bias,
decoder_output_use_sigmoid,
nl,
)
self.predict = Predict(
rnn_hidden_size, predict_net, self.z_pres_size, self.z_where_size, nl
)
self.embed = (
Identity() if embed_net is None else MLP(x_size**2, embed_net, nl, True)
)
self.bl_rnn = nn.LSTMCell(rnn_input_size, rnn_hidden_size)
self.bl_predict = MLP(rnn_hidden_size, bl_predict_net + [1], nl)
self.bl_embed = (
Identity() if embed_net is None else MLP(x_size**2, embed_net, nl, True)
)
# Create parameters.
self.h_init = nn.Parameter(torch.zeros(1, rnn_hidden_size))
self.c_init = nn.Parameter(torch.zeros(1, rnn_hidden_size))
self.bl_h_init = nn.Parameter(torch.zeros(1, rnn_hidden_size))
self.bl_c_init = nn.Parameter(torch.zeros(1, rnn_hidden_size))
self.z_where_init = nn.Parameter(torch.zeros(1, self.z_where_size))
self.z_what_init = nn.Parameter(torch.zeros(1, self.z_what_size))
if use_cuda:
self.cuda()
def prior(self, n, **kwargs):
state = ModelState(
x=torch.zeros(n, self.x_size, self.x_size, **self.options),
z_pres=torch.ones(n, self.z_pres_size, **self.options),
z_where=None,
)
z_pres = []
z_where = []
for t in range(self.num_steps):
state = self.prior_step(t, n, state, **kwargs)
z_where.append(state.z_where)
z_pres.append(state.z_pres)
return (z_where, z_pres), state.x
def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p):
# Sample presence indicators.
z_pres = pyro.sample(
"z_pres_{}".format(t),
dist.Bernoulli(z_pres_prior_p(t) * prev.z_pres).to_event(1),
)
# If zero is sampled for a data point, then no more objects
# will be added to its output image. We can't
# straight-forwardly avoid generating further objects, so
# instead we zero out the log_prob_sum of future choices.
sample_mask = z_pres if self.use_masking else torch.tensor(1.0)
# Sample attention window position.
z_where = pyro.sample(
"z_where_{}".format(t),
dist.Normal(
self.z_where_loc_prior.expand(n, self.z_where_size),
self.z_where_scale_prior.expand(n, self.z_where_size),
)
.mask(sample_mask)
.to_event(1),
)
# Sample latent code for contents of the attention window.
z_what = pyro.sample(
"z_what_{}".format(t),
dist.Normal(
torch.zeros(n, self.z_what_size, **self.options),
torch.ones(n, self.z_what_size, **self.options),
)
.mask(sample_mask)
.to_event(1),
)
# Map latent code to pixel space.
y_att = self.decode(z_what)
# Position/scale attention window within larger image.
y = window_to_image(z_where, self.window_size, self.x_size, y_att)
# Combine the image generated at this step with the image so far.
# (Note that there's no notion of occlusion here. Overlapping
# objects can create pixel intensities > 1.)
x = prev.x + (y * z_pres.view(-1, 1, 1))
return ModelState(x=x, z_pres=z_pres, z_where=z_where)
def model(self, data, batch_size, **kwargs):
pyro.module("decode", self.decode)
with pyro.plate("data", data.size(0), device=data.device) as ix:
batch = data[ix]
n = batch.size(0)
(z_where, z_pres), x = self.prior(n, **kwargs)
pyro.sample(
"obs",
dist.Normal(
x.view(n, -1),
(
self.likelihood_sd
* torch.ones(n, self.x_size**2, **self.options)
),
).to_event(1),
obs=batch.view(n, -1),
)
def guide(self, data, batch_size, **kwargs):
pyro.module("rnn", self.rnn),
pyro.module("predict", self.predict),
pyro.module("encode", self.encode),
pyro.module("embed", self.embed),
pyro.module("bl_rnn", self.bl_rnn),
pyro.module("bl_predict", self.bl_predict),
pyro.module("bl_embed", self.bl_embed)
pyro.param("h_init", self.h_init)
pyro.param("c_init", self.c_init)
pyro.param("z_where_init", self.z_where_init)
pyro.param("z_what_init", self.z_what_init)
pyro.param("bl_h_init", self.bl_h_init)
pyro.param("bl_c_init", self.bl_c_init)
with pyro.plate(
"data", data.size(0), subsample_size=batch_size, device=data.device
) as ix:
batch = data[ix]
n = batch.size(0)
# Embed inputs.
flattened_batch = batch.view(n, -1)
inputs = {
"raw": batch,
"embed": self.embed(flattened_batch),
"bl_embed": self.bl_embed(flattened_batch),
}
# Initial state.
state = GuideState(
h=self.h_init.expand(n, -1),
c=self.c_init.expand(n, -1),
bl_h=self.bl_h_init.expand(n, -1),
bl_c=self.bl_c_init.expand(n, -1),
z_pres=torch.ones(n, self.z_pres_size, **self.options),
z_where=self.z_where_init.expand(n, -1),
z_what=self.z_what_init.expand(n, -1),
)
z_pres = []
z_where = []
for t in range(self.num_steps):
state = self.guide_step(t, n, state, inputs)
z_where.append(state.z_where)
z_pres.append(state.z_pres)
return z_where, z_pres
def guide_step(self, t, n, prev, inputs):
rnn_input = torch.cat(
(inputs["embed"], prev.z_where, prev.z_what, prev.z_pres), 1
)
h, c = self.rnn(rnn_input, (prev.h, prev.c))
z_pres_p, z_where_loc, z_where_scale = self.predict(h)
# Compute baseline estimates for discrete choice z_pres.
infer_dict, bl_h, bl_c = self.baseline_step(prev, inputs)
# Sample presence.
z_pres = pyro.sample(
"z_pres_{}".format(t),
dist.Bernoulli(z_pres_p * prev.z_pres).to_event(1),
infer=infer_dict,
)
sample_mask = z_pres if self.use_masking else torch.tensor(1.0)
z_where = pyro.sample(
"z_where_{}".format(t),
dist.Normal(
z_where_loc + self.z_where_loc_prior,
z_where_scale * self.z_where_scale_prior,
)
.mask(sample_mask)
.to_event(1),
)
# Figure 2 of [1] shows x_att depending on z_where and h,
# rather than z_where and x as here, but I think this is
# correct.
x_att = image_to_window(z_where, self.window_size, self.x_size, inputs["raw"])
# Encode attention windows.
z_what_loc, z_what_scale = self.encode(x_att)
z_what = pyro.sample(
"z_what_{}".format(t),
dist.Normal(z_what_loc, z_what_scale).mask(sample_mask).to_event(1),
)
return GuideState(
h=h,
c=c,
bl_h=bl_h,
bl_c=bl_c,
z_pres=z_pres,
z_where=z_where,
z_what=z_what,
)
def baseline_step(self, prev, inputs):
if not self.use_baselines:
return dict(), None, None
# Prevent gradients flowing back from baseline loss to
# inference net by detaching from graph here.
rnn_input = torch.cat(
(
inputs["bl_embed"],
prev.z_where.detach(),
prev.z_what.detach(),
prev.z_pres.detach(),
),
1,
)
bl_h, bl_c = self.bl_rnn(rnn_input, (prev.bl_h, prev.bl_c))
bl_value = self.bl_predict(bl_h)
# Zero out values for finished data points. This avoids adding
# superfluous terms to the loss.
if self.use_masking:
bl_value = bl_value * prev.z_pres
# The value that the baseline net is estimating can be very
# large. An option to scale the nets output is provided
# to make it easier for the net to output values of this
# scale.
if self.baseline_scalar is not None:
bl_value = bl_value * self.baseline_scalar
infer_dict = dict(baseline=dict(baseline_value=bl_value.squeeze(-1)))
return infer_dict, bl_h, bl_c
# Spatial transformer helpers.
expansion_indices = torch.LongTensor([1, 0, 2, 0, 1, 3])
def expand_z_where(z_where):
# Take a batch of three-vectors, and massages them into a batch of
# 2x3 matrices with elements like so:
# [s,x,y] -> [[s,0,x],
# [0,s,y]]
n = z_where.size(0)
out = torch.cat((z_where.new_zeros(n, 1), z_where), 1)
ix = expansion_indices
if z_where.is_cuda:
ix = ix.cuda()
out = torch.index_select(out, 1, ix)
out = out.view(n, 2, 3)
return out
# Scaling by `1/scale` here is unsatisfactory, as `scale` could be
# zero.
def z_where_inv(z_where):
# Take a batch of z_where vectors, and compute their "inverse".
# That is, for each row compute:
# [s,x,y] -> [1/s,-x/s,-y/s]
# These are the parameters required to perform the inverse of the
# spatial transform performed in the generative model.
n = z_where.size(0)
out = torch.cat((z_where.new_ones(n, 1), -z_where[:, 1:]), 1)
# Divide all entries by the scale.
out = out / z_where[:, 0:1]
return out
def window_to_image(z_where, window_size, image_size, windows):
n = windows.size(0)
assert windows.size(1) == window_size**2, "Size mismatch."
theta = expand_z_where(z_where)
grid = F.affine_grid(theta, torch.Size((n, 1, image_size, image_size)))
out = F.grid_sample(windows.view(n, 1, window_size, window_size), grid)
return out.view(n, image_size, image_size)
def image_to_window(z_where, window_size, image_size, images):
n = images.size(0)
assert images.size(1) == images.size(2) == image_size, "Size mismatch."
theta_inv = expand_z_where(z_where_inv(z_where))
grid = F.affine_grid(theta_inv, torch.Size((n, 1, window_size, window_size)))
out = F.grid_sample(images.view(n, 1, image_size, image_size), grid)
return out.view(n, -1)
# Combine z_pres and z_where (as returned by the model and guide) into
# a single tensor, with size:
# [batch_size, num_steps, z_where_size + z_pres_size]
def latents_to_tensor(z):
return torch.stack(
[
torch.cat((z_where.cpu().data, z_pres.cpu().data), 1)
for z_where, z_pres in zip(*z)
]
).transpose(0, 1)
================================================
FILE: examples/air/main.py
================================================
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
AIR applied to the multi-mnist data set [1].
[1] Eslami, SM Ali, et al. "Attend, infer, repeat: Fast scene
understanding with generative models." Advances in Neural Information
Processing Systems. 2016.
"""
import argparse
import math
import os
import time
from functools import partial
import numpy as np
import torch
import visdom
from air import AIR, latents_to_tensor
from viz import draw_many, tensor_to_objs
import pyro
import pyro.contrib.examples.multi_mnist as multi_mnist
import pyro.optim as optim
import pyro.poutine as poutine
from pyro.contrib.examples.util import get_data_directory
from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO
def count_accuracy(X, true_counts, air, batch_size):
assert X.size(0) == true_counts.size(0), "Size mismatch."
assert X.size(0) % batch_size == 0, "Input size must be multiple of batch_size."
counts = torch.LongTensor(3, 4).zero_()
error_latents = []
error_indicators = []
def count_vec_to_mat(vec, max_index):
out = torch.LongTensor(vec.size(0), max_index + 1).zero_()
out.scatter_(1, vec.type(torch.LongTensor).view(vec.size(0), 1), 1)
return out
for i in range(X.size(0) // batch_size):
X_batch = X[i * batch_size : (i + 1) * batch_size]
true_counts_batch = true_counts[i * batch_size : (i + 1) * batch_size]
z_where, z_pres = air.guide(X_batch, batch_size)
inferred_counts = sum(z.cpu() for z in z_pres).squeeze().data
true_counts_m = count_vec_to_mat(true_counts_batch, 2)
inferred_counts_m = count_vec_to_mat(inferred_counts, 3)
counts += torch.mm(true_counts_m.t(), inferred_counts_m)
error_ind = 1 - (true_counts_batch == inferred_counts).long()
error_ix = error_ind.nonzero(as_tuple=False).squeeze()
error_latents.append(
latents_to_tensor((z_where, z_pres)).index_select(0, error_ix)
)
error_indicators.append(error_ind)
acc = counts.diag().sum().float() / X.size(0)
error_indices = torch.cat(error_indicators).nonzero(as_tuple=False).squeeze()
if X.is_cuda:
error_indices = error_indices.cuda()
return acc, counts, torch.cat(error_latents), error_indices
# Defines something like a truncated geometric. Like the geometric,
# this has the property that there's a constant difference in log prob
# between p(steps=n) and p(steps=n+1).
def make_prior(k):
assert 0 < k <= 1
u = 1 / (1 + k + k**2 + k**3)
p0 = 1 - u
p1 = 1 - (k * u) / p0
p2 = 1 - (k**2 * u) / (p0 * p1)
trial_probs = [p0, p1, p2]
# dist = [1 - p0, p0 * (1 - p1), p0 * p1 * (1 - p2), p0 * p1 * p2]
# print(dist)
return lambda t: trial_probs[t]
# Implements "prior annealing" as described in this blog post:
# http://akosiorek.github.io/ml/2017/09/03/implementing-air.html
# That implementation does something very close to the following:
# --z-pres-prior (1 - 1e-15)
# --z-pres-prior-raw
# --anneal-prior exp
# --anneal-prior-to 1e-7
# --anneal-prior-begin 1000
# --anneal-prior-duration 1e6
# e.g. After 200K steps z_pres_p will have decayed to ~0.04
# These compute the value of a decaying value at time t.
# initial: initial value
# final: final value, reached after begin + duration steps
# begin: number of steps before decay begins
# duration: number of steps over which decay occurs
# t: current time step
def lin_decay(initial, final, begin, duration, t):
assert duration > 0
x = (final - initial) * (t - begin) / duration + initial
return max(min(x, initial), final)
def exp_decay(initial, final, begin, duration, t):
assert final > 0
assert duration > 0
# half_life = math.log(2) / math.log(initial / final) * duration
decay_rate = math.log(initial / final) / duration
x = initial * math.exp(-decay_rate * (t - begin))
return max(min(x, initial), final)
def load_data():
inpath = get_data_directory(__file__)
X_np, Y = multi_mnist.load(inpath)
X_np = X_np.astype(np.float32)
X_np /= 255.0
X = torch.from_numpy(X_np)
# Using FloatTensor to allow comparison with values sampled from
# Bernoulli.
counts = torch.FloatTensor([len(objs) for objs in Y])
return X, counts
def main(**kwargs):
args = argparse.Namespace(**kwargs)
if "save" in args:
if os.path.exists(args.save):
raise RuntimeError('Output file "{}" already exists.'.format(args.save))
if args.seed is not None:
pyro.set_rng_seed(args.seed)
X, true_counts = load_data()
X_size = X.size(0)
if args.cuda:
X = X.cuda()
# Build a function to compute z_pres prior probabilities.
if args.z_pres_prior_raw:
def base_z_pres_prior_p(t):
return args.z_pres_prior
else:
base_z_pres_prior_p = make_prior(args.z_pres_prior)
# Wrap with logic to apply any annealing.
def z_pres_prior_p(opt_step, time_step):
p = base_z_pres_prior_p(time_step)
if args.anneal_prior == "none":
return p
else:
decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
return decay(
p,
args.anneal_prior_to,
args.anneal_prior_begin,
args.anneal_prior_duration,
opt_step,
)
model_arg_keys = [
"window_size",
"rnn_hidden_size",
"decoder_output_bias",
"decoder_output_use_sigmoid",
"baseline_scalar",
"encoder_net",
"decoder_net",
"predict_net",
"embed_net",
"bl_predict_net",
"non_linearity",
"pos_prior_mean",
"pos_prior_sd",
"scale_prior_mean",
"scale_prior_sd",
]
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
air = AIR(
num_steps=args.model_steps,
x_size=50,
use_masking=not args.no_masking,
use_baselines=not args.no_baselines,
z_what_size=args.encoder_latent_size,
use_cuda=args.cuda,
**model_args
)
if args.verbose:
print(air)
print(args)
if "load" in args:
print("Loading parameters...")
air.load_state_dict(torch.load(args.load, weights_only=False))
# Viz sample from prior.
if args.viz:
vis = visdom.Visdom(env=args.visdom_env)
z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))
def isBaselineParam(param_name):
return "bl_" in param_name
def per_param_optim_args(param_name):
lr = (
args.baseline_learning_rate
if isBaselineParam(param_name)
else args.learning_rate
)
return {"lr": lr}
adam = optim.Adam(per_param_optim_args)
elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
svi = SVI(air.model, air.guide, adam, loss=elbo)
# Do inference.
t0 = time.time()
examples_to_viz = X[5:10]
for i in range(1, args.num_steps + 1):
loss = svi.step(
X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)
)
if args.progress_every > 0 and i % args.progress_every == 0:
print(
"i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}".format(
i,
(i * args.batch_size) / X_size,
(time.time() - t0) / 3600,
loss / X_size,
)
)
if args.viz and i % args.viz_every == 0:
trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
z_wheres = tensor_to_objs(latents_to_tensor(z))
# Show data with inferred objection positions.
vis.images(draw_many(examples_to_viz, z_wheres))
# Show reconstructions of data.
vis.images(draw_many(recons, z_wheres))
if args.eval_every > 0 and i % args.eval_every == 0:
# Measure accuracy on subset of training data.
acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
print(
"i={}, accuracy={}, counts={}".format(i, acc, counts.numpy().tolist())
)
if args.viz and error_ix.size(0) > 0:
vis.images(
draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
opts=dict(caption="errors ({})".format(i)),
)
if "save" in args and i % args.save_every == 0:
print("Saving parameters...")
torch.save(air.state_dict(), args.save)
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Pyro AIR example", argument_default=argparse.SUPPRESS
)
parser.add_argument(
"-n",
"--num-steps",
type=int,
default=int(1e8),
help="number of optimization steps to take",
)
parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
parser.add_argument(
"-lr", "--learning-rate", type=float, default=1e-4, help="learning rate"
)
parser.add_argument(
"-blr",
"--baseline-learning-rate",
type=float,
default=1e-3,
help="baseline learning rate",
)
parser.add_argument(
"--progress-every",
type=int,
default=1,
help="number of steps between writing progress to stdout",
)
parser.add_argument(
"--eval-every", type=int, default=0, help="number of steps between evaluations"
)
parser.add_argument(
"--baseline-scalar",
type=float,
help="scale the output of the baseline nets by this value",
)
parser.add_argument(
"--no-baselines",
action="store_true",
default=False,
help="do not use data dependent baselines",
)
parser.add_argument(
"--encoder-net",
type=int,
nargs="+",
default=[200],
help="encoder net hidden layer sizes",
)
parser.add_argument(
"--decoder-net",
type=int,
nargs="+",
default=[200],
help="decoder net hidden layer sizes",
)
parser.add_argument(
"--predict-net", type=int, nargs="+", help="predict net hidden layer sizes"
)
parser.add_argument(
"--embed-net", type=int, nargs="+", help="embed net architecture"
)
parser.add_argument(
"--bl-predict-net",
type=int,
nargs="+",
help="baseline predict net hidden layer sizes",
)
parser.add_argument(
"--non-linearity", type=str, help="non linearity to use throughout"
)
parser.add_argument(
"--viz",
action="store_true",
default=False,
help="generate vizualizations during optimization",
)
parser.add_argument(
"--viz-every",
type=int,
default=100,
help="number of steps between vizualizations",
)
parser.add_argument("--visdom-env", default="main", help="visdom enviroment name")
parser.add_argument("--load", type=str, help="load previously saved parameters")
parser.add_argument("--save", type=str, help="save parameters to specified file")
parser.add_argument(
"--save-every",
type=int,
default=1e4,
help="number of steps between parameter saves",
)
parser.add_argument("--cuda", action="store_true", default=False, help="use cuda")
parser.add_argument(
"--jit", action="store_true", default=False, help="use PyTorch jit"
)
parser.add_argument(
"-t", "--model-steps", type=int, default=3, help="number of time steps"
)
parser.add_argument(
"--rnn-hidden-size", type=int, default=256, help="rnn hidden size"
)
parser.add_argument(
"--encoder-latent-size",
type=int,
default=50,
help="attention window encoder/decoder latent space size",
)
parser.add_argument(
"--decoder-output-bias",
type=float,
help="bias added to decoder output (prior to applying non-linearity)",
)
parser.add_argument(
"--decoder-output-use-sigmoid",
action="store_true",
help="apply sigmoid function to output of decoder network",
)
parser.add_argument(
"--window-size", type=int, default=28, help="attention window size"
)
parser.add_argument(
"--z-pres-prior",
type=float,
default=0.5,
help="prior success probability for z_pres",
)
parser.add_argument(
"--z-pres-prior-raw",
action="store_true",
default=False,
help="use --z-pres-prior directly as success prob instead of a geometric like prior",
)
parser.add_argument(
"--anneal-prior",
choices="none lin exp".split(),
default="none",
help="anneal z_pres prior during optimization",
)
parser.add_argument(
"--anneal-prior-to", type=float, default=1e-7, help="target z_pres prior prob"
)
parser.add_argument(
"--anneal-prior-begin",
type=int,
default=0,
help="number of steps to wait before beginning to anneal the prior",
)
parser.add_argument(
"--anneal-prior-duration",
type=int,
default=100000,
help="number of steps over which to anneal the prior",
)
parser.add_argument(
"--pos-prior-mean", type=float, help="mean of the window position prior"
)
parser.add_argument(
"--pos-prior-sd", type=float, help="std. dev. of the window position prior"
)
parser.add_argument(
"--scale-prior-mean", type=float, help="mean of the window scale prior"
)
parser.add_argument(
"--scale-prior-sd", type=float, help="std. dev. of the window scale prior"
)
parser.add_argument(
"--no-masking",
action="store_true",
default=False,
help="do not mask out the costs of unused choices",
)
parser.add_argument("--seed", type=int, help="random seed", default=None)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="write hyper parameters and network architecture to stdout",
)
main(**vars(parser.parse_args()))
================================================
FILE: examples/air/modules.py
================================================
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
from torch.nn.functional import softplus
# Takes pixel intensities of the attention window to parameters (mean,
# standard deviation) of the distribution over the latent code,
# z_what.
class Encoder(nn.Module):
def __init__(self, x_size, h_sizes, z_size, non_linear_layer):
super().__init__()
self.z_size = z_size
output_size = 2 * z_size
self.mlp = MLP(x_size, h_sizes + [output_size], non_linear_layer)
def forward(self, x):
a = self.mlp(x)
return a[:, 0 : self.z_size], softplus(a[:, self.z_size :])
# Takes a latent code, z_what, to pixel intensities.
class Decoder(nn.Module):
def __init__(self, x_size, h_sizes, z_size, bias, use_sigmoid, non_linear_layer):
super().__init__()
self.bias = bias
self.use_sigmoid = use_sigmoid
self.mlp = MLP(z_size, h_sizes + [x_size], non_linear_layer)
def forward(self, z):
a = self.mlp(z)
if self.bias is not None:
a = a + self.bias
return torch.sigmoid(a) if self.use_sigmoid else a
# A general purpose module to construct networks that look like:
# [Linear (256 -> 1)]
# [Linear (256 -> 256), ReLU (), Linear (256 -> 1)]
# [Linear (256 -> 256), ReLU (), Linear (256 -> 1), ReLU ()]
# etc.
class MLP(nn.Module):
def __init__(
self, in_size, out_sizes, non_linear_layer, output_non_linearity=False
):
super().__init__()
assert len(out_sizes) >= 1
layers = []
in_sizes = [in_size] + out_sizes[0:-1]
sizes = list(zip(in_sizes, out_sizes))
for i, o in sizes[0:-1]:
layers.append(nn.Linear(i, o))
layers.append(non_linear_layer())
layers.append(nn.Linear(sizes[-1][0], sizes[-1][1]))
if output_non_linearity:
layers.append(non_linear_layer())
self.seq = nn.Sequential(*layers)
def forward(self, x):
return self.seq(x)
# Takes the guide RNN hidden state to parameters of the guide
# distributions over z_where and z_pres.
class Predict(nn.Module):
def __init__(
self, input_size, h_sizes, z_pres_size, z_where_size, non_linear_layer
):
super().__init__()
self.z_pres_size = z_pres_size
self.z_where_size = z_where_size
output_size = z_pres_size + 2 * z_where_size
self.mlp = MLP(input_size, h_sizes + [output_size], non_linear_layer)
def forward(self, h):
out = self.mlp(h)
z_pres_p = torch.sigmoid(out[:, 0 : self.z_pres_size])
z_where_loc = out[:, self.z_pres_size : self.z_pres_size + self.z_where_size]
z_where_scale = softplus(out[:, (self.z_pres_size + self.z_where_size) :])
return z_pres_p, z_where_loc, z_where_scale
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
================================================
FILE: examples/air/viz.py
================================================
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import math
from collections import namedtuple
import numpy as np
from PIL import Image, ImageDraw
def bounding_box(z_where, x_size):
"""This doesn't take into account interpolation, but it's close
enough to be usable."""
w = x_size / z_where.s
h = x_size / z_where.s
xtrans = -z_where.x / z_where.s * x_size / 2.0
ytrans = -z_where.y / z_where.s * x_size / 2.0
x = (x_size - w) / 2 + xtrans # origin is top left
y = (x_size - h) / 2 + ytrans
return (x, y), w, h
def arr2img(arr):
# arr is expected to be a 2d array of floats in [0,1]
return Image.frombuffer(
"L", arr.shape, (arr * 255).astype(np.uint8).tostring(), "raw", "L", 0, 1
)
def img2arr(img):
# assumes color image
# returns an array suitable for sending to visdom
return (
np.array(img.getdata(), np.uint8).reshape(img.size + (3,)).transpose((2, 0, 1))
)
def colors(k):
return [(255, 0, 0), (0, 255, 0), (0, 0, 255)][k % 3]
def draw_one(imgarr, z_arr):
# Note that this clipping makes the visualisation somewhat
# misleading, as it incorrectly suggests objects occlude one
# another.
clipped = np.clip(imgarr.detach().cpu().numpy(), 0, 1)
img = arr2img(clipped).convert("RGB")
draw = ImageDraw.Draw(img)
for k, z in enumerate(z_arr):
# It would be better to use z_pres to change the opacity of
# the bounding boxes, but I couldn't make that work with PIL.
# Instead this darkens the color, and skips boxes altogether
# when z_pres==0.
if z.pres > 0:
(x, y), w, h = bounding_box(z, imgarr.size(0))
color = tuple(map(lambda c: int(c * z.pres), colors(k)))
draw.rectangle([x, y, x + w, y + h], outline=color)
is_relaxed = any(z.pres != math.floor(z.pres) for z in z_arr)
fmtstr = "{:.1f}" if is_relaxed else "{:.0f}"
draw.text((0, 0), fmtstr.format(sum(z.pres for z in z_arr)), fill="white")
return img2arr(img)
def draw_many(imgarrs, z_arr):
# canvases is expected to be a (n,w,h) numpy array
# z_where_arr is expected to be a list of length n
return [draw_one(imgarr, z) for (imgarr, z) in zip(imgarrs.cpu(), z_arr)]
z_obj = namedtuple("z", "s,x,y,pres")
# Map a tensor of latents (as produced by latents_to_tensor) to a list
# of z_obj named tuples.
def tensor_to_objs(latents):
return [[z_obj._make(step) for step in z] for z in latents]
================================================
FILE: examples/baseball.py
================================================
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import logging
import math
import pandas as pd
import torch
import pyro
from pyro.distributions import Beta, Binomial, HalfCauchy, Normal, Pareto, Uniform
from pyro.distributions.util import scalar_like
from pyro.infer import MCMC, NUTS, Predictive
from pyro.infer.mcmc.util import initialize_model, summary
from pyro.util import ignore_experimental_warning
"""
Example has been adapted from [1]. It demonstrates how to do Bayesian inference using
NUTS (or, HMC) in Pyro, and use of some common inference utilities.
As in the Stan tutorial, this uses the small baseball dataset of Efron and Morris [2]
to estimate players' batting average which is the fraction of times a player got a
base hit out of the number of times they went up at bat.
The dataset separates the initial 45 at-bats statistics from the remaining season.
We use the hits data from the initial 45 at-bats to estimate the batting average
for each player. We then use the remaining season's data to validate the predictions
from our models.
Three models are evaluated:
- Complete pooling model: The success probability of scoring a hit is shared
amongst all players.
- No pooling model: Each individual player's success probability is distinct and
there is no data sharing amongst players.
- Partial pooling model: A hierarchical model with partial data sharing.
We recommend Radford Neal's tutorial on HMC ([3]) to users who would like to get a
more comprehensive understanding of HMC and its variants, and to [4] for details on
the No U-Turn Sampler, which provides an efficient and automated way (i.e. limited
hyper-parameters) of running HMC on different problems.
[1] Carpenter B. (2016), ["Hierarchical Partial Pooling for Repeated Binary Trials"]
(http://mc-stan.org/users/documentation/case-studies/pool-binary-trials.html).
[2] Efron B., Morris C. (1975), "Data analysis using Stein's estimator and its
generalizations", J. Amer. Statist. Assoc., 70, 311-319.
[3] Neal, R. (2012), "MCMC using Hamiltonian Dynamics",
(https://arxiv.org/pdf/1206.1901.pdf)
[4] Hoffman, M. D. and Gelman, A. (2014), "The No-U-turn sampler: Adaptively setting
path lengths in Hamiltonian Monte Carlo", (https://arxiv.org/abs/1111.4246)
"""
logging.basicConfig(format="%(message)s", level=logging.INFO)
DATA_URL = "https://github.com/pyro-ppl/datasets/blob/master/EfronMorrisBB.txt?raw=true"
# ===================================
# MODELS
# ===================================
def fully_pooled(at_bats, hits):
r"""
Number of hits in $K$ at bats for each player has a Binomial
distribution with a common probability of success, $\phi$.
:param (torch.Tensor) at_bats: Number of at bats for each player.
:param (torch.Tensor) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))
phi = pyro.sample("phi", phi_prior)
num_players = at_bats.shape[0]
with pyro.plate("num_players", num_players):
return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
def not_pooled(at_bats, hits):
r"""
Number of hits in $K$ at bats for each player has a Binomial
distribution with independent probability of success, $\phi_i$.
:param (torch.Tensor) at_bats: Number of at bats for each player.
:param (torch.Tensor) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
num_players = at_bats.shape[0]
with pyro.plate("num_players", num_players):
phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))
phi = pyro.sample("phi", phi_prior)
return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
def partially_pooled(at_bats, hits):
r"""
Number of hits has a Binomial distribution with independent
probability of success, $\phi_i$. Each $\phi_i$ follows a Beta
distribution with concentration parameters $c_1$ and $c_2$, where
$c_1 = m * kappa$, $c_2 = (1 - m) * kappa$, $m ~ Uniform(0, 1)$,
and $kappa ~ Pareto(1, 1.5)$.
:param (torch.Tensor) at_bats: Number of at bats for each player.
:param (torch.Tensor) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
num_players = at_bats.shape[0]
m = pyro.sample("m", Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1)))
kappa = pyro.sample(
"kappa", Pareto(scalar_like(at_bats, 1), scalar_like(at_bats, 1.5))
)
with pyro.plate("num_players", num_players):
phi_prior = Beta(m * kappa, (1 - m) * kappa)
phi = pyro.sample("phi", phi_prior)
return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
def partially_pooled_with_logit(at_bats, hits):
r"""
Number of hits has a Binomial distribution with a logit link function.
The logits $\alpha$ for each player is normally distributed with the
mean and scale parameters sharing a common prior.
:param (torch.Tensor) at_bats: Number of at bats for each player.
:param (torch.Tensor) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
num_players = at_bats.shape[0]
loc = pyro.sample("loc", Normal(scalar_like(at_bats, -1), scalar_like(at_bats, 1)))
scale = pyro.sample("scale", HalfCauchy(scale=scalar_like(at_bats, 1)))
with pyro.plate("num_players", num_players):
alpha = pyro.sample("alpha", Normal(loc, scale))
return pyro.sample("obs", Binomial(at_bats, logits=alpha), obs=hits)
# ===================================
# DATA SUMMARIZE UTILS
# ===================================
def get_summary_table(
posterior,
sites,
player_names,
transforms={},
diagnostics=False,
group_by_chain=False,
):
"""
Return summarized statistics for each of the ``sites`` in the
traces corresponding to the approximate posterior.
"""
site_stats = {}
for site_name in sites:
marginal_site = posterior[site_name].cpu()
if site_name in transforms:
marginal_site = transforms[site_name](marginal_site)
site_summary = summary(
{site_name: marginal_site}, prob=0.5, group_by_chain=group_by_chain
)[site_name]
if site_summary["mean"].shape:
site_df = pd.DataFrame(site_summary, index=player_names)
else:
site_summary = {k: float(v) for k, v in site_summary.items()}
site_df = pd.DataFrame(site_summary, index=[0])
if not diagnostics:
site_df = site_df.drop(["n_eff", "r_hat"], axis=1)
site_stats[site_name] = site_df.astype(float).round(2)
return site_stats
def train_test_split(pd_dataframe):
"""
Training data - 45 initial at-bats and hits for each player.
Validation data - Full season at-bats and hits for each player.
"""
device = torch.Tensor().device
train_data = torch.tensor(
pd_dataframe[["At-Bats", "Hits"]].values, dtype=torch.float, device=device
)
test_data = torch.tensor(
pd_dataframe[["SeasonAt-Bats", "SeasonHits"]].values,
dtype=torch.float,
device=device,
)
first_name = pd_dataframe["FirstName"].values
last_name = pd_dataframe["LastName"].values
player_names = [
" ".join([first, last]) for first, last in zip(first_name, last_name)
]
return train_data, test_data, player_names
# ===================================
# MODEL EVALUATION UTILS
# ===================================
def sample_posterior_predictive(model, posterior_samples, baseball_dataset):
"""
Generate samples from posterior predictive distribution.
"""
train, test, player_names = train_test_split(baseball_dataset)
at_bats = train[:, 0]
at_bats_season = test[:, 0]
logging.Formatter("%(message)s")
logging.info("\nPosterior Predictive:")
logging.info("Hit Rate - Initial 45 At Bats")
logging.info("-----------------------------")
# set hits=None to convert it from observation node to sample node
train_predict = Predictive(model, posterior_samples)(at_bats, None)
train_summary = get_summary_table(
train_predict, sites=["obs"], player_names=player_names
)["obs"]
train_summary = train_summary.assign(ActualHits=baseball_dataset[["Hits"]].values)
logging.info(train_summary)
logging.info("\nHit Rate - Season Predictions")
logging.info("-----------------------------")
with ignore_experimental_warning():
test_predict = Predictive(model, posterior_samples)(at_bats_season, None)
test_summary = get_summary_table(
test_predict, sites=["obs"], player_names=player_names
)["obs"]
test_summary = test_summary.assign(
ActualHits=baseball_dataset[["SeasonHits"]].values
)
logging.info(test_summary)
def evaluate_pointwise_pred_density(model, posterior_samples, baseball_dataset):
"""
Evaluate the log probability density of observing the unseen data (season hits)
given a model and posterior distribution over the parameters.
"""
_, test, player_names = train_test_split(baseball_dataset)
at_bats_season, hits_season = test[:, 0], test[:, 1]
trace = Predictive(model, posterior_samples).get_vectorized_trace(
at_bats_season, hits_season
)
# Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $,
# where $\theta^{i}$ are parameter samples from the model's posterior.
trace.compute_log_prob()
post_loglik = trace.nodes["obs"]["log_prob"]
# computes expected log predictive density at each data point
exp_log_density = (post_loglik.logsumexp(0) - math.log(post_loglik.shape[0])).sum()
logging.info("\nLog pointwise predictive density")
logging.info("--------------------------------")
logging.info("{:.4f}\n".format(exp_log_density))
def main(args):
baseball_dataset = pd.read_csv(DATA_URL, sep="\t")
train, _, player_names = train_test_split(baseball_dataset)
at_bats, hits = train[:, 0], train[:, 1]
logging.info("Original Dataset:")
logging.info(baseball_dataset)
# (1) Full Pooling Model
# In this model, we illustrate how to use MCMC with general potential_fn.
init_params, potential_fn, transforms, _ = initialize_model(
fully_pooled,
model_args=(at_bats, hits),
num_chains=args.num_chains,
jit_compile=args.jit,
skip_jit_warnings=True,
)
nuts_kernel = NUTS(potential_fn=potential_fn)
mcmc = MCMC(
nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains,
initial_params=init_params,
transforms=transforms,
)
mcmc.run(at_bats, hits)
samples_fully_pooled = mcmc.get_samples()
logging.info("\nModel: Fully Pooled")
logging.info("===================")
logging.info("\nphi:")
logging.info(
get_summary_table(
mcmc.get_samples(group_by_chain=True),
sites=["phi"],
player_names=player_names,
diagnostics=True,
group_by_chain=True,
)["phi"]
)
num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset)
evaluate_pointwise_pred_density(
fully_pooled, samples_fully_pooled, baseball_dataset
)
# (2) No Pooling Model
nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
mcmc = MCMC(
nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains,
)
mcmc.run(at_bats, hits)
samples_not_pooled = mcmc.get_samples()
logging.info("\nModel: Not Pooled")
logging.info("=================")
logging.info("\nphi:")
logging.info(
get_summary_table(
mcmc.get_samples(group_by_chain=True),
sites=["phi"],
player_names=player_names,
diagnostics=True,
group_by_chain=True,
)["phi"]
)
num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset)
evaluate_pointwise_pred_density(not_pooled, samples_not_pooled, baseball_dataset)
# (3) Partially Pooled Model
nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
mcmc = MCMC(
nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains,
)
mcmc.run(at_bats, hits)
samples_partially_pooled = mcmc.get_samples()
logging.info("\nModel: Partially Pooled")
logging.info("=======================")
logging.info("\nphi:")
logging.info(
get_summary_table(
mcmc.get_samples(group_by_chain=True),
sites=["phi"],
player_names=player_names,
diagnostics=True,
group_by_chain=True,
)["phi"]
)
num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
sample_posterior_predictive(
partially_pooled, samples_partially_pooled, baseball_dataset
)
evaluate_pointwise_pred_density(
partially_pooled, samples_partially_pooled, baseball_dataset
)
# (4) Partially Pooled with Logit Model
nuts_kernel = NUTS(
partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True
)
mcmc = MCMC(
nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains,
)
mcmc.run(at_bats, hits)
samples_partially_pooled_logit = mcmc.get_samples()
logging.info("\nModel: Partially Pooled with Logit")
logging.info("==================================")
logging.info("\nSigmoid(alpha):")
logging.info(
get_summary_table(
mcmc.get_samples(group_by_chain=True),
sites=["alpha"],
player_names=player_names,
transforms={"alpha": torch.sigmoid},
diagnostics=True,
group_by_chain=True,
)["alpha"]
)
num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
sample_posterior_predictive(
partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset
)
evaluate_pointwise_pred_density(
partially_pooled_with_logit, samples_partially_pooled_logit, baseball_dataset
)
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=200, type=int)
parser.add_argument("--num-chains", nargs="?", default=4, type=int)
parser.add_argument("--warmup-steps", nargs="?", default=100, type=int)
parser.add_argument("--rng_seed", nargs="?", default=0, type=int)
parser.add_argument(
"--jit", action="store_true", default=False, help="use PyTorch jit"
)
parser.add_argument(
"--cuda", action="store_true", default=False, help="run this example in GPU"
)
args = parser.parse_args()
# work around the error "CUDA error: initialization error"
# see https://github.com/pytorch/pytorch/issues/2517
torch.multiprocessing.set_start_method("spawn")
pyro.set_rng_seed(args.rng_seed)
# Enable validation checks
# work around with the error "RuntimeError: received 0 items of ancdata"
# see https://discuss.pytorch.org/t/received-0-items-of-ancdata-pytorch-0-4-0/19823
torch.multiprocessing.set_sharing_strategy("file_system")
if args.cuda:
torch.set_default_device("cuda")
main(args)
================================================
FILE: examples/capture_recapture/cjs.py
================================================
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
We show how to implement several variants of the Cormack-Jolly-Seber (CJS)
[4, 5, 6] model used in ecology to analyze animal capture-recapture data.
For a discussion of these models see reference [1].
We make use of two datasets:
-- the European Dipper (Cinclus cinclus) data from reference [2]
(this is Norway's national bird).
-- the meadow voles data from reference [3].
Compare to the Stan implementations in [7].
References
[1] Kery, M., & Schaub, M. (2011). Bayesian population analysis using
WinBUGS: a hierarchical perspective. Academic Press.
[2] Lebreton, J.D., Burnham, K.P., Clobert, J., & Anderson, D.R. (1992).
Modeling survival and testing biological hypotheses using marked animals:
a unified approach with case studies. Ecological monographs, 62(1), 67-118.
[3] Nichols, Pollock, Hines (1984) The use of a robust capture-recapture design
in small mammal population studies: A field example with Microtus pennsylvanicus.
Acta Theriologica 29:357-365.
[4] Cormack, R.M., 1964. Estimates of survival from the sighting of marked animals.
Biometrika 51, 429-438.
[5] Jolly, G.M., 1965. Explicit estimates from capture-recapture data with both death
and immigration-stochastic model. Biometrika 52, 225-247.
[6] Seber, G.A.F., 1965. A note on the multiple recapture census. Biometrika 52, 249-259.
[7] https://github.com/stan-dev/example-models/tree/master/BPA/Ch.07
"""
import argparse
import os
import numpy as np
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam
"""
Our first and simplest CJS model variant only has two continuous
(scalar) latent random variables: i) the survival probability phi;
and ii) the recapture probability rho. These are treated as fixed
effects with no temporal or individual/group variation.
"""
def model_1(capture_history, sex):
N, T = capture_history.shape
phi = pyro.sample("phi", dist.Uniform(0.0, 1.0)) # survival probability
rho = pyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability
with pyro.plate("animals", N, dim=-1):
z = torch.ones(N)
# we use this mask to el
gitextract_cigk8yyx/
├── .codecov.yml
├── .coveragerc
├── .gitattributes
├── .github/
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE/
│ │ ├── config.yml
│ │ └── issue_template.md
│ └── workflows/
│ ├── ci.yml
│ └── publish.yml
├── .gitignore
├── .readthedocs.yml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.md
├── LICENSES/
│ ├── Apache-2.0.txt
│ ├── BSD-3-Clause.txt
│ └── MIT.txt
├── MANIFEST.in
├── Makefile
├── README.md
├── RELEASE-MANAGEMENT.md
├── docker/
│ ├── Dockerfile
│ ├── Makefile
│ ├── README.md
│ └── install.sh
├── docs/
│ ├── Makefile
│ ├── README.md
│ ├── requirements.txt
│ └── source/
│ ├── _static/
│ │ ├── css/
│ │ │ └── pyro.css
│ │ └── img/
│ │ └── favicon/
│ │ ├── browserconfig.xml
│ │ └── manifest.json
│ ├── conf.py
│ ├── contrib.autoname.rst
│ ├── contrib.bnn.rst
│ ├── contrib.cevae.rst
│ ├── contrib.easyguide.rst
│ ├── contrib.epidemiology.rst
│ ├── contrib.examples.rst
│ ├── contrib.forecast.rst
│ ├── contrib.funsor.rst
│ ├── contrib.gp.rst
│ ├── contrib.minipyro.rst
│ ├── contrib.mue.rst
│ ├── contrib.oed.rst
│ ├── contrib.randomvariable.rst
│ ├── contrib.timeseries.rst
│ ├── contrib.tracking.rst
│ ├── contrib.zuko.rst
│ ├── distributions.rst
│ ├── getting_started.rst
│ ├── index.rst
│ ├── infer.autoguide.rst
│ ├── infer.reparam.rst
│ ├── infer.util.rst
│ ├── inference.rst
│ ├── inference_algos.rst
│ ├── mcmc.rst
│ ├── nn.rst
│ ├── ops.rst
│ ├── optimization.rst
│ ├── parameters.rst
│ ├── poutine.rst
│ ├── primitives.rst
│ ├── pyro.infer.mcmc.txt
│ ├── pyro.optim.txt
│ ├── pyro.poutine.txt
│ ├── settings.rst
│ └── testing.rst
├── examples/
│ ├── __init__.py
│ ├── air/
│ │ ├── air.py
│ │ ├── main.py
│ │ ├── modules.py
│ │ └── viz.py
│ ├── baseball.py
│ ├── capture_recapture/
│ │ └── cjs.py
│ ├── contrib/
│ │ ├── __init__.py
│ │ ├── autoname/
│ │ │ ├── mixture.py
│ │ │ ├── scoping_mixture.py
│ │ │ └── tree_data.py
│ │ ├── cevae/
│ │ │ └── synthetic.py
│ │ ├── epidemiology/
│ │ │ ├── regional.py
│ │ │ └── sir.py
│ │ ├── forecast/
│ │ │ └── bart.py
│ │ ├── funsor/
│ │ │ ├── __init__.py
│ │ │ └── hmm.py
│ │ ├── gp/
│ │ │ └── sv-dkl.py
│ │ ├── mue/
│ │ │ ├── FactorMuE.py
│ │ │ └── ProfileHMM.py
│ │ ├── oed/
│ │ │ ├── ab_test.py
│ │ │ └── gp_bayes_opt.py
│ │ └── timeseries/
│ │ └── gp_models.py
│ ├── cvae/
│ │ ├── __init__.py
│ │ ├── baseline.py
│ │ ├── cvae.py
│ │ ├── main.py
│ │ ├── mnist.py
│ │ └── util.py
│ ├── dmm.py
│ ├── eight_schools/
│ │ ├── README.md
│ │ ├── data.py
│ │ ├── mcmc.py
│ │ └── svi.py
│ ├── einsum.py
│ ├── hmm.py
│ ├── inclined_plane.py
│ ├── lda.py
│ ├── lkj.py
│ ├── minipyro.py
│ ├── mixed_hmm/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── experiment.py
│ │ ├── model.py
│ │ └── seal_data.py
│ ├── neutra.py
│ ├── rsa/
│ │ ├── README.md
│ │ ├── generics.py
│ │ ├── hyperbole.py
│ │ ├── schelling.py
│ │ ├── schelling_false.py
│ │ ├── search_inference.py
│ │ └── semantic_parsing.py
│ ├── scanvi/
│ │ ├── __init__.py
│ │ └── scanvi.py
│ ├── sir_hmc.py
│ ├── smcfilter.py
│ ├── sparse_gamma_def.py
│ ├── sparse_regression.py
│ ├── svi_horovod.py
│ ├── svi_lightning.py
│ ├── svi_torch.py
│ ├── toy_mixture_model_discrete_enumeration.py
│ └── vae/
│ ├── ss_vae_M2.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── custom_mlp.py
│ │ ├── mnist_cached.py
│ │ └── vae_plots.py
│ ├── vae.py
│ └── vae_comparison.py
├── profiler/
│ ├── __init__.py
│ ├── distributions.py
│ ├── gaussianhmm.py
│ ├── hmm.py
│ └── profiling_utils.py
├── pyproject.toml
├── pyro/
│ ├── __init__.py
│ ├── contrib/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── autoguide.py
│ │ ├── autoname/
│ │ │ ├── __init__.py
│ │ │ ├── autoname.py
│ │ │ ├── named.py
│ │ │ └── scoping.py
│ │ ├── bnn/
│ │ │ ├── __init__.py
│ │ │ ├── hidden_layer.py
│ │ │ └── utils.py
│ │ ├── cevae/
│ │ │ └── __init__.py
│ │ ├── conjugate/
│ │ │ ├── __init__.py
│ │ │ └── infer.py
│ │ ├── easyguide/
│ │ │ ├── __init__.py
│ │ │ └── easyguide.py
│ │ ├── epidemiology/
│ │ │ ├── __init__.py
│ │ │ ├── compartmental.py
│ │ │ ├── distributions.py
│ │ │ ├── models.py
│ │ │ └── util.py
│ │ ├── examples/
│ │ │ ├── __init__.py
│ │ │ ├── bart.py
│ │ │ ├── finance.py
│ │ │ ├── multi_mnist.py
│ │ │ ├── nextstrain.py
│ │ │ ├── polyphonic_data_loader.py
│ │ │ ├── scanvi_data.py
│ │ │ └── util.py
│ │ ├── forecast/
│ │ │ ├── __init__.py
│ │ │ ├── evaluate.py
│ │ │ ├── forecaster.py
│ │ │ └── util.py
│ │ ├── funsor/
│ │ │ ├── __init__.py
│ │ │ ├── handlers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── enum_messenger.py
│ │ │ │ ├── named_messenger.py
│ │ │ │ ├── plate_messenger.py
│ │ │ │ ├── primitives.py
│ │ │ │ ├── replay_messenger.py
│ │ │ │ ├── runtime.py
│ │ │ │ └── trace_messenger.py
│ │ │ └── infer/
│ │ │ ├── __init__.py
│ │ │ ├── discrete.py
│ │ │ ├── elbo.py
│ │ │ ├── trace_elbo.py
│ │ │ ├── traceenum_elbo.py
│ │ │ └── tracetmc_elbo.py
│ │ ├── gp/
│ │ │ ├── __init__.py
│ │ │ ├── kernels/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── brownian.py
│ │ │ │ ├── coregionalize.py
│ │ │ │ ├── dot_product.py
│ │ │ │ ├── isotropic.py
│ │ │ │ ├── kernel.py
│ │ │ │ ├── periodic.py
│ │ │ │ └── static.py
│ │ │ ├── likelihoods/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── binary.py
│ │ │ │ ├── gaussian.py
│ │ │ │ ├── likelihood.py
│ │ │ │ ├── multi_class.py
│ │ │ │ └── poisson.py
│ │ │ ├── models/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── gplvm.py
│ │ │ │ ├── gpr.py
│ │ │ │ ├── model.py
│ │ │ │ ├── sgpr.py
│ │ │ │ ├── vgp.py
│ │ │ │ └── vsgp.py
│ │ │ ├── parameterized.py
│ │ │ └── util.py
│ │ ├── minipyro.py
│ │ ├── mue/
│ │ │ ├── __init__.py
│ │ │ ├── dataloaders.py
│ │ │ ├── missingdatahmm.py
│ │ │ ├── models.py
│ │ │ └── statearrangers.py
│ │ ├── oed/
│ │ │ ├── __init__.py
│ │ │ ├── eig.py
│ │ │ ├── glmm/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── glmm.py
│ │ │ │ └── guides.py
│ │ │ ├── search.py
│ │ │ └── util.py
│ │ ├── randomvariable/
│ │ │ ├── __init__.py
│ │ │ └── random_variable.py
│ │ ├── timeseries/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── gp.py
│ │ │ ├── lgssm.py
│ │ │ └── lgssmgp.py
│ │ ├── tracking/
│ │ │ ├── __init__.py
│ │ │ ├── assignment.py
│ │ │ ├── distributions.py
│ │ │ ├── dynamic_models.py
│ │ │ ├── extended_kalman_filter.py
│ │ │ ├── hashing.py
│ │ │ └── measurements.py
│ │ ├── util.py
│ │ └── zuko.py
│ ├── distributions/
│ │ ├── __init__.py
│ │ ├── affine_beta.py
│ │ ├── asymmetriclaplace.py
│ │ ├── avf_mvn.py
│ │ ├── coalescent.py
│ │ ├── conditional.py
│ │ ├── conjugate.py
│ │ ├── constraints.py
│ │ ├── delta.py
│ │ ├── diag_normal_mixture.py
│ │ ├── diag_normal_mixture_shared_cov.py
│ │ ├── distribution.py
│ │ ├── empirical.py
│ │ ├── extended.py
│ │ ├── folded.py
│ │ ├── gaussian_scale_mixture.py
│ │ ├── grouped_normal_normal.py
│ │ ├── hmm.py
│ │ ├── improper_uniform.py
│ │ ├── inverse_gamma.py
│ │ ├── kl.py
│ │ ├── lkj.py
│ │ ├── log_normal_negative_binomial.py
│ │ ├── logistic.py
│ │ ├── mixture.py
│ │ ├── multivariate_studentt.py
│ │ ├── nanmasked.py
│ │ ├── omt_mvn.py
│ │ ├── one_one_matching.py
│ │ ├── one_two_matching.py
│ │ ├── ordered_logistic.py
│ │ ├── polya_gamma.py
│ │ ├── projected_normal.py
│ │ ├── rejector.py
│ │ ├── relaxed_straight_through.py
│ │ ├── score_parts.py
│ │ ├── sine_bivariate_von_mises.py
│ │ ├── sine_skewed.py
│ │ ├── softlaplace.py
│ │ ├── spanning_tree.cpp
│ │ ├── spanning_tree.py
│ │ ├── stable.py
│ │ ├── stable_log_prob.py
│ │ ├── testing/
│ │ │ ├── __init__.py
│ │ │ ├── fakes.py
│ │ │ ├── gof.py
│ │ │ ├── naive_dirichlet.py
│ │ │ ├── rejection_exponential.py
│ │ │ ├── rejection_gamma.py
│ │ │ └── special.py
│ │ ├── torch.py
│ │ ├── torch_distribution.py
│ │ ├── torch_patch.py
│ │ ├── torch_transform.py
│ │ ├── transforms/
│ │ │ ├── __init__.py
│ │ │ ├── affine_autoregressive.py
│ │ │ ├── affine_coupling.py
│ │ │ ├── basic.py
│ │ │ ├── batchnorm.py
│ │ │ ├── block_autoregressive.py
│ │ │ ├── cholesky.py
│ │ │ ├── discrete_cosine.py
│ │ │ ├── generalized_channel_permute.py
│ │ │ ├── haar.py
│ │ │ ├── householder.py
│ │ │ ├── lower_cholesky_affine.py
│ │ │ ├── matrix_exponential.py
│ │ │ ├── neural_autoregressive.py
│ │ │ ├── normalize.py
│ │ │ ├── ordered.py
│ │ │ ├── permute.py
│ │ │ ├── planar.py
│ │ │ ├── polynomial.py
│ │ │ ├── power.py
│ │ │ ├── radial.py
│ │ │ ├── simplex_to_ordered.py
│ │ │ ├── softplus.py
│ │ │ ├── spline.py
│ │ │ ├── spline_autoregressive.py
│ │ │ ├── spline_coupling.py
│ │ │ ├── sylvester.py
│ │ │ ├── unit_cholesky.py
│ │ │ └── utils.py
│ │ ├── unit.py
│ │ ├── util.py
│ │ ├── von_mises_3d.py
│ │ └── zero_inflated.py
│ ├── generic.py
│ ├── infer/
│ │ ├── __init__.py
│ │ ├── abstract_infer.py
│ │ ├── autoguide/
│ │ │ ├── __init__.py
│ │ │ ├── effect.py
│ │ │ ├── gaussian.py
│ │ │ ├── guides.py
│ │ │ ├── initialization.py
│ │ │ ├── structured.py
│ │ │ └── utils.py
│ │ ├── csis.py
│ │ ├── discrete.py
│ │ ├── elbo.py
│ │ ├── energy_distance.py
│ │ ├── enum.py
│ │ ├── importance.py
│ │ ├── inspect.py
│ │ ├── mcmc/
│ │ │ ├── __init__.py
│ │ │ ├── adaptation.py
│ │ │ ├── api.py
│ │ │ ├── hmc.py
│ │ │ ├── logger.py
│ │ │ ├── mcmc_kernel.py
│ │ │ ├── nuts.py
│ │ │ ├── rwkernel.py
│ │ │ └── util.py
│ │ ├── predictive.py
│ │ ├── renyi_elbo.py
│ │ ├── reparam/
│ │ │ ├── __init__.py
│ │ │ ├── conjugate.py
│ │ │ ├── discrete_cosine.py
│ │ │ ├── haar.py
│ │ │ ├── hmm.py
│ │ │ ├── loc_scale.py
│ │ │ ├── neutra.py
│ │ │ ├── projected_normal.py
│ │ │ ├── reparam.py
│ │ │ ├── softmax.py
│ │ │ ├── split.py
│ │ │ ├── stable.py
│ │ │ ├── strategies.py
│ │ │ ├── structured.py
│ │ │ ├── studentt.py
│ │ │ ├── transform.py
│ │ │ └── unit_jacobian.py
│ │ ├── resampler.py
│ │ ├── rws.py
│ │ ├── smcfilter.py
│ │ ├── svgd.py
│ │ ├── svi.py
│ │ ├── trace_elbo.py
│ │ ├── trace_mean_field_elbo.py
│ │ ├── trace_mmd.py
│ │ ├── trace_tail_adaptive_elbo.py
│ │ ├── traceenum_elbo.py
│ │ ├── tracegraph_elbo.py
│ │ ├── tracetmc_elbo.py
│ │ └── util.py
│ ├── logger.py
│ ├── nn/
│ │ ├── __init__.py
│ │ ├── auto_reg_nn.py
│ │ ├── dense_nn.py
│ │ └── module.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── arrowhead.py
│ │ ├── contract.py
│ │ ├── dual_averaging.py
│ │ ├── einsum/
│ │ │ ├── __init__.py
│ │ │ ├── adjoint.py
│ │ │ ├── torch_log.py
│ │ │ ├── torch_map.py
│ │ │ ├── torch_marginal.py
│ │ │ ├── torch_sample.py
│ │ │ └── util.py
│ │ ├── gamma_gaussian.py
│ │ ├── gaussian.py
│ │ ├── hessian.py
│ │ ├── indexing.py
│ │ ├── integrator.py
│ │ ├── jit.py
│ │ ├── linalg.py
│ │ ├── newton.py
│ │ ├── packed.py
│ │ ├── provenance.py
│ │ ├── rings.py
│ │ ├── special.py
│ │ ├── ssm_gp.py
│ │ ├── stats.py
│ │ ├── streaming.py
│ │ ├── tensor_utils.py
│ │ └── welford.py
│ ├── optim/
│ │ ├── __init__.py
│ │ ├── adagrad_rmsprop.py
│ │ ├── clipped_adam.py
│ │ ├── dct_adam.py
│ │ ├── horovod.py
│ │ ├── lr_scheduler.py
│ │ ├── multi.py
│ │ ├── optim.py
│ │ └── pytorch_optimizers.py
│ ├── params/
│ │ ├── __init__.py
│ │ └── param_store.py
│ ├── poutine/
│ │ ├── __init__.py
│ │ ├── block_messenger.py
│ │ ├── broadcast_messenger.py
│ │ ├── collapse_messenger.py
│ │ ├── condition_messenger.py
│ │ ├── do_messenger.py
│ │ ├── enum_messenger.py
│ │ ├── equalize_messenger.py
│ │ ├── escape_messenger.py
│ │ ├── guide.py
│ │ ├── handlers.py
│ │ ├── indep_messenger.py
│ │ ├── infer_config_messenger.py
│ │ ├── lift_messenger.py
│ │ ├── markov_messenger.py
│ │ ├── mask_messenger.py
│ │ ├── messenger.py
│ │ ├── plate_messenger.py
│ │ ├── reentrant_messenger.py
│ │ ├── reparam_messenger.py
│ │ ├── replay_messenger.py
│ │ ├── runtime.py
│ │ ├── scale_messenger.py
│ │ ├── seed_messenger.py
│ │ ├── subsample_messenger.py
│ │ ├── substitute_messenger.py
│ │ ├── trace_messenger.py
│ │ ├── trace_struct.py
│ │ ├── uncondition_messenger.py
│ │ └── util.py
│ ├── primitives.py
│ ├── py.typed
│ ├── settings.py
│ └── util.py
├── scripts/
│ ├── install_pytorch.sh
│ ├── perf_test.sh
│ ├── profile_model.sh
│ ├── update_headers.py
│ └── update_version.py
├── setup.cfg
├── setup.py
├── tests/
│ ├── README.md
│ ├── __init__.py
│ ├── common.py
│ ├── conftest.py
│ ├── contrib/
│ │ ├── __init__.py
│ │ ├── autoname/
│ │ │ ├── test_autoname.py
│ │ │ ├── test_named.py
│ │ │ └── test_scoping.py
│ │ ├── bnn/
│ │ │ └── test_hidden_layer.py
│ │ ├── cevae/
│ │ │ └── test_cevae.py
│ │ ├── conftest.py
│ │ ├── easyguide/
│ │ │ └── test_easyguide.py
│ │ ├── epidemiology/
│ │ │ ├── __init__.py
│ │ │ ├── test_distributions.py
│ │ │ ├── test_models.py
│ │ │ ├── test_quant.py
│ │ │ └── test_util.py
│ │ ├── forecast/
│ │ │ ├── __init__.py
│ │ │ ├── test_evaluate.py
│ │ │ ├── test_forecaster.py
│ │ │ └── test_util.py
│ │ ├── funsor/
│ │ │ ├── conftest.py
│ │ │ ├── test_enum_funsor.py
│ │ │ ├── test_infer_discrete.py
│ │ │ ├── test_named_handlers.py
│ │ │ ├── test_pyroapi_funsor.py
│ │ │ ├── test_tmc.py
│ │ │ ├── test_valid_models_enum.py
│ │ │ ├── test_valid_models_plate.py
│ │ │ ├── test_valid_models_sequential_plate.py
│ │ │ └── test_vectorized_markov.py
│ │ ├── gp/
│ │ │ ├── __init__.py
│ │ │ ├── test_conditional.py
│ │ │ ├── test_kernels.py
│ │ │ ├── test_likelihoods.py
│ │ │ ├── test_models.py
│ │ │ └── test_parameterized.py
│ │ ├── mue/
│ │ │ ├── test_dataloaders.py
│ │ │ ├── test_missingdatahmm.py
│ │ │ ├── test_models.py
│ │ │ └── test_statearrangers.py
│ │ ├── oed/
│ │ │ ├── test_ewma.py
│ │ │ ├── test_finite_spaces_eig.py
│ │ │ ├── test_glmm.py
│ │ │ ├── test_linear_models_eig.py
│ │ │ └── test_xexpx.py
│ │ ├── randomvariable/
│ │ │ └── test_random_variable.py
│ │ ├── test_hessian.py
│ │ ├── test_minipyro.py
│ │ ├── test_util.py
│ │ ├── test_zuko.py
│ │ ├── timeseries/
│ │ │ ├── test_gp.py
│ │ │ └── test_lgssm.py
│ │ └── tracking/
│ │ ├── __init__.py
│ │ ├── test_assignment.py
│ │ ├── test_distributions.py
│ │ ├── test_dynamic_models.py
│ │ ├── test_ekf.py
│ │ ├── test_em.py
│ │ ├── test_hashing.py
│ │ └── test_measurements.py
│ ├── distributions/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── dist_fixture.py
│ │ ├── test_binomial.py
│ │ ├── test_categorical.py
│ │ ├── test_coalescent.py
│ │ ├── test_conjugate.py
│ │ ├── test_conjugate_update.py
│ │ ├── test_constraints.py
│ │ ├── test_cuda.py
│ │ ├── test_delta.py
│ │ ├── test_distributions.py
│ │ ├── test_empirical.py
│ │ ├── test_extended.py
│ │ ├── test_gaussian_mixtures.py
│ │ ├── test_grouped_normal_normal.py
│ │ ├── test_haar.py
│ │ ├── test_hmm.py
│ │ ├── test_ig.py
│ │ ├── test_improper_uniform.py
│ │ ├── test_independent.py
│ │ ├── test_kl.py
│ │ ├── test_lkj.py
│ │ ├── test_log_normal_negative_binomial.py
│ │ ├── test_lowrank_mvn.py
│ │ ├── test_mask.py
│ │ ├── test_mixture.py
│ │ ├── test_mvn.py
│ │ ├── test_mvt.py
│ │ ├── test_nanmasked.py
│ │ ├── test_omt_mvn.py
│ │ ├── test_one_hot_categorical.py
│ │ ├── test_one_one_matching.py
│ │ ├── test_one_two_matching.py
│ │ ├── test_ordered_logistic.py
│ │ ├── test_pickle.py
│ │ ├── test_polya_gamma.py
│ │ ├── test_projected_normal.py
│ │ ├── test_rejector.py
│ │ ├── test_relaxed_straight_through.py
│ │ ├── test_reshape.py
│ │ ├── test_shapes.py
│ │ ├── test_sine_bivariate_von_mises.py
│ │ ├── test_sine_skewed.py
│ │ ├── test_spanning_tree.py
│ │ ├── test_stable.py
│ │ ├── test_stable_log_prob.py
│ │ ├── test_tensor_type.py
│ │ ├── test_torch_patch.py
│ │ ├── test_transforms.py
│ │ ├── test_unit.py
│ │ ├── test_util.py
│ │ ├── test_von_mises.py
│ │ ├── test_zero_inflated.py
│ │ └── testing/
│ │ ├── test_gof.py
│ │ └── test_special.py
│ ├── doctest_fixtures.py
│ ├── infer/
│ │ ├── __init__.py
│ │ ├── autoguide/
│ │ │ ├── __init__.py
│ │ │ ├── conftest.py
│ │ │ ├── test_gaussian.py
│ │ │ ├── test_inference.py
│ │ │ └── test_mean_field_entropy.py
│ │ ├── conftest.py
│ │ ├── enum_growth.ipynb
│ │ ├── mcmc/
│ │ │ ├── __init__.py
│ │ │ ├── test_adaptation.py
│ │ │ ├── test_hmc.py
│ │ │ ├── test_mcmc_api.py
│ │ │ ├── test_mcmc_util.py
│ │ │ ├── test_nuts.py
│ │ │ ├── test_rwkernel.py
│ │ │ └── test_valid_models.py
│ │ ├── reparam/
│ │ │ ├── __init__.py
│ │ │ ├── test_conjugate.py
│ │ │ ├── test_discrete_cosine.py
│ │ │ ├── test_haar.py
│ │ │ ├── test_hmm.py
│ │ │ ├── test_loc_scale.py
│ │ │ ├── test_neutra.py
│ │ │ ├── test_projected_normal.py
│ │ │ ├── test_softmax.py
│ │ │ ├── test_split.py
│ │ │ ├── test_stable.py
│ │ │ ├── test_strategies.py
│ │ │ ├── test_structured.py
│ │ │ ├── test_studentt.py
│ │ │ ├── test_transform.py
│ │ │ ├── test_unit_jacobian.py
│ │ │ └── util.py
│ │ ├── test_abstract_infer.py
│ │ ├── test_autoguide.py
│ │ ├── test_compute_downstream_costs.py
│ │ ├── test_conjugate_gradients.py
│ │ ├── test_csis.py
│ │ ├── test_discrete.py
│ │ ├── test_elbo_mapdata.py
│ │ ├── test_enum.py
│ │ ├── test_gradient.py
│ │ ├── test_inference.py
│ │ ├── test_initialization.py
│ │ ├── test_inspect.py
│ │ ├── test_jit.py
│ │ ├── test_multi_sample_elbos.py
│ │ ├── test_predictive.py
│ │ ├── test_resampler.py
│ │ ├── test_sampling.py
│ │ ├── test_smcfilter.py
│ │ ├── test_svgd.py
│ │ ├── test_tmc.py
│ │ ├── test_util.py
│ │ └── test_valid_models.py
│ ├── integration_tests/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_conjugate_gaussian_models.py
│ │ └── test_tracegraph_elbo.py
│ ├── nn/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_autoregressive.py
│ │ └── test_module.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── einsum/
│ │ │ ├── conftest.py
│ │ │ ├── test_adjoint.py
│ │ │ └── test_torch_log.py
│ │ ├── gamma_gaussian.py
│ │ ├── gaussian.py
│ │ ├── test_arrowhead.py
│ │ ├── test_contract.py
│ │ ├── test_gamma_gaussian.py
│ │ ├── test_gaussian.py
│ │ ├── test_indexing.py
│ │ ├── test_integrator.py
│ │ ├── test_jit.py
│ │ ├── test_linalg.py
│ │ ├── test_newton.py
│ │ ├── test_packed.py
│ │ ├── test_provenance.py
│ │ ├── test_special.py
│ │ ├── test_ssm_gp.py
│ │ ├── test_stats.py
│ │ ├── test_streaming.py
│ │ ├── test_tensor_utils.py
│ │ └── test_welford.py
│ ├── optim/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_multi.py
│ │ └── test_optim.py
│ ├── params/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_module.py
│ │ └── test_param.py
│ ├── perf/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ └── test_benchmark.py
│ ├── poutine/
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_counterfactual.py
│ │ ├── test_mapdata.py
│ │ ├── test_nesting.py
│ │ ├── test_poutines.py
│ │ ├── test_properties.py
│ │ ├── test_runtime.py
│ │ └── test_trace_struct.py
│ ├── pyroapi/
│ │ ├── conftest.py
│ │ └── test_pyroapi.py
│ ├── test_examples.py
│ ├── test_generic.py
│ ├── test_primitives.py
│ ├── test_settings.py
│ └── test_util.py
└── tutorial/
├── Makefile
├── README.md
└── source/
├── RSA-hyperbole.ipynb
├── RSA-implicature.ipynb
├── _static/
│ ├── css/
│ │ └── pyro.css
│ └── img/
│ └── dmm.tex
├── air.ipynb
├── autoname_examples.rst
├── baseball.rst
├── bayesian_regression.ipynb
├── bayesian_regression_ii.ipynb
├── bo.ipynb
├── boosting_bbvi.ipynb
├── capture_recapture.rst
├── cevae.rst
├── cleannb.py
├── conf.py
├── contrib_funsor_intro_i.ipynb
├── contrib_funsor_intro_ii.ipynb
├── csis.ipynb
├── custom_objectives.ipynb
├── cvae.ipynb
├── dirichlet_process_mixture.ipynb
├── dkl.rst
├── dmm.ipynb
├── easyguide.ipynb
├── effect_handlers.ipynb
├── einsum.rst
├── ekf.ipynb
├── elections.ipynb
├── enumeration.ipynb
├── epi_intro.ipynb
├── epi_regional.rst
├── epi_sir.rst
├── forecast_simple.rst
├── forecasting_dlm.ipynb
├── forecasting_i.ipynb
├── forecasting_ii.ipynb
├── forecasting_iii.ipynb
├── gmm.ipynb
├── gp.ipynb
├── gplvm.ipynb
├── hmm.rst
├── hmm_funsor.rst
├── inclined_plane.rst
├── index.rst
├── intro_long.ipynb
├── intro_part_i.ipynb
├── intro_part_ii.ipynb
├── jit.ipynb
├── lda.rst
├── lkj.rst
├── logistic-growth.ipynb
├── mcmc.rst
├── minipyro.rst
├── mixed_hmm.rst
├── mle_map.ipynb
├── model_rendering.ipynb
├── modules.ipynb
├── mue_factor.rst
├── mue_profile.rst
├── neutra.rst
├── normalizing_flows_intro.ipynb
├── predictive_deterministic.ipynb
├── prior_predictive.ipynb
├── prodlda.ipynb
├── reconciling_experts.ipynb
├── scanvi.ipynb
├── search_inference.py
├── sir_hmc.rst
├── smcfilter.rst
├── sparse_gamma.rst
├── sparse_regression.rst
├── ss-vae.ipynb
├── stable.ipynb
├── svi_flow_guide.ipynb
├── svi_horovod.rst
├── svi_lightning.rst
├── svi_part_i.ipynb
├── svi_part_ii.ipynb
├── svi_part_iii.ipynb
├── svi_part_iv.ipynb
├── svi_torch.rst
├── tensor_shapes.ipynb
├── timeseries.rst
├── toy_mixture_model_discrete_enumeration.rst
├── tracking_1d.ipynb
├── vae.ipynb
├── vae_flow_prior.ipynb
├── workflow.ipynb
└── working_memory.ipynb
Showing preview only (467K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (5897 symbols across 565 files)
FILE: docs/source/conf.py
function setup (line 212) | def setup(app):
FILE: examples/air/air.py
function default_z_pres_prior_p (line 24) | def default_z_pres_prior_p(t):
class AIR (line 34) | class AIR(nn.Module):
method __init__ (line 35) | def __init__(
method prior (line 128) | def prior(self, n, **kwargs):
method prior_step (line 145) | def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p):
method model (line 193) | def model(self, data, batch_size, **kwargs):
method guide (line 211) | def guide(self, data, batch_size, **kwargs):
method guide_step (line 262) | def guide_step(self, t, n, prev, inputs):
method baseline_step (line 313) | def baseline_step(self, prev, inputs):
function expand_z_where (line 352) | def expand_z_where(z_where):
function z_where_inv (line 369) | def z_where_inv(z_where):
function window_to_image (line 382) | def window_to_image(z_where, window_size, image_size, windows):
function image_to_window (line 391) | def image_to_window(z_where, window_size, image_size, images):
function latents_to_tensor (line 403) | def latents_to_tensor(z):
FILE: examples/air/main.py
function count_accuracy (line 33) | def count_accuracy(X, true_counts, air, batch_size):
function make_prior (line 70) | def make_prior(k):
function lin_decay (line 103) | def lin_decay(initial, final, begin, duration, t):
function exp_decay (line 109) | def exp_decay(initial, final, begin, duration, t):
function load_data (line 118) | def load_data():
function main (line 130) | def main(**kwargs):
FILE: examples/air/modules.py
class Encoder (line 12) | class Encoder(nn.Module):
method __init__ (line 13) | def __init__(self, x_size, h_sizes, z_size, non_linear_layer):
method forward (line 19) | def forward(self, x):
class Decoder (line 25) | class Decoder(nn.Module):
method __init__ (line 26) | def __init__(self, x_size, h_sizes, z_size, bias, use_sigmoid, non_lin...
method forward (line 32) | def forward(self, z):
class MLP (line 44) | class MLP(nn.Module):
method __init__ (line 45) | def __init__(
method forward (line 61) | def forward(self, x):
class Predict (line 67) | class Predict(nn.Module):
method __init__ (line 68) | def __init__(
method forward (line 77) | def forward(self, h):
class Identity (line 85) | class Identity(nn.Module):
method __init__ (line 86) | def __init__(self):
method forward (line 89) | def forward(self, x):
FILE: examples/air/viz.py
function bounding_box (line 11) | def bounding_box(z_where, x_size):
function arr2img (line 23) | def arr2img(arr):
function img2arr (line 30) | def img2arr(img):
function colors (line 38) | def colors(k):
function draw_one (line 42) | def draw_one(imgarr, z_arr):
function draw_many (line 64) | def draw_many(imgarrs, z_arr):
function tensor_to_objs (line 75) | def tensor_to_objs(latents):
FILE: examples/baseball.py
function fully_pooled (line 63) | def fully_pooled(at_bats, hits):
function not_pooled (line 79) | def not_pooled(at_bats, hits):
function partially_pooled (line 95) | def partially_pooled(at_bats, hits):
function partially_pooled_with_logit (line 118) | def partially_pooled_with_logit(at_bats, hits):
function get_summary_table (line 141) | def get_summary_table(
function train_test_split (line 176) | def train_test_split(pd_dataframe):
function sample_posterior_predictive (line 203) | def sample_posterior_predictive(model, posterior_samples, baseball_datas...
function evaluate_pointwise_pred_density (line 234) | def evaluate_pointwise_pred_density(model, posterior_samples, baseball_d...
function main (line 255) | def main(args):
FILE: examples/capture_recapture/cjs.py
function model_1 (line 54) | def model_1(capture_history, sex):
function model_2 (line 89) | def model_2(capture_history, sex):
function model_3 (line 128) | def model_3(capture_history, sex):
function model_4 (line 178) | def model_4(capture_history, sex):
function model_5 (line 224) | def model_5(capture_history, sex):
function main (line 269) | def main(args):
FILE: examples/contrib/autoname/mixture.py
function model (line 21) | def model(data, k):
function local_model (line 35) | def local_model(latent, ps, locs, scales, obs=None):
function guide (line 40) | def guide(data, k):
function local_guide (line 48) | def local_guide(latent, k):
function main (line 54) | def main(args):
FILE: examples/contrib/autoname/scoping_mixture.py
function model (line 16) | def model(K, data):
function local_model (line 27) | def local_model(weights, locs, scale, data):
function guide (line 34) | def guide(K, data):
function local_guide (line 45) | def local_guide(probs):
function main (line 49) | def main(args):
FILE: examples/contrib/autoname/tree_data.py
function model (line 25) | def model(data):
function model_recurse (line 31) | def model_recurse(data, latent):
function guide (line 51) | def guide(data):
function guide_recurse (line 55) | def guide_recurse(data, latent):
function main (line 73) | def main(args):
FILE: examples/contrib/cevae/synthetic.py
function generate_data (line 29) | def generate_data(args):
function main (line 46) | def main(args):
FILE: examples/contrib/epidemiology/regional.py
function Model (line 15) | def Model(args, data):
function generate_data (line 22) | def generate_data(args):
function infer_mcmc (line 53) | def infer_mcmc(args, model):
function infer_svi (line 87) | def infer_svi(args, model):
function predict (line 109) | def predict(args, model, truth):
function main (line 152) | def main(args):
FILE: examples/contrib/epidemiology/sir.py
function Model (line 29) | def Model(args, data):
function generate_data (line 58) | def generate_data(args):
function infer_mcmc (line 103) | def infer_mcmc(args, model):
function infer_svi (line 143) | def infer_svi(args, model):
function evaluate (line 167) | def evaluate(args, model, samples):
function predict (line 261) | def predict(args, model, truth):
function main (line 316) | def main(args):
FILE: examples/contrib/forecast/bart.py
function preprocess (line 20) | def preprocess(args):
class Model (line 45) | class Model(ForecastingModel):
method model (line 51) | def model(self, zero_data, covariates):
function main (line 125) | def main(args):
FILE: examples/contrib/funsor/hmm.py
function model_0 (line 94) | def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
function model_1 (line 185) | def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
function model_2 (line 276) | def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
function model_3 (line 327) | def model_3(sequences, lengths, args, batch_size=None, include_prior=True):
function model_4 (line 382) | def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
class TonesGenerator (line 437) | class TonesGenerator(nn.Module):
method __init__ (line 438) | def __init__(self, args, data_dim):
method forward (line 448) | def forward(self, x, y):
function model_5 (line 470) | def model_5(sequences, lengths, args, batch_size=None, include_prior=True):
function model_6 (line 523) | def model_6(sequences, lengths, args, batch_size=None, include_prior=Fal...
function model_7 (line 591) | def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
function main (line 671) | def main(args):
FILE: examples/contrib/gp/sv-dkl.py
class CNN (line 45) | class CNN(nn.Module):
method __init__ (line 46) | def __init__(self):
method forward (line 53) | def forward(self, x):
function train (line 62) | def train(args, train_loader, gpmodule, optimizer, loss_fn, epoch):
function test (line 87) | def test(args, test_loader, gpmodule):
function main (line 111) | def main(args):
FILE: examples/contrib/mue/FactorMuE.py
function generate_data (line 47) | def generate_data(small_test, include_stop, device):
function main (line 62) | def main(args):
FILE: examples/contrib/mue/ProfileHMM.py
function generate_data (line 51) | def generate_data(small_test, include_stop, device):
function main (line 66) | def main(args):
FILE: examples/contrib/oed/ab_test.py
function estimated_ape (line 60) | def estimated_ape(ns, num_vi_steps):
function true_ape (line 82) | def true_ape(ns):
function main (line 94) | def main(num_vi_steps, num_bo_steps, seed):
FILE: examples/contrib/oed/gp_bayes_opt.py
class GPBayesOptimizer (line 14) | class GPBayesOptimizer(pyro.optim.multi.MultiOptimizer):
method __init__ (line 19) | def __init__(self, constraints, gpmodel, num_acquisitions, acquisition...
method update_posterior (line 36) | def update_posterior(self, X, y):
method find_a_candidate (line 50) | def find_a_candidate(self, differentiable, x_init):
method opt_differentiable (line 83) | def opt_differentiable(self, differentiable, num_candidates=5):
method acquire_thompson (line 110) | def acquire_thompson(self, num_acquisitions=1, **opt_params):
method get_step (line 132) | def get_step(self, loss, params, verbose=False):
FILE: examples/contrib/timeseries/gp_models.py
function download_data (line 16) | def download_data():
function main (line 23) | def main(args):
FILE: examples/cvae/baseline.py
class BaselineNet (line 14) | class BaselineNet(nn.Module):
method __init__ (line 15) | def __init__(self, hidden_1, hidden_2):
method forward (line 22) | def forward(self, x):
class MaskedBCELoss (line 30) | class MaskedBCELoss(nn.Module):
method __init__ (line 31) | def __init__(self, masked_with=-1):
method forward (line 35) | def forward(self, input, target):
function train (line 46) | def train(
FILE: examples/cvae/cvae.py
class Encoder (line 16) | class Encoder(nn.Module):
method __init__ (line 17) | def __init__(self, z_dim, hidden_1, hidden_2):
method forward (line 25) | def forward(self, x, y):
class Decoder (line 40) | class Decoder(nn.Module):
method __init__ (line 41) | def __init__(self, z_dim, hidden_1, hidden_2):
method forward (line 48) | def forward(self, z):
class CVAE (line 55) | class CVAE(nn.Module):
method __init__ (line 56) | def __init__(self, z_dim, hidden_1, hidden_2, pre_trained_baseline_net):
method model (line 68) | def model(self, xs, ys=None):
method guide (line 107) | def guide(self, xs, ys=None):
function train (line 122) | def train(
FILE: examples/cvae/main.py
function main (line 15) | def main(args):
FILE: examples/cvae/mnist.py
class CVAEMNIST (line 12) | class CVAEMNIST(Dataset):
method __init__ (line 13) | def __init__(self, root, train=True, transform=None, download=False):
method __len__ (line 17) | def __len__(self):
method __getitem__ (line 20) | def __getitem__(self, item):
class ToTensor (line 29) | class ToTensor:
method __call__ (line 30) | def __call__(self, sample):
class MaskImages (line 38) | class MaskImages:
method __init__ (line 46) | def __init__(self, num_quadrant_inputs, mask_with=-1):
method __call__ (line 52) | def __call__(self, sample):
function get_data (line 77) | def get_data(num_quadrant_inputs, batch_size):
FILE: examples/cvae/util.py
function imshow (line 19) | def imshow(inp, image_path=None):
function visualize (line 41) | def visualize(
function generate_table (line 115) | def generate_table(
FILE: examples/dmm.py
class Emitter (line 44) | class Emitter(nn.Module):
method __init__ (line 49) | def __init__(self, input_dim, z_dim, emission_dim):
method forward (line 58) | def forward(self, z_t):
class GatedTransition (line 69) | class GatedTransition(nn.Module):
method __init__ (line 75) | def __init__(self, z_dim, transition_dim):
method forward (line 92) | def forward(self, z_t_1):
class Combiner (line 114) | class Combiner(nn.Module):
method __init__ (line 121) | def __init__(self, z_dim, rnn_dim):
method forward (line 131) | def forward(self, z_t_1, h_rnn):
class DMM (line 147) | class DMM(nn.Module):
method __init__ (line 153) | def __init__(
method model (line 203) | def model(
method guide (line 263) | def guide(
function main (line 334) | def main(args):
FILE: examples/eight_schools/mcmc.py
function model (line 19) | def model(sigma):
function conditioned_model (line 29) | def conditioned_model(model, sigma, y):
function main (line 33) | def main(args):
FILE: examples/eight_schools/svi.py
function model (line 20) | def model(data):
function guide (line 34) | def guide(data):
function main (line 67) | def main(args):
FILE: examples/einsum.py
function jit_prob (line 38) | def jit_prob(equation, *operands, **kwargs):
function jit_logprob (line 55) | def jit_logprob(equation, *operands, **kwargs):
function jit_gradient (line 74) | def jit_gradient(equation, *operands, **kwargs):
function _jit_adjoint (line 105) | def _jit_adjoint(equation, *operands, **kwargs):
function jit_map (line 144) | def jit_map(equation, *operands, **kwargs):
function jit_sample (line 150) | def jit_sample(equation, *operands, **kwargs):
function jit_marginal (line 156) | def jit_marginal(equation, *operands, **kwargs):
function time_fn (line 162) | def time_fn(fn, equation, *operands, **kwargs):
function main (line 175) | def main(args):
FILE: examples/hmm.py
function model_0 (line 83) | def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
function model_1 (line 174) | def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
function model_2 (line 265) | def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
function model_3 (line 316) | def model_3(sequences, lengths, args, batch_size=None, include_prior=True):
function model_4 (line 371) | def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
class TonesGenerator (line 426) | class TonesGenerator(nn.Module):
method __init__ (line 427) | def __init__(self, args, data_dim):
method forward (line 437) | def forward(self, x, y):
function model_5 (line 459) | def model_5(sequences, lengths, args, batch_size=None, include_prior=True):
function model_6 (line 512) | def model_6(sequences, lengths, args, batch_size=None, include_prior=Fal...
function model_7 (line 580) | def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
function main (line 621) | def main(args):
FILE: examples/inclined_plane.py
function simulate (line 36) | def simulate(mu, length=2.0, phi=np.pi / 6.0, dt=0.005, noise_sigma=None):
function analytic_T (line 66) | def analytic_T(mu, length=2.0, phi=np.pi / 6.0):
function model (line 86) | def model(observed_data):
function main (line 101) | def main(args):
FILE: examples/lda.py
function model (line 42) | def model(data=None, args=None, batch_size=None):
function make_predictor (line 78) | def make_predictor(args):
function parametrized_guide (line 96) | def parametrized_guide(predictor, data, args, batch_size=None):
function main (line 125) | def main(args):
FILE: examples/lkj.py
function model (line 22) | def model(y):
function main (line 45) | def main(args):
FILE: examples/minipyro.py
function main (line 19) | def main(args):
FILE: examples/mixed_hmm/experiment.py
function aic_num_parameters (line 19) | def aic_num_parameters(model, guide=None):
function run_expt (line 37) | def run_expt(args):
FILE: examples/mixed_hmm/model.py
function guide_generic (line 14) | def guide_generic(config):
function model_generic (line 70) | def model_generic(config):
FILE: examples/mixed_hmm/seal_data.py
function download_seal_data (line 13) | def download_seal_data(filename):
function prepare_seal (line 20) | def prepare_seal(filename, random_effects):
FILE: examples/neutra.py
class BananaShaped (line 46) | class BananaShaped(dist.TorchDistribution):
method __init__ (line 50) | def __init__(self, a, b, rho=0.9):
method sample (line 58) | def sample(self, sample_shape=()):
method log_prob (line 66) | def log_prob(self, x):
function model (line 74) | def model(a, b, rho=0.9):
function fit_guide (line 78) | def fit_guide(guide, args):
function run_hmc (line 88) | def run_hmc(args, model):
function main (line 96) | def main(args):
FILE: examples/rsa/generics.py
function Marginal (line 26) | def Marginal(fn):
function discretize_beta_pdf (line 38) | def discretize_beta_pdf(bins, gamma, delta):
function structured_prior_model (line 55) | def structured_prior_model(params):
function threshold_prior (line 73) | def threshold_prior():
function utterance_prior (line 81) | def utterance_prior():
function meaning (line 87) | def meaning(utterance, state, threshold):
function listener0 (line 106) | def listener0(utterance, threshold, prior):
function speaker1 (line 114) | def speaker1(state, threshold, prior):
function listener1 (line 124) | def listener1(utterance, prior):
function speaker2 (line 133) | def speaker2(prevalence, prior):
function main (line 140) | def main(args):
FILE: examples/rsa/hyperbole.py
function Marginal (line 23) | def Marginal(fn):
function approx (line 35) | def approx(x, b=None):
function price_prior (line 43) | def price_prior():
function valence_prior (line 52) | def valence_prior(price):
function meaning (line 68) | def meaning(utterance, price):
function qud_prior (line 83) | def qud_prior():
function utterance_cost (line 91) | def utterance_cost(numberUtt):
function utterance_prior (line 96) | def utterance_prior():
function literal_listener (line 106) | def literal_listener(utterance, qud):
function speaker (line 114) | def speaker(qudValue, qud):
function pragmatic_listener (line 124) | def pragmatic_listener(utterance):
function test_truth (line 138) | def test_truth():
function main (line 204) | def main(args):
FILE: examples/rsa/schelling.py
function location (line 23) | def location(preference):
function alice (line 32) | def alice(preference, depth):
function bob (line 42) | def bob(preference, depth):
function main (line 55) | def main(args):
FILE: examples/rsa/schelling_false.py
function location (line 24) | def location(preference):
function alice_fb (line 33) | def alice_fb(preference, depth):
function alice (line 46) | def alice(preference, depth):
function bob (line 56) | def bob(preference, depth):
function main (line 69) | def main(args):
FILE: examples/rsa/search_inference.py
function memoize (line 22) | def memoize(fn=None, **kwargs):
class HashingMarginal (line 28) | class HashingMarginal(dist.Distribution):
method __init__ (line 37) | def __init__(self, trace_dist, sites=None):
method _dist_and_values (line 54) | def _dist_and_values(self):
method sample (line 85) | def sample(self):
method log_prob (line 90) | def log_prob(self, val):
method enumerate_support (line 100) | def enumerate_support(self):
method _dict_to_tuple (line 104) | def _dict_to_tuple(self, d):
method _weighted_mean (line 115) | def _weighted_mean(self, value, dim=0):
method mean (line 122) | def mean(self):
method variance (line 127) | def variance(self):
class Search (line 138) | class Search(TracePosterior):
method __init__ (line 143) | def __init__(self, model, max_tries=int(1e6), **kwargs):
method _traces (line 148) | def _traces(self, *args, **kwargs):
function pqueue (line 162) | def pqueue(fn, queue):
class BestFirstSearch (line 200) | class BestFirstSearch(TracePosterior):
method __init__ (line 206) | def __init__(self, model, num_samples=None, **kwargs):
method _traces (line 213) | def _traces(self, *args, **kwargs):
FILE: examples/rsa/semantic_parsing.py
function Marginal (line 22) | def Marginal(fn=None, **kwargs):
function flip (line 35) | def flip(name, p):
function Obj (line 43) | def Obj(name):
class Meaning (line 52) | class Meaning:
method sem (line 53) | def sem(self, world):
method syn (line 58) | def syn(self):
class UndefinedMeaning (line 62) | class UndefinedMeaning(Meaning):
method sem (line 63) | def sem(self, world):
method syn (line 66) | def syn(self):
class BlondMeaning (line 70) | class BlondMeaning(Meaning):
method sem (line 71) | def sem(self, world):
method syn (line 74) | def syn(self):
class NiceMeaning (line 78) | class NiceMeaning(Meaning):
method sem (line 79) | def sem(self, world):
method syn (line 82) | def syn(self):
class TallMeaning (line 86) | class TallMeaning(Meaning):
method sem (line 87) | def sem(self, world):
method syn (line 90) | def syn(self):
class BobMeaning (line 94) | class BobMeaning(Meaning):
method sem (line 95) | def sem(self, world):
method syn (line 98) | def syn(self):
class SomeMeaning (line 102) | class SomeMeaning(Meaning):
method sem (line 103) | def sem(self, world):
method syn (line 112) | def syn(self):
class AllMeaning (line 124) | class AllMeaning(Meaning):
method sem (line 125) | def sem(self, world):
method syn (line 136) | def syn(self):
class NoneMeaning (line 148) | class NoneMeaning(Meaning):
method sem (line 149) | def sem(self, world):
method syn (line 158) | def syn(self):
class CompoundMeaning (line 170) | class CompoundMeaning(Meaning):
method __init__ (line 171) | def __init__(self, sem, syn):
method sem (line 175) | def sem(self, world):
method syn (line 178) | def syn(self):
function heuristic (line 187) | def heuristic(is_good):
function world_prior (line 193) | def world_prior(num_objs, meaning_fn):
function lexical_meaning (line 206) | def lexical_meaning(word):
function apply_world_passing (line 221) | def apply_world_passing(f, a):
function syntax_match (line 225) | def syntax_match(s, t):
function can_apply (line 236) | def can_apply(meanings):
function combine_meaning (line 255) | def combine_meaning(meanings, c):
function combine_meanings (line 273) | def combine_meanings(meanings, c=0):
function meaning (line 280) | def meaning(utterance):
function literal_listener (line 288) | def literal_listener(utterance):
function utterance_prior (line 295) | def utterance_prior():
function speaker (line 306) | def speaker(world):
function rsa_listener (line 313) | def rsa_listener(utterance, qud):
function literal_listener_raw (line 320) | def literal_listener_raw(utterance, qud):
function main (line 327) | def main(args):
FILE: examples/scanvi/scanvi.py
function make_fc (line 41) | def make_fc(dims):
function split_in_half (line 51) | def split_in_half(t):
function broadcast_inputs (line 56) | def broadcast_inputs(input_args):
class Z2Decoder (line 63) | class Z2Decoder(nn.Module):
method __init__ (line 64) | def __init__(self, z1_dim, y_dim, z2_dim, hidden_dims):
method forward (line 69) | def forward(self, z1, y):
class XDecoder (line 84) | class XDecoder(nn.Module):
method __init__ (line 85) | def __init__(self, num_genes, z2_dim, hidden_dims):
method forward (line 90) | def forward(self, z2):
class Z2LEncoder (line 97) | class Z2LEncoder(nn.Module):
method __init__ (line 98) | def __init__(self, num_genes, z2_dim, hidden_dims):
method forward (line 103) | def forward(self, x):
class Z1Encoder (line 115) | class Z1Encoder(nn.Module):
method __init__ (line 116) | def __init__(self, num_labels, z1_dim, z2_dim, hidden_dims):
method forward (line 121) | def forward(self, z2, y):
class Classifier (line 136) | class Classifier(nn.Module):
method __init__ (line 137) | def __init__(self, z2_dim, hidden_dims, num_labels):
method forward (line 142) | def forward(self, x):
class SCANVI (line 148) | class SCANVI(nn.Module):
method __init__ (line 149) | def __init__(
method model (line 209) | def model(self, x, y=None):
method guide (line 252) | def guide(self, x, y=None):
function main (line 280) | def main(args):
FILE: examples/sir_hmc.py
function global_model (line 63) | def global_model(population):
function discrete_model (line 75) | def discrete_model(args, data):
function generate_data (line 94) | def generate_data(args):
function reparameterized_discrete_model (line 166) | def reparameterized_discrete_model(args, data):
function infer_hmc_enum (line 209) | def infer_hmc_enum(args, data):
function _infer_hmc (line 214) | def _infer_hmc(args, data, model, init_values={}):
function quantize (line 267) | def quantize(name, x_real, min, max):
function continuous_model (line 303) | def continuous_model(args, data):
function heuristic_init (line 350) | def heuristic_init(args, data):
function infer_hmc_cont (line 371) | def infer_hmc_cont(model, args, data):
function quantize_enumerate (line 383) | def quantize_enumerate(x_real, min, max):
function vectorized_model (line 415) | def vectorized_model(args, data):
function evaluate (line 482) | def evaluate(args, samples):
function predict (line 527) | def predict(args, data, samples, truth=None):
function main (line 611) | def main(args):
FILE: examples/smcfilter.py
class SimpleHarmonicModel (line 25) | class SimpleHarmonicModel:
method __init__ (line 26) | def __init__(self, process_noise, measurement_noise):
method init (line 32) | def init(self, state, initial):
method step (line 36) | def step(self, state, y=None):
class SimpleHarmonicModel_Guide (line 48) | class SimpleHarmonicModel_Guide:
method __init__ (line 49) | def __init__(self, model):
method init (line 52) | def init(self, state, initial):
method step (line 56) | def step(self, state, y=None):
function generate_data (line 68) | def generate_data(args):
function main (line 84) | def main(args):
FILE: examples/sparse_gamma_def.py
function rand_tensor (line 39) | def rand_tensor(shape, mean, sigma):
class SparseGammaDEF (line 43) | class SparseGammaDEF:
method __init__ (line 44) | def __init__(self):
method model (line 61) | def model(self, x):
method guide (line 113) | def guide(self, x):
function clip_params (line 161) | def clip_params():
class MyEasyGuide (line 178) | class MyEasyGuide(EasyGuide):
method guide (line 179) | def guide(self, x):
function main (line 206) | def main(args):
FILE: examples/sparse_regression.py
function dot (line 47) | def dot(X, Z):
function kernel (line 52) | def kernel(X, Z, eta1, eta2, c):
function model (line 62) | def model(X, Y, hypers, jitter=1.0e-4):
function compute_posterior_stats (line 102) | def compute_posterior_stats(X, Y, msq, lam, eta1, xisq, c, sigma, jitter...
function get_data (line 219) | def get_data(N=20, P=10, S=2, Q=2, sigma_obs=0.15):
function init_loc_fn (line 255) | def init_loc_fn(site):
function main (line 264) | def main(args):
FILE: examples/svi_horovod.py
class Model (line 42) | class Model(PyroModule):
method __init__ (line 43) | def __init__(self, size):
method forward (line 47) | def forward(self, covariates, data=None):
function main (line 64) | def main(args):
FILE: examples/svi_lightning.py
class Model (line 30) | class Model(PyroModule):
method __init__ (line 31) | def __init__(self, size):
method forward (line 35) | def forward(self, covariates, data=None):
class PyroLightningModule (line 53) | class PyroLightningModule(pl.LightningModule):
method __init__ (line 54) | def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float):
method forward (line 64) | def forward(self, *args):
method training_step (line 67) | def training_step(self, batch, batch_idx):
method configure_optimizers (line 74) | def configure_optimizers(self):
function main (line 79) | def main(args):
FILE: examples/svi_torch.py
class Model (line 24) | class Model(PyroModule):
method __init__ (line 25) | def __init__(self, size):
method forward (line 34) | def forward(self, covariates, data=None):
function main (line 48) | def main(args):
FILE: examples/toy_mixture_model_discrete_enumeration.py
function main (line 40) | def main(args):
function generate_data (line 48) | def generate_data(num_obs):
function model (line 71) | def model(prior, obs, num_obs):
function guide (line 86) | def guide(prior, obs, num_obs):
function train (line 95) | def train(prior, data, num_steps, num_obs):
function evaluate (line 114) | def evaluate(CPDs, posterior_params):
function get_true_pred_CPDs (line 129) | def get_true_pred_CPDs(CPD, posterior_param):
FILE: examples/vae/ss_vae_M2.py
class SSVAE (line 27) | class SSVAE(nn.Module):
method __init__ (line 44) | def __init__(
method setup_networks (line 69) | def setup_networks(self):
method model (line 109) | def model(self, xs, ys=None):
method guide (line 152) | def guide(self, xs, ys=None):
method classifier (line 179) | def classifier(self, xs):
method model_classify (line 198) | def model_classify(self, xs, ys=None):
method guide_classify (line 214) | def guide_classify(self, xs, ys=None):
function run_inference_for_epoch (line 221) | def run_inference_for_epoch(data_loaders, losses, periodic_interval_batc...
function get_accuracy (line 268) | def get_accuracy(data_loader, classifier_fn, batch_size):
function visualize (line 292) | def visualize(ss_vae, viz, test_loader):
function main (line 298) | def main(args):
FILE: examples/vae/utils/custom_mlp.py
class Exp (line 12) | class Exp(nn.Module):
method __init__ (line 17) | def __init__(self):
method forward (line 20) | def forward(self, val):
class ConcatModule (line 24) | class ConcatModule(nn.Module):
method __init__ (line 29) | def __init__(self, allow_broadcast=False):
method forward (line 33) | def forward(self, *input_args):
class ListOutModule (line 51) | class ListOutModule(nn.ModuleList):
method __init__ (line 56) | def __init__(self, modules):
method forward (line 59) | def forward(self, *args, **kwargs):
function call_nn_op (line 64) | def call_nn_op(op):
class MLP (line 78) | class MLP(nn.Module):
method __init__ (line 79) | def __init__(
method forward (line 202) | def forward(self, *args, **kwargs):
FILE: examples/vae/utils/mnist_cached.py
function fn_x_mnist (line 20) | def fn_x_mnist(x, use_cuda):
function fn_y_mnist (line 35) | def fn_y_mnist(y, use_cuda):
function get_ss_indices_per_class (line 48) | def get_ss_indices_per_class(y, sup_per_class):
function split_sup_unsup_valid (line 73) | def split_sup_unsup_valid(X, y, sup_num, validation_num=10000):
function print_distribution_labels (line 104) | def print_distribution_labels(y):
class MNISTCached (line 119) | class MNISTCached(MNIST):
method __init__ (line 133) | def __init__(self, mode, sup_num, use_cuda=True, *args, **kwargs):
method __getitem__ (line 201) | def __getitem__(self, index):
function setup_data_loaders (line 215) | def setup_data_loaders(
function mkdir_p (line 251) | def mkdir_p(path):
FILE: examples/vae/utils/vae_plots.py
function plot_conditional_samples_ssvae (line 7) | def plot_conditional_samples_ssvae(ssvae, visdom_session):
function plot_llk (line 28) | def plot_llk(train_elbo, test_elbo):
function plot_vae_samples (line 65) | def plot_vae_samples(vae, visdom_session):
function mnist_test_tsne (line 78) | def mnist_test_tsne(vae=None, test_loader=None):
function mnist_test_tsne_ssvae (line 89) | def mnist_test_tsne_ssvae(name=None, ssvae=None, test_loader=None):
function plot_tsne (line 101) | def plot_tsne(z_loc, classes, name):
FILE: examples/vae/vae.py
class Encoder (line 22) | class Encoder(nn.Module):
method __init__ (line 23) | def __init__(self, z_dim, hidden_dim):
method forward (line 32) | def forward(self, x):
class Decoder (line 47) | class Decoder(nn.Module):
method __init__ (line 48) | def __init__(self, z_dim, hidden_dim):
method forward (line 56) | def forward(self, z):
class VAE (line 67) | class VAE(nn.Module):
method __init__ (line 70) | def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
method model (line 84) | def model(self, x):
method guide (line 105) | def guide(self, x):
method reconstruct_img (line 115) | def reconstruct_img(self, x):
function main (line 125) | def main(args):
FILE: examples/vae/vae_comparison.py
class Encoder (line 35) | class Encoder(nn.Module):
method __init__ (line 36) | def __init__(self):
method forward (line 43) | def forward(self, x):
class Decoder (line 50) | class Decoder(nn.Module):
method __init__ (line 51) | def __init__(self):
method forward (line 57) | def forward(self, z):
class VAE (line 62) | class VAE(object, metaclass=ABCMeta):
method __init__ (line 68) | def __init__(self, args, train_loader, test_loader):
method set_train (line 76) | def set_train(self, is_train=True):
method compute_loss_and_gradient (line 87) | def compute_loss_and_gradient(self, x):
method model_eval (line 97) | def model_eval(self, x):
method train (line 112) | def train(self, epoch):
method test (line 124) | def test(self, epoch):
class PyTorchVAEImpl (line 146) | class PyTorchVAEImpl(VAE):
method __init__ (line 152) | def __init__(self, *args, **kwargs):
method compute_loss_and_gradient (line 156) | def compute_loss_and_gradient(self, x):
method initialize_optimizer (line 174) | def initialize_optimizer(self, lr=1e-3):
class PyroVAEImpl (line 181) | class PyroVAEImpl(VAE):
method __init__ (line 188) | def __init__(self, *args, **kwargs):
method model (line 192) | def model(self, data):
method guide (line 204) | def guide(self, data):
method compute_loss_and_gradient (line 210) | def compute_loss_and_gradient(self, x):
method initialize_optimizer (line 218) | def initialize_optimizer(self, lr):
function setup (line 224) | def setup(args):
function main (line 248) | def main(args):
FILE: profiler/distributions.py
function T (line 26) | def T(arr):
function get_tool (line 66) | def get_tool():
function get_tool_cfg (line 70) | def get_tool_cfg():
function sample (line 82) | def sample(dist, batch_size):
function log_prob (line 94) | def log_prob(dist, batch):
function run_with_tool (line 98) | def run_with_tool(tool, dists, batch_sizes):
function set_tool_cfg (line 127) | def set_tool_cfg(args):
function main (line 139) | def main():
FILE: profiler/gaussianhmm.py
function random_mvn (line 12) | def random_mvn(batch_shape, dim, requires_grad=False):
function main (line 22) | def main(args):
FILE: profiler/hmm.py
function main (line 20) | def main(args):
FILE: profiler/profiling_utils.py
class ProfilePrinter (line 20) | class ProfilePrinter:
method __init__ (line 21) | def __init__(self, column_widths=None, field_format=None, template="co...
method _formatted_values (line 32) | def _formatted_values(self, values):
method _add_using_row_format (line 41) | def _add_using_row_format(self, values):
method _add_using_column_format (line 47) | def _add_using_column_format(self, values):
method push (line 51) | def push(self, values):
method header (line 57) | def header(self, values):
method print (line 70) | def print(self):
function profile_print (line 75) | def profile_print(column_widths=None, field_format=None, template="colum...
function profile_timeit (line 83) | def profile_timeit(fn_callable, repeat=1):
function profile_cprofile (line 88) | def profile_cprofile(fn_callable, prof_file):
class Profile (line 98) | class Profile:
method __init__ (line 99) | def __init__(self, tool, tool_cfg, fn_id):
method _set_decorator_params (line 104) | def _set_decorator_params(self):
method __call__ (line 110) | def __call__(self, fn):
FILE: pyro/contrib/autoname/autoname.py
function genname (line 15) | def genname(name="name"):
class NameScope (line 19) | class NameScope:
method __init__ (line 20) | def __init__(self, name=None):
method __str__ (line 25) | def __str__(self):
method allocate (line 30) | def allocate(self, name):
class ScopeStack (line 36) | class ScopeStack:
method __init__ (line 41) | def __init__(self):
method __str__ (line 44) | def __str__(self):
method global_scope (line 48) | def global_scope(self):
method current_scope (line 52) | def current_scope(self):
method push_scope (line 57) | def push_scope(self, scope):
method pop_scope (line 61) | def pop_scope(self):
method fresh_name (line 64) | def fresh_name(self, name):
class AutonameMessenger (line 71) | class AutonameMessenger(ReentrantMessenger):
method __init__ (line 115) | def __init__(self, name=None):
method __call__ (line 119) | def __call__(self, fn_or_iter):
method __enter__ (line 131) | def __enter__(self):
method __exit__ (line 136) | def __exit__(self, *args):
method __iter__ (line 140) | def __iter__(self):
method _pyro_genname (line 148) | def _pyro_genname(msg):
function autoname (line 157) | def autoname(fn=None, name=None): ...
function sample (line 161) | def sample(*args):
function _sample_name (line 166) | def _sample_name(name, fn, *args, **kwargs): # the current syntax of py...
function _sample_dist (line 172) | def _sample_dist(fn, *args, **kwargs):
FILE: pyro/contrib/autoname/named.py
class Object (line 57) | class Object:
method __init__ (line 81) | def __init__(self, name):
method __str__ (line 85) | def __str__(self):
method __getattribute__ (line 88) | def __getattribute__(self, key):
method __setattr__ (line 101) | def __setattr__(self, key, value):
method sample_ (line 111) | def sample_(self, fn, *args, **kwargs):
method param_ (line 121) | def param_(self, *args, **kwargs):
class List (line 129) | class List(list):
method __init__ (line 147) | def __init__(self, name=None):
method __str__ (line 150) | def __str__(self):
method _set_name (line 153) | def _set_name(self, name):
method add (line 160) | def add(self):
method __setitem__ (line 179) | def __setitem__(self, pos, value):
class Dict (line 195) | class Dict(dict):
method __init__ (line 213) | def __init__(self, name=None):
method __str__ (line 216) | def __str__(self):
method _set_name (line 219) | def _set_name(self, name):
method __getitem__ (line 226) | def __getitem__(self, key):
method __setitem__ (line 239) | def __setitem__(self, key, value):
FILE: pyro/contrib/autoname/scoping.py
class NameCountMessenger (line 15) | class NameCountMessenger(Messenger):
method __enter__ (line 20) | def __enter__(self):
method _increment_name (line 24) | def _increment_name(self, name, label):
method _pyro_sample (line 34) | def _pyro_sample(self, msg):
method _pyro_post_sample (line 37) | def _pyro_post_sample(self, msg):
method _pyro_post_scope (line 40) | def _pyro_post_scope(self, msg):
method _pyro_scope (line 43) | def _pyro_scope(self, msg):
class ScopeMessenger (line 47) | class ScopeMessenger(Messenger):
method __init__ (line 52) | def __init__(self, prefix=None, inner=None):
method _collect_scope (line 59) | def _collect_scope(prefixed_scope):
method __enter__ (line 62) | def __enter__(self):
method __call__ (line 73) | def __call__(self, fn):
method _pyro_scope (line 84) | def _pyro_scope(self, msg):
method _pyro_sample (line 87) | def _pyro_sample(self, msg):
function scope (line 91) | def scope(fn=None, prefix=None, inner=None):
function name_count (line 146) | def name_count(fn=None):
FILE: pyro/contrib/bnn/hidden_layer.py
class HiddenLayer (line 12) | class HiddenLayer(TorchDistribution):
method __init__ (line 61) | def __init__(
method log_prob (line 90) | def log_prob(self, value):
method KL (line 94) | def KL(self):
method rsample (line 101) | def rsample(self, sample_shape=torch.Size()):
FILE: pyro/contrib/bnn/utils.py
function xavier_uniform (line 9) | def xavier_uniform(D_in, D_out):
function adjoin_ones_vector (line 15) | def adjoin_ones_vector(x):
function adjoin_zeros_vector (line 19) | def adjoin_zeros_vector(x):
FILE: pyro/contrib/cevae/__init__.py
class FullyConnected (line 42) | class FullyConnected(nn.Sequential):
method __init__ (line 47) | def __init__(self, sizes, final_activation=None):
method append (line 57) | def append(self, layer):
class DistributionNet (line 62) | class DistributionNet(nn.Module):
method get_class (line 68) | def get_class(dtype):
class BernoulliNet (line 80) | class BernoulliNet(DistributionNet):
method __init__ (line 94) | def __init__(self, sizes):
method forward (line 99) | def forward(self, x):
method make_dist (line 104) | def make_dist(logits):
class ExponentialNet (line 108) | class ExponentialNet(DistributionNet):
method __init__ (line 122) | def __init__(self, sizes):
method forward (line 127) | def forward(self, x):
method make_dist (line 133) | def make_dist(rate):
class LaplaceNet (line 137) | class LaplaceNet(DistributionNet):
method __init__ (line 152) | def __init__(self, sizes):
method forward (line 157) | def forward(self, x):
method make_dist (line 164) | def make_dist(loc, scale):
class NormalNet (line 168) | class NormalNet(DistributionNet):
method __init__ (line 183) | def __init__(self, sizes):
method forward (line 188) | def forward(self, x):
method make_dist (line 195) | def make_dist(loc, scale):
class StudentTNet (line 199) | class StudentTNet(DistributionNet):
method __init__ (line 214) | def __init__(self, sizes):
method forward (line 220) | def forward(self, x):
method make_dist (line 228) | def make_dist(df, loc, scale):
class DiagNormalNet (line 232) | class DiagNormalNet(nn.Module):
method __init__ (line 250) | def __init__(self, sizes):
method forward (line 256) | def forward(self, x):
class PreWhitener (line 265) | class PreWhitener(nn.Module):
method __init__ (line 270) | def __init__(self, data):
method forward (line 279) | def forward(self, data):
class Model (line 283) | class Model(PyroModule):
method __init__ (line 301) | def __init__(self, config):
method forward (line 319) | def forward(self, x, t=None, y=None, size=None):
method y_mean (line 329) | def y_mean(self, x, t=None):
method z_dist (line 336) | def z_dist(self):
method x_dist (line 339) | def x_dist(self, z):
method y_dist (line 343) | def y_dist(self, t, z):
method t_dist (line 351) | def t_dist(self, z):
class Guide (line 356) | class Guide(PyroModule):
method __init__ (line 374) | def __init__(self, config):
method forward (line 397) | def forward(self, x, t=None, y=None, size=None):
method t_dist (line 409) | def t_dist(self, x):
method y_dist (line 413) | def y_dist(self, t, x):
method z_dist (line 423) | def z_dist(self, y, t, x):
class TraceCausalEffect_ELBO (line 435) | class TraceCausalEffect_ELBO(Trace_ELBO):
method _differentiable_loss_particle (line 443) | def _differentiable_loss_particle(self, model_trace, guide_trace):
method loss (line 466) | def loss(self, model, guide, *args, **kwargs):
class CEVAE (line 470) | class CEVAE(nn.Module):
method __init__ (line 512) | def __init__(
method fit (line 539) | def fit(
method ite (line 607) | def ite(self, x, num_samples=None, batch_size=None):
method to_script_module (line 648) | def to_script_module(self):
FILE: pyro/contrib/conjugate/infer.py
function _make_cls (line 15) | def _make_cls(base, static_attrs, instance_attrs, parent_linkage=None):
function _latent (line 46) | def _latent(base, parent):
function _conditional (line 52) | def _conditional(base, parent):
function _compound (line 58) | def _compound(base, parent):
class BetaBinomialPair (line 62) | class BetaBinomialPair:
method __init__ (line 63) | def __init__(self):
method latent (line 67) | def latent(self, *args, **kwargs):
method conditional (line 71) | def conditional(self, *args, **kwargs):
method posterior (line 75) | def posterior(self, obs):
method compound (line 90) | def compound(self):
class GammaPoissonPair (line 98) | class GammaPoissonPair:
method __init__ (line 99) | def __init__(self):
method latent (line 103) | def latent(self, *args, **kwargs):
method conditional (line 107) | def conditional(self, *args, **kwargs):
method posterior (line 111) | def posterior(self, obs):
method compound (line 119) | def compound(self):
class UncollapseConjugateMessenger (line 125) | class UncollapseConjugateMessenger(Messenger):
method __init__ (line 131) | def __init__(self, trace):
method _pyro_sample (line 141) | def _pyro_sample(self, msg):
function uncollapse_conjugate (line 172) | def uncollapse_conjugate(fn=None, trace=None):
class CollapseConjugateMessenger (line 185) | class CollapseConjugateMessenger(Messenger):
method _pyro_sample (line 186) | def _pyro_sample(self, msg):
function collapse_conjugate (line 198) | def collapse_conjugate(fn=None):
function posterior_replay (line 210) | def posterior_replay(model, posterior_samples, *args, **kwargs):
FILE: pyro/contrib/easyguide/easyguide.py
class _EasyGuideMeta (line 22) | class _EasyGuideMeta(type(PyroModule), ABCMeta):
class EasyGuide (line 26) | class EasyGuide(PyroModule, metaclass=_EasyGuideMeta):
method __init__ (line 46) | def __init__(self, model):
method model (line 56) | def model(self):
method _setup_prototype (line 59) | def _setup_prototype(self, *args, **kwargs):
method guide (line 75) | def guide(self, *args, **kargs):
method init (line 81) | def init(self, site):
method forward (line 95) | def forward(self, *args, **kwargs):
method plate (line 108) | def plate(
method group (line 122) | def group(self, match=".*"):
method map_estimate (line 145) | def map_estimate(self, name):
class Group (line 177) | class Group:
method __init__ (line 189) | def __init__(self, guide, sites):
method __getstate__ (line 232) | def __getstate__(self):
method __setstate__ (line 237) | def __setstate__(self, state):
method guide (line 242) | def guide(self):
method sample (line 245) | def sample(self, guide_name, fn, infer=None):
method map_estimate (line 305) | def map_estimate(self):
function easy_guide (line 318) | def easy_guide(model):
FILE: pyro/contrib/epidemiology/compartmental.py
function _require_double_precision (line 57) | def _require_double_precision():
function _disallow_latent_variables (line 67) | def _disallow_latent_variables(section_name):
class CompartmentalModel (line 81) | class CompartmentalModel(ABC):
method __init__ (line 150) | def __init__(self, compartments, duration, population, *, approximate=...
method time_plate (line 183) | def time_plate(self):
method region_plate (line 194) | def region_plate(self):
method _clear_plates (line 206) | def _clear_plates(self):
method full_mass (line 211) | def full_mass(self):
method series (line 226) | def series(self):
method global_model (line 248) | def global_model(self):
method initialize (line 258) | def initialize(self, params):
method transition (line 269) | def transition(self, params, state, t):
method finalize (line 297) | def finalize(self, params, prev, curr):
method compute_flows (line 322) | def compute_flows(self, prev, curr, t):
method generate (line 361) | def generate(self, fixed={}):
method fit_svi (line 384) | def fit_svi(
method fit_mcmc (line 534) | def fit_mcmc(self, **options):
method predict (line 663) | def predict(self, forecast=0):
method heuristic (line 737) | def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):
method _heuristic (line 788) | def _heuristic(self, haar, **options):
method _concat_series (line 804) | def _concat_series(self, samples, trace, forecast=0):
method _non_compartmental (line 829) | def _non_compartmental(self):
method _sample_auxiliary (line 861) | def _sample_auxiliary(self):
method _transition_bwd (line 900) | def _transition_bwd(self, params, prev, curr, t):
method _generative_model (line 921) | def _generative_model(self, forecast=0):
method _sequential_model (line 948) | def _sequential_model(self):
method _quantized_model (line 1000) | def _quantized_model(self):
method _relaxed_model (line 1097) | def _relaxed_model(self):
class _SMCModel (line 1138) | class _SMCModel:
method __init__ (line 1143) | def __init__(self, model):
method init (line 1147) | def init(self, state):
method step (line 1158) | def step(self, state):
class _SMCGuide (line 1180) | class _SMCGuide(_SMCModel):
method init (line 1185) | def init(self, state):
method step (line 1188) | def step(self, state):
class _HaarSplitReparam (line 1193) | class _HaarSplitReparam:
method __init__ (line 1199) | def __init__(self, split, duration, dims, supports):
method __bool__ (line 1206) | def __bool__(self):
method reparam (line 1209) | def reparam(self, model):
method aux_to_user (line 1229) | def aux_to_user(self, samples):
FILE: pyro/contrib/epidemiology/distributions.py
function _all (line 17) | def _all(x):
function _is_zero (line 21) | def _is_zero(x):
function set_approx_sample_thresh (line 26) | def set_approx_sample_thresh(thresh):
function set_approx_log_prob_tol (line 56) | def set_approx_log_prob_tol(tol):
function set_relaxed_distributions (line 84) | def set_relaxed_distributions(relaxed=True):
function _validate_overdispersion (line 94) | def _validate_overdispersion(overdispersion):
function _relaxed_binomial (line 102) | def _relaxed_binomial(total_count, probs):
function _relaxed_beta_binomial (line 117) | def _relaxed_beta_binomial(concentration1, concentration0, total_count):
function binomial_dist (line 137) | def binomial_dist(total_count, probs, *, overdispersion=0.0):
function beta_binomial_dist (line 194) | def beta_binomial_dist(
function poisson_dist (line 230) | def poisson_dist(rate, *, overdispersion=0.0):
function negative_binomial_dist (line 237) | def negative_binomial_dist(
function infection_dist (line 246) | def infection_dist(
FILE: pyro/contrib/epidemiology/models.py
class SimpleSIRModel (line 16) | class SimpleSIRModel(CompartmentalModel):
method __init__ (line 35) | def __init__(self, population, recovery_time, data):
method global_model (line 46) | def global_model(self):
method initialize (line 52) | def initialize(self, params):
method transition (line 56) | def transition(self, params, state, t):
class SimpleSEIRModel (line 84) | class SimpleSEIRModel(CompartmentalModel):
method __init__ (line 106) | def __init__(self, population, incubation_time, recovery_time, data):
method global_model (line 121) | def global_model(self):
method initialize (line 128) | def initialize(self, params):
method transition (line 132) | def transition(self, params, state, t):
class SimpleSEIRDModel (line 162) | class SimpleSEIRDModel(CompartmentalModel):
method __init__ (line 189) | def __init__(
method global_model (line 210) | def global_model(self):
method initialize (line 218) | def initialize(self, params):
method transition (line 222) | def transition(self, params, state, t):
method compute_flows (line 260) | def compute_flows(self, prev, curr, t):
class OverdispersedSIRModel (line 275) | class OverdispersedSIRModel(CompartmentalModel):
method __init__ (line 314) | def __init__(self, population, recovery_time, data):
method global_model (line 325) | def global_model(self):
method initialize (line 332) | def initialize(self, params):
method transition (line 336) | def transition(self, params, state, t):
class OverdispersedSEIRModel (line 367) | class OverdispersedSEIRModel(CompartmentalModel):
method __init__ (line 408) | def __init__(self, population, incubation_time, recovery_time, data):
method global_model (line 423) | def global_model(self):
method initialize (line 431) | def initialize(self, params):
method transition (line 435) | def transition(self, params, state, t):
class SuperspreadingSIRModel (line 470) | class SuperspreadingSIRModel(CompartmentalModel):
method __init__ (line 509) | def __init__(self, population, recovery_time, data):
method global_model (line 520) | def global_model(self):
method initialize (line 527) | def initialize(self, params):
method transition (line 531) | def transition(self, params, state, t):
class SuperspreadingSEIRModel (line 560) | class SuperspreadingSEIRModel(CompartmentalModel):
method __init__ (line 610) | def __init__(
method global_model (line 642) | def global_model(self):
method initialize (line 650) | def initialize(self, params):
method transition (line 654) | def transition(self, params, state, t):
class HeterogeneousSIRModel (line 696) | class HeterogeneousSIRModel(CompartmentalModel):
method __init__ (line 716) | def __init__(self, population, recovery_time, data):
method global_model (line 727) | def global_model(self):
method initialize (line 753) | def initialize(self, params):
method transition (line 759) | def transition(self, params, state, t):
class SparseSIRModel (line 797) | class SparseSIRModel(CompartmentalModel):
method __init__ (line 825) | def __init__(self, population, recovery_time, data, mask):
method global_model (line 838) | def global_model(self):
method initialize (line 844) | def initialize(self, params):
method transition (line 848) | def transition(self, params, state, t):
method compute_flows (line 880) | def compute_flows(self, prev, curr, t):
class UnknownStartSIRModel (line 892) | class UnknownStartSIRModel(CompartmentalModel):
method __init__ (line 917) | def __init__(self, population, recovery_time, pre_obs_window, data):
method global_model (line 943) | def global_model(self):
method initialize (line 968) | def initialize(self, params):
method transition (line 972) | def transition(self, params, state, t):
method predict (line 1000) | def predict(self, forecast=0):
class RegionalSIRModel (line 1022) | class RegionalSIRModel(CompartmentalModel):
method __init__ (line 1064) | def __init__(self, population, coupling, recovery_time, data):
method global_model (line 1084) | def global_model(self):
method initialize (line 1100) | def initialize(self, params):
method transition (line 1107) | def transition(self, params, state, t):
class HeterogeneousRegionalSIRModel (line 1144) | class HeterogeneousRegionalSIRModel(CompartmentalModel):
method __init__ (line 1171) | def __init__(self, population, coupling, recovery_time, data):
method global_model (line 1191) | def global_model(self):
method initialize (line 1205) | def initialize(self, params):
method transition (line 1217) | def transition(self, params, state, t):
FILE: pyro/contrib/epidemiology/util.py
function clamp (line 14) | def clamp(tensor, *, min=None, max=None):
function cat2 (line 30) | def cat2(lhs, rhs, *, dim=-1):
function align_samples (line 56) | def align_samples(samples, model, particle_dim):
function compute_bin_probs (line 174) | def compute_bin_probs(s, num_quant_bins):
function _all (line 332) | def _all(x):
function _unsqueeze (line 336) | def _unsqueeze(x):
function quantize (line 340) | def quantize(name, x_real, min, max, num_quant_bins=4):
function quantize_enumerate (line 363) | def quantize_enumerate(x_real, min, max, num_quant_bins=4):
FILE: pyro/contrib/examples/bart.py
function _load_hourly_od (line 40) | def _load_hourly_od(basename):
function load_bart_od (line 91) | def load_bart_od():
function load_fake_od (line 167) | def load_fake_od():
FILE: pyro/contrib/examples/finance.py
function load_snp500 (line 17) | def load_snp500():
FILE: pyro/contrib/examples/multi_mnist.py
function imresize (line 21) | def imresize(arr, size):
function sample_one (line 25) | def sample_one(canvas_size, mnist):
function sample_multi (line 42) | def sample_multi(num_digits, canvas_size, mnist):
function mk_dataset (line 56) | def mk_dataset(n, mnist, max_digits, canvas_size):
function load_mnist (line 67) | def load_mnist(root_path):
function load (line 75) | def load(root_path):
FILE: pyro/contrib/examples/nextstrain.py
function load_nextstrain_counts (line 17) | def load_nextstrain_counts(map_location=None) -> dict:
FILE: pyro/contrib/examples/polyphonic_data_loader.py
function process_data (line 58) | def process_data(base_path, dataset, min_note=21, note_range=88):
function load_data (line 100) | def load_data(dataset):
function reverse_sequences (line 119) | def reverse_sequences(mini_batch, seq_lengths):
function pad_and_reverse (line 131) | def pad_and_reverse(rnn_output, seq_lengths):
function get_mini_batch_mask (line 139) | def get_mini_batch_mask(mini_batch, seq_lengths):
function get_mini_batch (line 151) | def get_mini_batch(mini_batch_indices, sequences, seq_lengths, cuda=False):
FILE: pyro/contrib/examples/scanvi_data.py
class BatchDataLoader (line 18) | class BatchDataLoader(object):
method __init__ (line 24) | def __init__(self, data_x, data_y, batch_size, num_classes=4, missing_...
method size (line 43) | def size(self):
method __len__ (line 46) | def __len__(self):
method _sample_batch_indices (line 49) | def _sample_batch_indices(self):
method __iter__ (line 66) | def __iter__(self):
function _get_score (line 81) | def _get_score(normalized_adata, gene_set):
function _get_cell_mask (line 95) | def _get_cell_mask(normalized_adata, gene_set):
function get_data (line 107) | def get_data(dataset="pbmc", batch_size=100, cuda=False):
FILE: pyro/contrib/examples/util.py
class MNIST (line 12) | class MNIST(datasets.MNIST):
method download (line 15) | def download(self) -> None:
function get_data_loader (line 43) | def get_data_loader(
function print_and_log (line 64) | def print_and_log(logger, msg):
function get_data_directory (line 73) | def get_data_directory(filepath=None):
function _mkdir_p (line 79) | def _mkdir_p(dirname):
FILE: pyro/contrib/forecast/evaluate.py
function eval_mae (line 19) | def eval_mae(pred, truth):
function eval_rmse (line 32) | def eval_rmse(pred, truth):
function eval_crps (line 46) | def eval_crps(pred, truth):
function backtest (line 71) | def backtest(
FILE: pyro/contrib/forecast/forecaster.py
class _ForecastingModelMeta (line 33) | class _ForecastingModelMeta(type(PyroModule), ABCMeta):
class ForecastingModel (line 37) | class ForecastingModel(PyroModule, metaclass=_ForecastingModelMeta):
method __init__ (line 44) | def __init__(self):
method model (line 49) | def model(self, zero_data, covariates):
method time_plate (line 71) | def time_plate(self):
method predict (line 82) | def predict(self, noise_dist, prediction):
method forward (line 169) | def forward(self, data, covariates):
class Forecaster (line 197) | class Forecaster(nn.Module):
method __init__ (line 262) | def __init__(
method __call__ (line 340) | def __call__(self, data, covariates, num_samples, batch_size=None):
method forward (line 365) | def forward(self, data, covariates, num_samples, batch_size=None):
class HMCForecaster (line 395) | class HMCForecaster(nn.Module):
method __init__ (line 427) | def __init__(
method __call__ (line 487) | def __call__(self, data, covariates, num_samples, batch_size=None):
method forward (line 512) | def forward(self, data, covariates, num_samples, batch_size=None):
FILE: pyro/contrib/forecast/util.py
function time_reparam_dct (line 17) | def time_reparam_dct(msg):
function time_reparam_haar (line 30) | def time_reparam_haar(msg):
class MarkDCTParamMessenger (line 43) | class MarkDCTParamMessenger(Messenger):
method __init__ (line 52) | def __init__(self, name):
method _postprocess_message (line 56) | def _postprocess_message(self, msg):
class PrefixWarmStartMessenger (line 70) | class PrefixWarmStartMessenger(Messenger):
method _pyro_param (line 77) | def _pyro_param(self, msg):
class PrefixReplayMessenger (line 113) | class PrefixReplayMessenger(Messenger):
method __init__ (line 124) | def __init__(self, trace):
method _pyro_post_sample (line 128) | def _pyro_post_sample(self, msg):
class PrefixConditionMessenger (line 154) | class PrefixConditionMessenger(Messenger):
method __init__ (line 162) | def __init__(self, data):
method _pyro_sample (line 166) | def _pyro_sample(self, msg):
function prefix_condition (line 205) | def prefix_condition(d, data):
function _ (line 227) | def _(d, data):
function _ (line 234) | def _(d, data):
function _ (line 240) | def _(d, data):
function _ (line 247) | def _(d, data):
function _ (line 253) | def _(d, data):
function _prefix_condition_univariate (line 260) | def _prefix_condition_univariate(d, data):
function _ (line 271) | def _(d, data):
function reshape_batch (line 279) | def reshape_batch(d, batch_shape):
function _ (line 298) | def _(d, batch_shape):
function _ (line 307) | def _(d, batch_shape):
function _ (line 314) | def _(d, batch_shape):
function _ (line 321) | def _(d, batch_shape):
function _ (line 327) | def _(d, batch_shape):
function _reshape_batch_univariate (line 337) | def _reshape_batch_univariate(d, batch_shape):
function _ (line 351) | def _(d, batch_shape):
function _ (line 359) | def _(d, batch_shape):
function _ (line 388) | def _(d, batch_shape):
function reshape_transform_batch (line 431) | def reshape_transform_batch(t, old_shape, new_shape):
function _reshape_batch_univariate_transform (line 451) | def _reshape_batch_univariate_transform(t, old_shape, new_shape):
function _ (line 463) | def _(t, old_shape, new_shape):
function _ (line 468) | def _(t, old_shape, new_shape):
function _ (line 480) | def _(t, old_shape, new_shape):
FILE: pyro/contrib/funsor/__init__.py
function plate (line 24) | def plate(*args, **kwargs):
FILE: pyro/contrib/funsor/handlers/__init__.py
function enum (line 26) | def enum(fn=None, first_available_dim=None): ...
function markov (line 30) | def markov(fn=None, history=1, keep=False): ...
function named (line 34) | def named(fn=None, first_available_dim=None): ...
function plate (line 38) | def plate(
function replay (line 51) | def replay(fn=None, trace=None, params=None): ...
function trace (line 55) | def trace(fn=None, graph_type=None, param_only=None, pack_online=True): ...
function vectorized_markov (line 59) | def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1...
FILE: pyro/contrib/funsor/handlers/enum_messenger.py
function _get_support_value (line 28) | def _get_support_value(funsor_dist, name, **kwargs):
function _get_support_value_contraction (line 35) | def _get_support_value_contraction(funsor_dist, name, **kwargs):
function _get_support_value_delta (line 46) | def _get_support_value_delta(funsor_dist, name, **kwargs):
function _get_support_value_tensor (line 52) | def _get_support_value_tensor(funsor_dist, name, **kwargs):
function _get_support_value_distribution (line 62) | def _get_support_value_distribution(funsor_dist, name, expand=False):
function _enum_strategy_default (line 67) | def _enum_strategy_default(dist, msg):
function _enum_strategy_diagonal (line 78) | def _enum_strategy_diagonal(dist, msg):
function _enum_strategy_mixture (line 103) | def _enum_strategy_mixture(dist, msg):
function _enum_strategy_full (line 146) | def _enum_strategy_full(dist, msg):
function _enum_strategy_exact (line 156) | def _enum_strategy_exact(dist, msg):
function enumerate_site (line 162) | def enumerate_site(dist, msg):
class EnumMessenger (line 182) | class EnumMessenger(NamedMessenger):
method _pyro_sample (line 188) | def _pyro_sample(self, msg):
function queue (line 213) | def queue(
FILE: pyro/contrib/funsor/handlers/named_messenger.py
class NamedMessenger (line 16) | class NamedMessenger(ReentrantMessenger):
method __init__ (line 27) | def __init__(self, first_available_dim=None):
method __enter__ (line 35) | def __enter__(self):
method __exit__ (line 48) | def __exit__(self, *args, **kwargs):
method _pyro_to_data (line 61) | def _pyro_to_data(msg):
method _pyro_to_funsor (line 83) | def _pyro_to_funsor(msg):
class MarkovMessenger (line 117) | class MarkovMessenger(NamedMessenger):
method __init__ (line 130) | def __init__(self, history=1, keep=False):
method __call__ (line 137) | def __call__(self, fn):
method __iter__ (line 143) | def __iter__(self):
method __enter__ (line 152) | def __enter__(self):
method __exit__ (line 166) | def __exit__(self, *args, **kwargs):
class GlobalNamedMessenger (line 174) | class GlobalNamedMessenger(NamedMessenger):
method __init__ (line 185) | def __init__(self, first_available_dim=None):
method __enter__ (line 189) | def __enter__(self):
method __exit__ (line 198) | def __exit__(self, *args):
FILE: pyro/contrib/funsor/handlers/plate_messenger.py
class IndepMessenger (line 29) | class IndepMessenger(GlobalNamedMessenger):
method __init__ (line 35) | def __init__(self, name=None, size=None, dim=None, indices=None):
method __enter__ (line 57) | def __enter__(self):
method _pyro_sample (line 65) | def _pyro_sample(self, msg):
method _pyro_param (line 69) | def _pyro_param(self, msg):
class SubsampleMessenger (line 75) | class SubsampleMessenger(IndepMessenger):
method __init__ (line 76) | def __init__(
method _pyro_sample (line 95) | def _pyro_sample(self, msg):
method _pyro_param (line 99) | def _pyro_param(self, msg):
method _subsample_site_value (line 103) | def _subsample_site_value(self, value, event_dim=None):
method _pyro_post_param (line 115) | def _pyro_post_param(self, msg):
method _pyro_post_subsample (line 131) | def _pyro_post_subsample(self, msg):
class PlateMessenger (line 136) | class PlateMessenger(SubsampleMessenger):
method __enter__ (line 143) | def __enter__(self):
method _pyro_sample (line 147) | def _pyro_sample(self, msg):
method __iter__ (line 151) | def __iter__(self):
class _SequentialPlateMessenger (line 159) | class _SequentialPlateMessenger(Messenger):
method __init__ (line 164) | def __init__(self, name, size, indices, scale):
method __iter__ (line 172) | def __iter__(self):
method _pyro_sample (line 179) | def _pyro_sample(self, msg):
method _pyro_param (line 184) | def _pyro_param(self, msg):
class VectorizedMarkovMessenger (line 190) | class VectorizedMarkovMessenger(NamedMessenger):
method __init__ (line 296) | def __init__(self, name=None, size=None, dim=None, history=1):
method _markov_chain (line 305) | def _markov_chain(name=None, markov_vars=set(), suffixes=list()):
method __iter__ (line 325) | def __iter__(self):
method _pyro_sample (line 348) | def _pyro_sample(self, msg):
method _pyro_post_sample (line 366) | def _pyro_post_sample(self, msg):
FILE: pyro/contrib/funsor/handlers/primitives.py
function to_funsor (line 9) | def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL):
function to_data (line 21) | def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL):
FILE: pyro/contrib/funsor/handlers/replay_messenger.py
class ReplayMessenger (line 8) | class ReplayMessenger(OrigReplayMessenger):
method _pyro_sample (line 15) | def _pyro_sample(self, msg):
FILE: pyro/contrib/funsor/handlers/runtime.py
class StackFrame (line 8) | class StackFrame:
method __init__ (line 14) | def __init__(self, name_to_dim, dim_to_name, history=1, keep=False):
method __setitem__ (line 28) | def __setitem__(self, key, value):
method __getitem__ (line 37) | def __getitem__(self, key):
method __delitem__ (line 41) | def __delitem__(self, key):
method __contains__ (line 51) | def __contains__(self, key):
class DimType (line 56) | class DimType(Enum):
class DimStack (line 68) | class DimStack:
method __init__ (line 77) | def __init__(self):
method set_first_available_dim (line 93) | def set_first_available_dim(self, dim):
method push_global (line 98) | def push_global(self, frame):
method pop_global (line 101) | def pop_global(self):
method push_iter (line 105) | def push_iter(self, frame):
method pop_iter (line 108) | def pop_iter(self):
method push_local (line 112) | def push_local(self, frame):
method pop_local (line 115) | def pop_local(self):
method global_frame (line 120) | def global_frame(self):
method local_frame (line 124) | def local_frame(self):
method current_write_env (line 128) | def current_write_env(self):
method current_read_env (line 136) | def current_read_env(self):
method _genvalue (line 147) | def _genvalue(self, key, value_request):
method allocate (line 183) | def allocate(self, key_to_value_request):
method names_from_batch_shape (line 227) | def names_from_batch_shape(self, batch_shape, dim_type=DimType.LOCAL):
FILE: pyro/contrib/funsor/handlers/trace_messenger.py
function _mask_fn (line 13) | def _mask_fn(fn, mask):
class TraceMessenger (line 20) | class TraceMessenger(OrigTraceMessenger):
method __init__ (line 30) | def __init__(self, graph_type=None, param_only=None, pack_online=True):
method _pyro_post_sample (line 34) | def _pyro_post_sample(self, msg):
method _pyro_post_markov_chain (line 81) | def _pyro_post_markov_chain(self, msg):
FILE: pyro/contrib/funsor/infer/discrete.py
function _sample_posterior (line 15) | def _sample_posterior(model, first_available_dim, temperature, *args, **...
function infer_discrete (line 71) | def infer_discrete(model, first_available_dim=None, temperature=1):
FILE: pyro/contrib/funsor/infer/elbo.py
class ELBO (line 9) | class ELBO(_OrigELBO):
method _get_trace (line 10) | def _get_trace(self, *args, **kwargs):
method differentiable_loss (line 13) | def differentiable_loss(self, model, guide, *args, **kwargs):
method loss (line 16) | def loss(self, model, guide, *args, **kwargs):
method loss_and_grads (line 19) | def loss_and_grads(self, model, guide, *args, **kwargs):
class Jit_ELBO (line 25) | class Jit_ELBO(ELBO):
method differentiable_loss (line 26) | def differentiable_loss(self, model, guide, *args, **kwargs):
FILE: pyro/contrib/funsor/infer/trace_elbo.py
class Trace_ELBO (line 19) | class Trace_ELBO(ELBO):
method differentiable_loss (line 20) | def differentiable_loss(self, model, guide, *args, **kwargs):
class JitTrace_ELBO (line 51) | class JitTrace_ELBO(Jit_ELBO, Trace_ELBO):
FILE: pyro/contrib/funsor/infer/traceenum_elbo.py
function apply_optimizer (line 20) | def apply_optimizer(x):
function terms_from_trace (line 28) | def terms_from_trace(tr):
class TraceMarkovEnum_ELBO (line 93) | class TraceMarkovEnum_ELBO(ELBO):
method differentiable_loss (line 94) | def differentiable_loss(self, model, guide, *args, **kwargs):
class TraceEnum_ELBO (line 172) | class TraceEnum_ELBO(ELBO):
method differentiable_loss (line 173) | def differentiable_loss(self, model, guide, *args, **kwargs):
class JitTraceEnum_ELBO (line 278) | class JitTraceEnum_ELBO(Jit_ELBO, TraceEnum_ELBO):
class JitTraceMarkovEnum_ELBO (line 282) | class JitTraceMarkovEnum_ELBO(Jit_ELBO, TraceMarkovEnum_ELBO):
FILE: pyro/contrib/funsor/infer/tracetmc_elbo.py
class TraceTMC_ELBO (line 17) | class TraceTMC_ELBO(ELBO):
method differentiable_loss (line 18) | def differentiable_loss(self, model, guide, *args, **kwargs):
class JitTraceTMC_ELBO (line 53) | class JitTraceTMC_ELBO(Jit_ELBO, TraceTMC_ELBO):
FILE: pyro/contrib/gp/kernels/brownian.py
class Brownian (line 11) | class Brownian(Kernel):
method __init__ (line 26) | def __init__(self, input_dim, variance=None, active_dims=None):
method forward (line 34) | def forward(self, X, Z=None, diag=False):
FILE: pyro/contrib/gp/kernels/coregionalize.py
class Coregionalize (line 12) | class Coregionalize(Kernel):
method __init__ (line 48) | def __init__(
method forward (line 80) | def forward(self, X, Z=None, diag=False):
FILE: pyro/contrib/gp/kernels/dot_product.py
class DotProduct (line 11) | class DotProduct(Kernel):
method __init__ (line 16) | def __init__(self, input_dim, variance=None, active_dims=None):
method _dot_product (line 22) | def _dot_product(self, X, Z=None, diag=False):
class Linear (line 39) | class Linear(DotProduct):
method __init__ (line 53) | def __init__(self, input_dim, variance=None, active_dims=None):
method forward (line 56) | def forward(self, X, Z=None, diag=False):
class Polynomial (line 60) | class Polynomial(DotProduct):
method __init__ (line 70) | def __init__(self, input_dim, variance=None, bias=None, degree=1, acti...
method forward (line 82) | def forward(self, X, Z=None, diag=False):
FILE: pyro/contrib/gp/kernels/isotropic.py
function _torch_sqrt (line 11) | def _torch_sqrt(x, eps=1e-12):
class Isotropy (line 20) | class Isotropy(Kernel):
method __init__ (line 32) | def __init__(self, input_dim, variance=None, lengthscale=None, active_...
method _square_scaled_dist (line 41) | def _square_scaled_dist(self, X, Z=None):
method _scaled_dist (line 60) | def _scaled_dist(self, X, Z=None):
method _diag (line 66) | def _diag(self, X):
class RBF (line 73) | class RBF(Isotropy):
method __init__ (line 82) | def __init__(self, input_dim, variance=None, lengthscale=None, active_...
method forward (line 85) | def forward(self, X, Z=None, diag=False):
class RationalQuadratic (line 93) | class RationalQuadratic(Isotropy):
method __init__ (line 104) | def __init__(
method forward (line 118) | def forward(self, X, Z=None, diag=False):
class Exponential (line 128) | class Exponential(Isotropy):
method __init__ (line 135) | def __init__(self, input_dim, variance=None, lengthscale=None, active_...
method forward (line 138) | def forward(self, X, Z=None, diag=False):
class Matern32 (line 146) | class Matern32(Isotropy):
method __init__ (line 154) | def __init__(self, input_dim, variance=None, lengthscale=None, active_...
method forward (line 157) | def forward(self, X, Z=None, diag=False):
class Matern52 (line 166) | class Matern52(Isotropy):
method __init__ (line 174) | def __init__(self, input_dim, variance=None, lengthscale=None, active_...
method forward (line 177) | def forward(self, X, Z=None, diag=False):
FILE: pyro/contrib/gp/kernels/kernel.py
class Kernel (line 9) | class Kernel(Parameterized):
method __init__ (line 30) | def __init__(self, input_dim, active_dims=None):
method forward (line 42) | def forward(self, X, Z=None, diag=False):
method _slice_input (line 57) | def _slice_input(self, X):
class Combination (line 74) | class Combination(Kernel):
method __init__ (line 83) | def __init__(self, kern0, kern1):
class Sum (line 105) | class Sum(Combination):
method forward (line 111) | def forward(self, X, Z=None, diag=False):
class Product (line 118) | class Product(Combination):
method forward (line 124) | def forward(self, X, Z=None, diag=False):
class Transforming (line 131) | class Transforming(Kernel):
method __init__ (line 139) | def __init__(self, kern):
class Exponent (line 145) | class Exponent(Transforming):
method forward (line 152) | def forward(self, X, Z=None, diag=False):
class VerticalScaling (line 156) | class VerticalScaling(Transforming):
method __init__ (line 167) | def __init__(self, kern, vscaling_fn):
method forward (line 172) | def forward(self, X, Z=None, diag=False):
function _Horner_evaluate (line 188) | def _Horner_evaluate(x, coef):
class Warping (line 200) | class Warping(Transforming):
method __init__ (line 229) | def __init__(self, kern, iwarping_fn=None, owarping_coef=None):
method forward (line 248) | def forward(self, X, Z=None, diag=False):
FILE: pyro/contrib/gp/kernels/periodic.py
class Cosine (line 14) | class Cosine(Isotropy):
method __init__ (line 23) | def __init__(self, input_dim, variance=None, lengthscale=None, active_...
method forward (line 26) | def forward(self, X, Z=None, diag=False):
class Periodic (line 34) | class Periodic(Kernel):
method __init__ (line 51) | def __init__(
method forward (line 65) | def forward(self, X, Z=None, diag=False):
FILE: pyro/contrib/gp/kernels/static.py
class Constant (line 11) | class Constant(Kernel):
method __init__ (line 18) | def __init__(self, input_dim, variance=None, active_dims=None):
method forward (line 24) | def forward(self, X, Z=None, diag=False):
class WhiteNoise (line 33) | class WhiteNoise(Kernel):
method __init__ (line 42) | def __init__(self, input_dim, variance=None, active_dims=None):
method forward (line 48) | def forward(self, X, Z=None, diag=False):
FILE: pyro/contrib/gp/likelihoods/binary.py
class Binary (line 11) | class Binary(Likelihood):
method __init__ (line 24) | def __init__(self, response_function=None):
method forward (line 30) | def forward(self, f_loc, f_var, y=None):
FILE: pyro/contrib/gp/likelihoods/gaussian.py
class Gaussian (line 13) | class Gaussian(Likelihood):
method __init__ (line 23) | def __init__(self, variance=None):
method forward (line 29) | def forward(self, f_loc, f_var, y=None):
FILE: pyro/contrib/gp/likelihoods/likelihood.py
class Likelihood (line 7) | class Likelihood(Parameterized):
method __init__ (line 15) | def __init__(self):
method forward (line 18) | def forward(self, f_loc, f_var, y=None):
FILE: pyro/contrib/gp/likelihoods/multi_class.py
function _softmax (line 11) | def _softmax(x):
class MultiClass (line 15) | class MultiClass(Likelihood):
method __init__ (line 29) | def __init__(self, num_classes, response_function=None):
method forward (line 36) | def forward(self, f_loc, f_var, y=None):
FILE: pyro/contrib/gp/likelihoods/poisson.py
class Poisson (line 11) | class Poisson(Likelihood):
method __init__ (line 23) | def __init__(self, response_function=None):
method forward (line 29) | def forward(self, f_loc, f_var, y=None):
FILE: pyro/contrib/gp/models/gplvm.py
class GPLVM (line 9) | class GPLVM(Parameterized):
method __init__ (line 59) | def __init__(self, base_model):
method model (line 75) | def model(self):
method guide (line 82) | def guide(self):
method forward (line 88) | def forward(self, **kwargs):
FILE: pyro/contrib/gp/models/gpr.py
class GPRegression (line 16) | class GPRegression(GPModel):
method __init__ (line 69) | def __init__(self, X, y, kernel, noise=None, mean_function=None, jitte...
method model (line 83) | def model(self):
method guide (line 106) | def guide(self):
method forward (line 110) | def forward(self, Xnew, full_cov=False, noiseless=True):
method iter_sample (line 159) | def iter_sample(self, noiseless=True):
FILE: pyro/contrib/gp/models/model.py
function _zero_mean_function (line 9) | def _zero_mean_function(x):
class GPModel (line 13) | class GPModel(Parameterized):
method __init__ (line 92) | def __init__(self, X, y, kernel, mean_function=None, jitter=1e-6):
method model (line 109) | def model(self):
method guide (line 116) | def guide(self):
method forward (line 123) | def forward(self, Xnew, full_cov=False):
method set_data (line 144) | def set_data(self, X, y=None):
method _check_Xnew_shape (line 207) | def _check_Xnew_shape(self, Xnew):
FILE: pyro/contrib/gp/models/sgpr.py
class SparseGPRegression (line 14) | class SparseGPRegression(GPModel):
method __init__ (line 98) | def __init__(
method model (line 130) | def model(self):
method guide (line 178) | def guide(self):
method forward (line 182) | def forward(self, Xnew, full_cov=False, noiseless=True):
FILE: pyro/contrib/gp/models/vgp.py
class VariationalGP (line 16) | class VariationalGP(GPModel):
method __init__ (line 63) | def __init__(
method model (line 100) | def model(self):
method guide (line 137) | def guide(self):
method forward (line 148) | def forward(self, Xnew, full_cov=False):
FILE: pyro/contrib/gp/models/vsgp.py
class VariationalSparseGP (line 17) | class VariationalSparseGP(GPModel):
method __init__ (line 82) | def __init__(
method model (line 127) | def model(self):
method guide (line 174) | def guide(self):
method forward (line 185) | def forward(self, Xnew, full_cov=False):
FILE: pyro/contrib/gp/parameterized.py
function _is_real_support (line 17) | def _is_real_support(support):
function _get_sample_fn (line 24) | def _get_sample_fn(module, name):
class Parameterized (line 57) | class Parameterized(PyroModule):
method __init__ (line 92) | def __init__(self):
method set_prior (line 98) | def set_prior(self, name, prior):
method __setattr__ (line 113) | def __setattr__(self, name, value):
method autoguide (line 122) | def autoguide(self, name, dist_constructor):
method _load_pyro_samples (line 181) | def _load_pyro_samples(self):
method set_mode (line 190) | def set_mode(self, mode):
method mode (line 207) | def mode(self):
method mode (line 211) | def mode(self, mode):
FILE: pyro/contrib/gp/util.py
function conditional (line 10) | def conditional(
function train (line 161) | def train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num...
FILE: pyro/contrib/minipyro.py
function get_param_store (line 38) | def get_param_store():
class Messenger (line 43) | class Messenger:
method __init__ (line 44) | def __init__(self, fn=None):
method __enter__ (line 49) | def __enter__(self):
method __exit__ (line 52) | def __exit__(self, *args, **kwargs):
method process_message (line 56) | def process_message(self, msg):
method postprocess_message (line 59) | def postprocess_message(self, msg):
method __call__ (line 62) | def __call__(self, *args, **kwargs):
class trace (line 70) | class trace(Messenger):
method __enter__ (line 71) | def __enter__(self):
method postprocess_message (line 78) | def postprocess_message(self, msg):
method get_trace (line 84) | def get_trace(self, *args, **kwargs):
class replay (line 94) | class replay(Messenger):
method __init__ (line 95) | def __init__(self, fn, guide_trace):
method process_message (line 99) | def process_message(self, msg):
class block (line 107) | class block(Messenger):
method __init__ (line 108) | def __init__(self, fn=None, hide_fn=lambda msg: True):
method process_message (line 112) | def process_message(self, msg):
class seed (line 118) | class seed(Messenger):
method __init__ (line 119) | def __init__(self, fn=None, rng_seed=None):
method __enter__ (line 123) | def __enter__(self):
method __exit__ (line 133) | def __exit__(self, type, value, traceback):
class PlateMessenger (line 143) | class PlateMessenger(Messenger):
method __init__ (line 144) | def __init__(self, fn, size, dim):
method process_message (line 150) | def process_message(self, msg):
method __iter__ (line 158) | def __iter__(self):
function apply_stack (line 164) | def apply_stack(msg):
function sample (line 186) | def sample(name, fn, *args, **kwargs):
function param (line 210) | def param(
function plate (line 259) | def plate(name, size, dim=None):
class Adam (line 268) | class Adam:
method __init__ (line 269) | def __init__(self, optim_args):
method __call__ (line 275) | def __call__(self, params):
class SVI (line 293) | class SVI:
method __init__ (line 294) | def __init__(self, model, guide, optim, loss):
method step (line 302) | def step(self, *args, **kwargs):
function elbo (line 328) | def elbo(model, guide, *args, **kwargs):
function Trace_ELBO (line 358) | def Trace_ELBO(**kwargs):
class JitTrace_ELBO (line 365) | class JitTrace_ELBO:
method __init__ (line 366) | def __init__(self, **kwargs):
method __call__ (line 371) | def __call__(self, model, guide, *args):
FILE: pyro/contrib/mue/dataloaders.py
class BiosequenceDataset (line 37) | class BiosequenceDataset(Dataset):
method __init__ (line 56) | def __init__(
method _load_fasta (line 102) | def _load_fasta(self, source):
method _one_hot (line 122) | def _one_hot(self, seq, alphabet, length):
method __len__ (line 136) | def __len__(self):
method __getitem__ (line 139) | def __getitem__(self, ind):
function write (line 143) | def write(x, alphabet, file, truncate_stop=False, append=False, scores=N...
FILE: pyro/contrib/mue/missingdatahmm.py
class MissingDataDiscreteHMM (line 13) | class MissingDataDiscreteHMM(TorchDistribution):
method __init__ (line 47) | def __init__(
method log_prob (line 85) | def log_prob(self, value):
method sample (line 115) | def sample(self, sample_shape=torch.Size([])):
method filter (line 147) | def filter(self, value):
method smooth (line 188) | def smooth(self, value):
method sample_states (line 220) | def sample_states(self, value):
method map_states (line 244) | def map_states(self, value):
method given_states (line 290) | def given_states(self, states):
method sample_given_states (line 308) | def sample_given_states(self, states):
FILE: pyro/contrib/mue/models.py
class ProfileHMM (line 26) | class ProfileHMM(nn.Module):
method __init__ (line 47) | def __init__(
method model (line 79) | def model(self, seq_data, local_scale):
method guide (line 132) | def guide(self, seq_data, local_scale):
method fit_svi (line 173) | def fit_svi(
method evaluate (line 242) | def evaluate(self, dataset_train, dataset_test=None, jit=False):
method _local_variables (line 276) | def _local_variables(self, name, site):
method _evaluate_local_elbo (line 280) | def _evaluate_local_elbo(self, svi, dataload, data_size):
class Encoder (line 309) | class Encoder(nn.Module):
method __init__ (line 310) | def __init__(self, data_length, alphabet_length, z_dim):
method forward (line 317) | def forward(self, data):
class FactorMuE (line 325) | class FactorMuE(nn.Module):
method __init__ (line 371) | def __init__(
method decoder (line 452) | def decoder(self, z, W, B, inverse_temp):
method model (line 488) | def model(self, seq_data, local_scale, local_prior_scale):
method guide (line 610) | def guide(self, seq_data, local_scale, local_prior_scale):
method fit_svi (line 681) | def fit_svi(
method _beta_anneal (line 763) | def _beta_anneal(self, step, batch_size, data_size, anneal_length):
method evaluate (line 770) | def evaluate(self, dataset_train, dataset_test=None, jit=False):
method _local_variables (line 809) | def _local_variables(self, name, site):
method _evaluate_local_elbo (line 813) | def _evaluate_local_elbo(self, svi, dataload, data_size):
method embed (line 841) | def embed(self, dataset, batch_size=None):
method _reconstruct_regressor_seq (line 863) | def _reconstruct_regressor_seq(self, data, ind, param):
FILE: pyro/contrib/mue/statearrangers.py
class Profile (line 8) | class Profile(nn.Module):
method __init__ (line 32) | def __init__(self, M, epsilon=1e-32):
method _make_transfer (line 40) | def _make_transfer(self):
method forward (line 135) | def forward(
function mg2k (line 205) | def mg2k(m, g, M):
FILE: pyro/contrib/oed/eig.py
function laplace_eig (line 29) | def laplace_eig(
function _eig_from_ape (line 87) | def _eig_from_ape(model, design, target_labels, ape, eig, prior_entropy_...
function _laplace_vi_ape (line 108) | def _laplace_vi_ape(
function vi_eig (line 152) | def vi_eig(
function _vi_ape (line 230) | def _vi_ape(
function nmc_eig (line 268) | def nmc_eig(
function donsker_varadhan_eig (line 376) | def donsker_varadhan_eig(
function posterior_eig (line 442) | def posterior_eig(
function _posterior_ape (line 525) | def _posterior_ape(
function marginal_eig (line 555) | def marginal_eig(
function marginal_likelihood_eig (line 620) | def marginal_likelihood_eig(
function lfire_eig (line 683) | def lfire_eig(
function vnmc_eig (line 756) | def vnmc_eig(
function opt_eig_ape_loss (line 826) | def opt_eig_ape_loss(
function monte_carlo_entropy (line 869) | def monte_carlo_entropy(model, design, target_labels, num_prior_samples=...
function _donsker_varadhan_loss (line 884) | def _donsker_varadhan_loss(model, T, observation_labels, target_labels):
function _posterior_loss (line 927) | def _posterior_loss(
function _marginal_loss (line 965) | def _marginal_loss(model, guide, observation_labels, target_labels):
function _marginal_likelihood_loss (line 994) | def _marginal_likelihood_loss(
function _lfire_loss (line 1034) | def _lfire_loss(
function _vnmc_eig_loss (line 1082) | def _vnmc_eig_loss(model, guide, observation_labels, target_labels):
function _safe_mean_terms (line 1123) | def _safe_mean_terms(terms):
function xexpx (line 1135) | def xexpx(a):
class _EwmaLogFn (line 1149) | class _EwmaLogFn(torch.autograd.Function):
method forward (line 1151) | def forward(ctx, input, ewma):
method backward (line 1156) | def backward(ctx, grad_output):
class EwmaLog (line 1164) | class EwmaLog:
method __init__ (line 1180) | def __init__(self, alpha):
method __call__ (line 1186) | def __call__(self, inputs, s, dim=0, keepdim=False):
FILE: pyro/contrib/oed/glmm/glmm.py
function known_covariance_linear_model (line 22) | def known_covariance_linear_model(
function normal_guide (line 57) | def normal_guide(observation_sd, coef_shape, coef_label="w"):
function group_linear_model (line 65) | def group_linear_model(
function group_normal_guide (line 92) | def group_normal_guide(
function zero_mean_unit_obs_sd_lm (line 102) | def zero_mean_unit_obs_sd_lm(coef_sd, coef_label="w"):
function normal_inverse_gamma_linear_model (line 110) | def normal_inverse_gamma_linear_model(
function normal_inverse_gamma_guide (line 123) | def normal_inverse_gamma_guide(coef_shape, coef_label="w", **kwargs):
function logistic_regression_model (line 132) | def logistic_regression_model(
function lmer_model (line 145) | def lmer_model(
function sigmoid_model (line 168) | def sigmoid_model(
function bayesian_linear_model (line 209) | def bayesian_linear_model(
function normal_inv_gamma_family_guide (line 348) | def normal_inv_gamma_family_guide(design, obs_sd, w_sizes, mf=False):
function group_assignment_matrix (line 409) | def group_assignment_matrix(design):
function rf_group_assignments (line 432) | def rf_group_assignments(n, random_intercept=True):
function analytic_posterior_cov (line 448) | def analytic_posterior_cov(prior_cov, x, obs_sd):
function broadcast_cat (line 464) | def broadcast_cat(ws):
FILE: pyro/contrib/oed/glmm/guides.py
class LinearModelPosteriorGuide (line 23) | class LinearModelPosteriorGuide(nn.Module):
method __init__ (line 24) | def __init__(
method get_params (line 71) | def get_params(self, y_dict, design, target_labels):
method linear_model_formula (line 75) | def linear_model_formula(self, y, design, target_labels):
method forward (line 84) | def forward(self, y_dict, design, observation_labels, target_labels):
class LinearModelLaplaceGuide (line 95) | class LinearModelLaplaceGuide(nn.Module):
method __init__ (line 107) | def __init__(self, d, w_sizes, tau_label=None, init_value=0.1, **kwargs):
method _hessian_diag (line 124) | def _hessian_diag(y, x, event_shape):
method finalize (line 164) | def finalize(self, loss, target_labels):
method forward (line 184) | def forward(self, design, target_labels=None):
class SigmoidGuide (line 214) | class SigmoidGuide(LinearModelPosteriorGuide):
method __init__ (line 215) | def __init__(self, d, n, w_sizes, **kwargs):
method get_params (line 221) | def get_params(self, y_dict, design, target_labels):
class NormalInverseGammaGuide (line 235) | class NormalInverseGammaGuide(LinearModelPosteriorGuide):
method __init__ (line 236) | def __init__(
method get_params (line 252) | def get_params(self, y_dict, design, target_labels):
method forward (line 267) | def forward(self, y_dict, design, observation_labels, target_labels):
class GuideDV (line 290) | class GuideDV(nn.Module):
method __init__ (line 295) | def __init__(self, guide):
method forward (line 299) | def forward(self, design, trace, observation_labels, target_labels):
FILE: pyro/contrib/oed/search.py
class Search (line 14) | class Search(TracePosterior):
method __init__ (line 19) | def __init__(self, model, max_tries=int(1e6), **kwargs):
method _traces (line 24) | def _traces(self, *args, **kwargs):
FILE: pyro/contrib/oed/util.py
function linear_model_ground_truth (line 13) | def linear_model_ground_truth(
FILE: pyro/contrib/randomvariable/random_variable.py
class RVMagicOps (line 21) | class RVMagicOps:
method __add__ (line 24) | def __add__(self, x: Union[float, Tensor]):
method __radd__ (line 29) | def __radd__(self, x: Union[float, Tensor]):
method __sub__ (line 34) | def __sub__(self, x: Union[float, Tensor]):
method __rsub__ (line 39) | def __rsub__(self, x: Union[float, Tensor]):
method __mul__ (line 44) | def __mul__(self, x: Union[float, Tensor]):
method __rmul__ (line 49) | def __rmul__(self, x: Union[float, Tensor]):
method __truediv__ (line 54) | def __truediv__(self, x: Union[float, Tensor]):
method __neg__ (line 59) | def __neg__(self):
method __abs__ (line 64) | def __abs__(self):
method __pow__ (line 69) | def __pow__(self, x):
class RVChainOps (line 75) | class RVChainOps:
method add (line 80) | def add(self, x):
method sub (line 83) | def sub(self, x):
method mul (line 86) | def mul(self, x):
method div (line 89) | def div(self, x):
method abs (line 92) | def abs(self):
method pow (line 95) | def pow(self, x):
method neg (line 98) | def neg(self):
method exp (line 101) | def exp(self):
method log (line 104) | def log(self):
method sigmoid (line 107) | def sigmoid(self):
method tanh (line 110) | def tanh(self):
method softmax (line 113) | def softmax(self):
class RandomVariable (line 117) | class RandomVariable(RVMagicOps, RVChainOps):
method __init__ (line 144) | def __init__(self, distribution):
method transform (line 152) | def transform(self, t: Transform):
method dist (line 168) | def dist(self):
FILE: pyro/contrib/timeseries/base.py
class TimeSeriesModel (line 7) | class TimeSeriesModel(PyroModule):
method log_prob (line 13) | def log_prob(self, targets):
method forecast (line 28) | def forecast(self, targets, dts):
method get_dist (line 43) | def get_dist(self):
FILE: pyro/contrib/timeseries/gp.py
class IndependentMaternGP (line 17) | class IndependentMaternGP(TimeSeriesModel):
method __init__ (line 35) | def __init__(
method _get_init_dist (line 68) | def _get_init_dist(self):
method _get_obs_dist (line 74) | def _get_obs_dist(self):
method get_dist (line 80) | def get_dist(self, duration=None):
method log_prob (line 107) | def log_prob(self, targets):
method _filter (line 118) | def _filter(self, targets):
method _forecast (line 126) | def _forecast(self, dts, filtering_state, include_observation_noise=Tr...
method forecast (line 154) | def forecast(self, targets, dts):
class LinearlyCoupledMaternGP (line 171) | class LinearlyCoupledMaternGP(TimeSeriesModel):
method __init__ (line 196) | def __init__(
method _get_obs_matrix (line 235) | def _get_obs_matrix(self):
method _stationary_covariance (line 245) | def _stationary_covariance(self):
method _get_init_dist (line 248) | def _get_init_dist(self):
method _get_obs_dist (line 252) | def _get_obs_dist(self):
method get_dist (line 256) | def get_dist(self, duration=None):
method log_prob (line 282) | def log_prob(self, targets):
method _filter (line 293) | def _filter(self, targets):
method _forecast (line 301) | def _forecast(
method forecast (line 338) | def forecast(self, targets, dts):
class DependentMaternGP (line 356) | class DependentMaternGP(TimeSeriesModel):
method __init__ (line 379) | def __init__(
method _get_obs_matrix (line 428) | def _get_obs_matrix(self):
method _get_init_dist (line 438) | def _get_init_dist(self, stationary_covariance):
method _get_obs_dist (line 443) | def _get_obs_dist(self):
method _get_wiener_cov (line 448) | def _get_wiener_cov(self):
method _stationary_covariance (line 456) | def _stationary_covariance(self):
method _get_trans_dist (line 470) | def _get_trans_dist(self, trans_matrix, stationary_covariance):
method _trans_matrix_distribution_stat_covar (line 477) | def _trans_matrix_distribution_stat_covar(self, dts):
method get_dist (line 484) | def get_dist(self, duration=None):
method log_prob (line 507) | def log_prob(self, targets):
method _filter (line 518) | def _filter(self, targets):
method _forecast (line 526) | def _forecast(self, dts, filtering_state, include_observation_noise=Tr...
method forecast (line 554) | def forecast(self, targets, dts):
FILE: pyro/contrib/timeseries/lgssm.py
class GenericLGSSM (line 14) | class GenericLGSSM(TimeSeriesModel):
method __init__ (line 26) | def __init__(
method _get_init_dist (line 61) | def _get_init_dist(self):
method _get_obs_dist (line 65) | def _get_obs_dist(self):
method _get_trans_dist (line 68) | def _get_trans_dist(self):
method get_dist (line 72) | def get_dist(self, duration=None):
method log_prob (line 90) | def log_prob(self, targets):
method _filter (line 101) | def _filter(self, targets):
method _forecast (line 109) | def _forecast(self, N_timesteps, filtering_state, include_observation_...
method forecast (line 142) | def forecast(self, targets, N_timesteps):
FILE: pyro/contrib/timeseries/lgssmgp.py
class GenericLGSSMWithGPNoiseModel (line 15) | class GenericLGSSMWithGPNoiseModel(TimeSeriesModel):
method __init__ (line 44) | def __init__(
method _get_obs_matrix (line 102) | def _get_obs_matrix(self):
method _get_init_dist (line 106) | def _get_init_dist(self):
method _get_obs_dist (line 117) | def _get_obs_dist(self):
method get_dist (line 120) | def get_dist(self, duration=None):
method log_prob (line 165) | def log_prob(self, targets):
method _filter (line 176) | def _filter(self, targets):
method _forecast (line 184) | def _forecast(self, N_timesteps, filtering_state, include_observation_...
method forecast (line 260) | def forecast(self, targets, N_timesteps):
FILE: pyro/contrib/tracking/assignment.py
function _product (line 14) | def _product(factors):
function _exp (line 21) | def _exp(value):
class MarginalAssignment (line 27) | class MarginalAssignment:
method __init__ (line 56) | def __init__(self, exists_logits, assign_logits, bp_iters=None):
class MarginalAssignmentSparse (line 81) | class MarginalAssignmentSparse:
method __init__ (line 108) | def __init__(
class MarginalAssignmentPersistent (line 142) | class MarginalAssignmentPersistent:
method __init__ (line 180) | def __init__(self, exists_logits, assign_logits, bp_iters=None, bp_mom...
function compute_marginals (line 207) | def compute_marginals(exists_logits, assign_logits):
function compute_marginals_bp (line 249) | def compute_marginals_bp(exists_logits, assign_logits, bp_iters):
function compute_marginals_sparse_bp (line 284) | def compute_marginals_sparse_bp(
function compute_marginals_persistent (line 334) | def compute_marginals_persistent(exists_logits, assign_logits):
function compute_marginals_persistent_bp (line 389) | def compute_marginals_persistent_bp(
FILE: pyro/contrib/tracking/distributions.py
class EKFDistribution (line 13) | class EKFDistribution(TorchDistribution):
method __init__ (line 38) | def __init__(
method rsample (line 60) | def rsample(self, sample_shape=torch.Size()):
method filter_states (line 63) | def filter_states(self, value):
method log_prob (line 83) | def log_prob(self, value):
FILE: pyro/contrib/tracking/dynamic_models.py
class DynamicModel (line 14) | class DynamicModel(nn.Module, metaclass=ABCMeta):
method __init__ (line 25) | def __init__(self, dimension, dimension_pv, num_process_noise_paramete...
method dimension (line 32) | def dimension(self):
method dimension_pv (line 39) | def dimension_pv(self):
method num_process_noise_parameters (line 46) | def num_process_noise_parameters(self):
method forward (line 53) | def forward(self, x, dt, do_normalization=True):
method geodesic_difference (line 67) | def geodesic_difference(self, x1, x0):
method mean2pv (line 79) | def mean2pv(self, x):
method cov2pv (line 91) | def cov2pv(self, P):
method process_noise_cov (line 103) | def process_noise_cov(self, dt=0.0):
method process_noise_dist (line 115) | def process_noise_dist(self, dt=0.0):
class DifferentiableDynamicModel (line 129) | class DifferentiableDynamicModel(DynamicModel):
method jacobian (line 136) | def jacobian(self, dt):
class Ncp (line 147) | class Ncp(DifferentiableDynamicModel):
method __init__ (line 159) | def __init__(self, dimension, sv2):
method forward (line 168) | def forward(self, x, dt, do_normalization=True):
method mean2pv (line 182) | def mean2pv(self, x):
method cov2pv (line 196) | def cov2pv(self, P):
method jacobian (line 211) | def jacobian(self, dt):
method process_noise_cov (line 222) | def process_noise_cov(self, dt=0.0):
class Ncv (line 233) | class Ncv(DifferentiableDynamicModel):
method __init__ (line 245) | def __init__(self, dimension, sa2):
method forward (line 254) | def forward(self, x, dt, do_normalization=True):
method mean2pv (line 270) | def mean2pv(self, x):
method cov2pv (line 281) | def cov2pv(self, P):
method jacobian (line 292) | def jacobian(self, dt):
method process_noise_cov (line 310) | def process_noise_cov(self, dt=0.0):
class NcpContinuous (line 321) | class NcpContinuous(Ncp):
method process_noise_cov (line 336) | def process_noise_cov(self, dt=0.0):
class NcvContinuous (line 355) | class NcvContinuous(Ncv):
method process_noise_cov (line 370) | def process_noise_cov(self, dt=0.0):
class NcpDiscrete (line 398) | class NcpDiscrete(Ncp):
method process_noise_cov (line 413) | def process_noise_cov(self, dt=0.0):
class NcvDiscrete (line 428) | class NcvDiscrete(Ncv):
method process_noise_cov (line 443) | def process_noise_cov(self, dt=0.0):
FILE: pyro/contrib/tracking/extended_kalman_filter.py
class EKFState (line 11) | class EKFState:
method __init__ (line 27) | def __init__(self, dynamic_model, mean, cov, time=None, frame_num=None):
method dynamic_model (line 37) | def dynamic_model(self):
method dimension (line 44) | def dimension(self):
method mean (line 51) | def mean(self):
method cov (line 58) | def cov(self):
method dimension_pv (line 65) | def dimension_pv(self):
method mean_pv (line 72) | def mean_pv(self):
method cov_pv (line 79) | def cov_pv(self):
method time (line 86) | def time(self):
method frame_num (line 93) | def frame_num(self):
method predict (line 99) | def predict(self, dt=None, destination_time=None, destination_frame_nu...
method innovation (line 139) | def innovation(self, measurement):
method log_likelihood_of_update (line 165) | def log_likelihood_of_update(self, measurement):
method update (line 180) | def update(self, measurement):
FILE: pyro/contrib/tracking/hashing.py
class LSH (line 12) | class LSH:
method __init__ (line 48) | def __init__(self, radius):
method _hash (line 57) | def _hash(self, point):
method add (line 61) | def add(self, key, point):
method remove (line 75) | def remove(self, key):
method nearby (line 88) | def nearby(self, key):
class ApproxSet (line 110) | class ApproxSet:
method __init__ (line 119) | def __init__(self, radius):
method _hash (line 127) | def _hash(self, point):
method try_add (line 131) | def try_add(self, point):
function merge_points (line 147) | def merge_points(points, radius):
FILE: pyro/contrib/tracking/measurements.py
class Measurement (line 11) | class Measurement(object, metaclass=ABCMeta):
method __init__ (line 23) | def __init__(self, mean, cov, time=None, frame_num=None):
method dimension (line 33) | def dimension(self):
method mean (line 40) | def mean(self):
method cov (line 47) | def cov(self):
method time (line 54) | def time(self):
method frame_num (line 61) | def frame_num(self):
method __call__ (line 68) | def __call__(self, x, do_normalization=True):
method geodesic_difference (line 80) | def geodesic_difference(self, z1, z0):
class DifferentiableMeasurement (line 92) | class DifferentiableMeasurement(Measurement):
method jacobian (line 99) | def jacobian(self, x=None):
class PositionMeasurement (line 111) | class PositionMeasurement(DifferentiableMeasurement):
method __init__ (line 120) | def __init__(self, mean, cov, time=None, frame_num=None):
method __call__ (line 132) | def __call__(self, x, do_normalization=True):
method jacobian (line 144) | def jacobian(self, x=None):
FILE: pyro/contrib/util.py
function get_indices (line 12) | def get_indices(labels, sizes=None, tensors=None):
function tensor_to_dict (line 25) | def tensor_to_dict(sizes, tensor, subset=None):
function rmm (line 38) | def rmm(A, B):
function rmv (line 43) | def rmv(A, b):
function rvv (line 48) | def rvv(a, b):
function lexpand (line 53) | def lexpand(A, *dimensions):
function rexpand (line 58) | def rexpand(A, *dimensions):
function rdiag (line 63) | def rdiag(v):
function rtril (line 68) | def rtril(M, diagonal=0, upper=False):
function iter_plates_to_shape (line 75) | def iter_plates_to_shape(shape):
function check_no_weakref (line 81) | def check_no_weakref(obj, path="", avoid_ids=None):
FILE: pyro/contrib/zuko.py
class ZukoToPyro (line 18) | class ZukoToPyro(pyro.distributions.TorchDistribution):
method __init__ (line 48) | def __init__(self, dist: torch.distributions.Distribution):
method has_rsample (line 53) | def has_rsample(self) -> bool:
method event_shape (line 57) | def event_shape(self) -> Size:
method batch_shape (line 61) | def batch_shape(self) -> Size:
method __call__ (line 64) | def __call__(self, shape: Size = ()) -> Tensor:
method log_prob (line 74) | def log_prob(self, x: Tensor) -> Tensor:
method expand (line 80) | def expand(self, *args, **kwargs):
FILE: pyro/distributions/affine_beta.py
class AffineBeta (line 12) | class AffineBeta(TransformedDistribution):
method __init__ (line 39) | def __init__(self, concentration1, concentration0, loc, scale, validat...
method infer_shapes (line 48) | def infer_shapes(concentration1, concentration0, loc, scale):
method expand (line 53) | def expand(self, batch_shape, _instance=None):
method sample (line 57) | def sample(self, sample_shape=torch.Size()):
method rsample (line 71) | def rsample(self, sample_shape=torch.Size()):
method support (line 85) | def support(self):
method concentration1 (line 89) | def concentration1(self):
method concentration0 (line 93) | def concentration0(self):
method sample_size (line 97) | def sample_size(self):
method loc (line 101) | def loc(self):
method scale (line 105) | def scale(self):
method low (line 109) | def low(self):
method high (line 113) | def high(self):
method mean (line 117) | def mean(self):
method variance (line 121) | def variance(self):
FILE: pyro/distributions/asymmetriclaplace.py
class AsymmetricLaplace (line 13) | class AsymmetricLaplace(TorchDistribution):
method __init__ (line 36) | def __init__(self, loc, scale, asymmetry, *, validate_args=None):
method left_scale (line 41) | def left_scale(self):
method right_scale (line 45) | def right_scale(self):
method expand (line 48) | def expand(self, batch_shape, _instance=None):
method log_prob (line 58) | def log_prob(self, value):
method rsample (line 65) | def rsample(self, sample_shape=torch.Size()):
method mean (line 71) | def mean(self):
method variance (line 76) | def variance(self):
class SoftAsymmetricLaplace (line 85) | class SoftAsymmetricLaplace(TorchDistribution):
method __init__ (line 121) | def __init__(self, loc, scale, asymmetry=1.0, softness=1.0, *, validat...
method left_scale (line 131) | def left_scale(self):
method right_scale (line 135) | def right_scale(self):
method soft_scale (line 139) | def soft_scale(self):
method expand (line 142) | def expand(self, batch_shape, _instance=None):
method log_prob (line 153) | def log_prob(self, value):
method rsample (line 182) | def rsample(self, sample_shape=torch.Size()):
method mean (line 191) | def mean(self):
method variance (line 196) | def variance(self):
function _logerfc (line 205) | def _logerfc(x):
FILE: pyro/distributions/avf_mvn.py
class AVFMultivariateNormal (line 13) | class AVFMultivariateNormal(MultivariateNormal):
method __init__ (line 48) | def __init__(self, loc, scale_tril, control_var):
method rsample (line 64) | def rsample(self, sample_shape=torch.Size()):
class _AVFMVNSample (line 70) | class _AVFMVNSample(Function):
method forward (line 72) | def forward(ctx, loc, scale_tril, control_var, shape):
method backward (line 80) | def backward(ctx, grad_output):
FILE: pyro/distributions/coalescent.py
class CoalescentTimesConstraint (line 17) | class CoalescentTimesConstraint(constraints.Constraint):
method __init__ (line 18) | def __init__(self, leaf_times, *, ordered=True):
method check (line 22) | def check(self, value):
class CoalescentTimes (line 35) | class CoalescentTimes(TorchDistribution):
method __init__ (line 65) | def __init__(self, leaf_times, rate=1.0, *, validate_args=None):
method support (line 74) | def support(self):
method log_prob (line 77) | def log_prob(self, value):
method sample (line 96) | def sample(self, sample_shape=torch.Size()):
class CoalescentTimesWithRate (line 102) | class CoalescentTimesWithRate(TorchDistribution):
method __init__ (line 149) | def __init__(self, leaf_times, rate_grid, *, validate_args=None):
method support (line 157) | def support(self):
method duration (line 161) | def duration(self):
method expand (line 164) | def expand(self, batch_shape, _instance=None):
method log_prob (line 174) | def log_prob(self, value):
class CoalescentRateLikelihood (line 213) | class CoalescentRateLikelihood:
method __init__ (line 249) | def __init__(self, leaf_times, coal_times, duration, *, validate_args=...
method __call__ (line 292) | def __call__(self, rate_grid, t=slice(None)):
function bio_phylo_to_times (line 326) | def bio_phylo_to_times(tree, *, get_time=None):
function _gather (line 374) | def _gather(tensor, dim, index):
function _interpolate_gather (line 386) | def _interpolate_gather(array, x):
function _interpolate_scatter_add_ (line 399) | def _interpolate_scatter_add_(dst, x, src):
function _weak_memoize (line 412) | def _weak_memoize(fn):
function _make_phylogeny (line 450) | def _make_phylogeny(leaf_times, coal_times):
function _sample_coalescent_times (line 487) | def _sample_coalescent_times(leaf_times):
FILE: pyro/distributions/conditional.py
class ConditionalDistribution (line 13) | class ConditionalDistribution(ABC):
method condition (line 15) | def condition(self, context):
class ConditionalTransform (line 20) | class ConditionalTransform(ABC):
method condition (line 22) | def condition(self, context):
class ConditionalTransformModule (line 27) | class ConditionalTransformModule(ConditionalTransform, torch.nn.Module):
method __init__ (line 34) | def __init__(self, *args, **kwargs):
method __hash__ (line 37) | def __hash__(self):
method inv (line 41) | def inv(self) -> "ConditionalTransformModule":
class _ConditionalInverseTransformModule (line 45) | class _ConditionalInverseTransformModule(ConditionalTransformModule):
method __init__ (line 46) | def __init__(self, transform: ConditionalTransform):
method inv (line 51) | def inv(self) -> ConditionalTransform:
method condition (line 54) | def condition(self, context: torch.Tensor):
class ConditionalComposeTransformModule (line 58) | class ConditionalComposeTransformModule(
method __init__ (line 83) | def __init__(self, transforms, cache_size: int = 0):
method condition (line 101) | def condition(self, context: torch.Tensor) -> ComposeTransformModule:
class ConstantConditionalDistribution (line 107) | class ConstantConditionalDistribution(ConditionalDistribution):
method __init__ (line 108) | def __init__(self, base_dist):
method condition (line 112) | def condition(self, context):
class ConstantConditionalTransform (line 116) | class ConstantConditionalTransform(ConditionalTransform):
method __init__ (line 117) | def __init__(self, transform):
method condition (line 121) | def condition(self, context):
method clear_cache (line 124) | def clear_cache(self):
class ConditionalTransformedDistribution (line 128) | class ConditionalTransformedDistribution(ConditionalDistribution):
method __init__ (line 129) | def __init__(self, base_dist, transforms):
method condition (line 144) | def condition(self, context):
method clear_cache (line 149) | def clear_cache(self):
FILE: pyro/distributions/conjugate.py
function _log_beta_1 (line 17) | def _log_beta_1(alpha, value, is_sparse):
class BetaBinomial (line 34) | class BetaBinomial(TorchDistribution):
method __init__ (line 65) | def __init__(
method concentration1 (line 76) | def concentration1(self):
method concentration0 (line 80) | def concentration0(self):
method expand (line 83) | def expand(self, batch_shape, _instance=None):
method sample (line 92) | def sample(self, sample_shape=()):
method log_prob (line 96) | def log_prob(self, value):
method mean (line 112) | def mean(self):
method variance (line 116) | def variance(self):
method enumerate_support (line 123) | def enumerate_support(self, expand=True):
class DirichletMultinomial (line 140) | class DirichletMultinomial(TorchDistribution):
method __init__ (line 161) | def __init__(
method concentration (line 178) | def concentration(self):
method infer_shapes (line 182) | def infer_shapes(concentration, total_count=()):
method expand (line 187) | def expand(self, batch_shape, _instance=None):
method sample (line 199) | def sample(self, sample_shape=()):
method log_prob (line 208) | def log_prob(self, value):
method mean (line 217) | def mean(self):
method variance (line 221) | def variance(self):
class GammaPoisson (line 229) | class GammaPoisson(TorchDistribution):
method __init__ (line 252) | def __init__(self, concentration, rate, validate_args=None):
method concentration (line 259) | def concentration(self):
method rate (line 263) | def rate(self):
method expand (line 266) | def expand(self, batch_shape, _instance=None):
method sample (line 274) | def sample(self, sample_shape=()):
method log_prob (line 278) | def log_prob(self, value):
method mean (line 290) | def mean(self):
method variance (line 294) | def variance(self):
FILE: pyro/distributions/constraints.py
class _Integer (line 50) | class _Integer(Constraint):
method check (line 57) | def check(self, value):
method __repr__ (line 60) | def __repr__(self):
class _Sphere (line 64) | class _Sphere(Constraint):
method check (line 72) | def check(self, value):
method __repr__ (line 78) | def __repr__(self):
class _CorrMatrix (line 82) | class _CorrMatrix(Constraint):
method check (line 89) | def check(self, value):
class _OrderedVector (line 98) | class _OrderedVector(Constraint):
method check (line 106) | def check(self, value):
class _PositiveOrderedVector (line 115) | class _PositiveOrderedVector(Constraint):
method check (line 121) | def check(self, value):
class _SoftplusPositive (line 125) | class _SoftplusPositive(type(positive)):
method __init__ (line 126) | def __init__(self):
class _SoftplusLowerCholesky (line 130) | class _SoftplusLowerCholesky(type(lower_cholesky)):
class _UnitLowerCholesky (line 134) | class _UnitLowerCholesky(Constraint):
method check (line 141) | def check(self, value):
FILE: pyro/distributions/delta.py
class Delta (line 14) | class Delta(TorchDistribution):
method __init__ (line 32) | def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None):
method support (line 57) | def support(self):
method expand (line 60) | def expand(self, batch_shape, _instance=None):
method rsample (line 69) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 73) | def log_prob(self, x):
method mean (line 80) | def mean(self):
method variance (line 84) | def variance(self):
FILE: pyro/distributions/diag_normal_mixture.py
class MixtureOfDiagNormals (line 15) | class MixtureOfDiagNormals(TorchDistribution):
method __init__ (line 51) | def __init__(self, locs, coord_scale, component_logits):
method expand (line 87) | def expand(self, batch_shape, _instance=None):
method log_prob (line 107) | def log_prob(self, value):
method rsample (line 122) | def rsample(self, sample_shape=torch.Size()):
class _MixDiagNormalSample (line 134) | class _MixDiagNormalSample(Function):
method forward (line 136) | def forward(ctx, locs, scales, component_logits, pis, which, noise_sha...
method backward (line 151) | def backward(ctx, grad_output):
FILE: pyro/distributions/diag_normal_mixture_shared_cov.py
class MixtureOfDiagNormalsSharedCovariance (line 15) | class MixtureOfDiagNormalsSharedCovariance(TorchDistribution):
method __init__ (line 50) | def __init__(self, locs, coord_scale, component_logits):
method expand (line 85) | def expand(self, batch_shape, _instance=None):
method log_prob (line 108) | def log_prob(self, value):
method rsample (line 124) | def rsample(self, sample_shape=torch.Size()):
class _MixDiagNormalSharedCovarianceSample (line 136) | class _MixDiagNormalSharedCovarianceSample(Function):
method forward (line 138) | def forward(ctx, locs, coord_scale, component_logits, pis, which, nois...
method backward (line 152) | def backward(ctx, grad_output):
FILE: pyro/distributions/distribution.py
class DistributionMeta (line 15) | class DistributionMeta(ABCMeta):
method __init__ (line 16) | def __init__(cls, *args, **kwargs):
method __call__ (line 21) | def __call__(cls, *args, **kwargs):
class Distribution (line 29) | class Distribution(metaclass=DistributionMeta):
method __call__ (line 55) | def __call__(self, *args, **kwargs):
method sample (line 68) | def sample(self, *args, **kwargs):
method log_prob (line 85) | def log_prob(self, x, *args, **kwargs):
method score_parts (line 98) | def score_parts(self, x, *args, **kwargs):
method enumerate_support (line 127) | def enumerate_support(self, expand: bool = True) -> torch.Tensor:
method conjugate_update (line 147) | def conjugate_update(self, other):
method has_rsample_ (line 180) | def has_rsample_(self, value):
method rv (line 199) | def rv(self):
FILE: pyro/distributions/empirical.py
class Empirical (line 13) | class Empirical(TorchDistribution):
method __init__ (line 53) | def __init__(self, samples, log_weights, validate_args=None):
method sample_size (line 77) | def sample_size(self):
method sample (line 85) | def sample(self, sample_shape=torch.Size()):
method log_prob (line 103) | def log_prob(self, value):
method _weighted_mean (line 128) | def _weighted_mean(self, value, keepdim=False):
method event_shape (line 141) | def event_shape(self):
method mean (line 145) | def mean(self):
method variance (line 157) | def variance(self):
method log_weights (line 171) | def log_weights(self):
method enumerate_support (line 174) | def enumerate_support(self, expand=True):
FILE: pyro/distributions/extended.py
class ExtendedBinomial (line 12) | class ExtendedBinomial(Binomial):
method log_prob (line 27) | def log_prob(self, value):
class ExtendedBetaBinomial (line 33) | class ExtendedBetaBinomial(BetaBinomial):
method log_prob (line 48) | def log_prob(self, value):
FILE: pyro/distributions/folded.py
class FoldedDistribution (line 10) | class FoldedDistribution(TransformedDistribution):
method __init__ (line 21) | def __init__(self, base_dist, validate_args=None):
method expand (line 26) | def expand(self, batch_shape, _instance=None):
method log_prob (line 30) | def log_prob(self, value):
FILE: pyro/distributions/gaussian_scale_mixture.py
class GaussianScaleMixture (line 15) | class GaussianScaleMixture(TorchDistribution):
method __init__ (line 60) | def __init__(self, coord_scale, component_logits, component_scale):
method _compute_coeffs (line 83) | def _compute_coeffs(self):
method log_prob (line 93) | def log_prob(self, value):
method rsample (line 108) | def rsample(self, sample_shape=torch.Size()):
class _GSMSample (line 121) | class _GSMSample(Function):
method forward (line 123) | def forward(
method backward (line 136) | def backward(ctx, grad_output):
FILE: pyro/distributions/grouped_normal_normal.py
class GroupedNormalNormal (line 15) | class GroupedNormalNormal(TorchDistribution):
method __init__ (line 59) | def __init__(
method expand (line 98) | def expand(self, batch_shape, _instance=None):
method sample (line 101) | def sample(self, sample_shape=()):
method get_posterior (line 104) | def get_posterior(self, value):
method log_prob (line 131) | def log_prob(self, value):
FILE: pyro/distributions/hmm.py
function _linear_integrate (line 32) | def _linear_integrate(init, trans, shift):
function _logmatmulexp (line 51) | def _logmatmulexp(x, y):
function _sequential_logmatmulexp (line 65) | def _sequential_logmatmulexp(logits):
function _markov_index (line 88) | def _markov_index(x, y):
function _sequential_index (line 96) | def _sequential_index(samples):
function _sequential_gamma_gaussian_tensordot (line 164) | def _sequential_gamma_gaussian_tensordot(gamma_gaussian):
class HiddenMarkovModel (line 189) | class HiddenMarkovModel(TorchDistribution):
method __init__ (line 200) | def __init__(self, duration, batch_shape, event_shape, validate_args=N...
method duration (line 218) | def duration(self):
method _validate_sample (line 224) | def _validate_sample(self, value):
class DiscreteHMM (line 243) | class DiscreteHMM(HiddenMarkovModel):
method __init__ (line 293) | def __init__(
method support (line 333) | def support(self):
method expand (line 336) | def expand(self, batch_shape, _instance=None):
method log_prob (line 352) | def log_prob(self, value):
method filter (line 371) | def filter(self, value):
method sample (line 400) | def sample(self, sample_shape=torch.Size()):
class GaussianHMM (line 434) | class GaussianHMM(HiddenMarkovModel):
method __init__ (line 498) | def __init__(
method expand (line 546) | def expand(self, batch_shape, _instance=None):
method log_prob (line 565) | def log_prob(self, value):
method rsample (line 584) | def rsample(self, sample_shape=torch.Size()):
method rsample_posterior (line 596) | def rsample_posterior(self, value, sample_shape=torch.Size()):
method filter (line 606) | def filter(self, value):
method conjugate_update (line 638) | def conjugate_update(self, other):
method prefix_condition (line 690) | def prefix_condition(self, data):
class GammaGaussianHMM (line 744) | class GammaGaussianHMM(HiddenMarkovModel):
method __init__ (line 817) | def __init__(
method expand (line 862) | def expand(self, batch_shape, _instance=None):
method log_prob (line 879) | def log_prob(self, value):
method filter (line 901) | def filter(self, value):
class LinearHMM (line 939) | class LinearHMM(HiddenMarkovModel):
method __init__ (line 1011) | def __init__(
method support (line 1094) | def support(self): # noqa: F811
method expand (line 1097) | def expand(self, batch_shape, _instance=None):
method log_prob (line 1119) | def log_prob(self, value):
method rsample (line 1122) | def rsample(self, sample_shape=torch.Size()):
class IndependentHMM (line 1141) | class IndependentHMM(TorchDistribution):
method __init__ (line 1159) | def __init__(self, base_dist):
method support (line 1169) | def support(self):
method has_rsample (line 1173) | def has_rsample(self):
method duration (line 1177) | def duration(self):
method expand (line 1180) | def expand(self, batch_shape, _instance=None):
method rsample (line 1192) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 1196) | def log_prob(self, value):
class GaussianMRF (line 1201) | class GaussianMRF(TorchDistribution):
method __init__ (line 1244) | def __init__(
method support (line 1270) | def support(self):
method expand (line 1273) | def expand(self, batch_shape, _instance=None):
method log_prob (line 1291) | def log_prob(self, value):
FILE: pyro/distributions/improper_uniform.py
class ImproperUniform (line 11) | class ImproperUniform(TorchDistribution):
method __init__ (line 46) | def __init__(self, support, batch_shape, event_shape):
method support (line 52) | def support(self):
method expand (line 55) | def expand(self, batch_shape, _instance=None):
method log_prob (line 62) | def log_prob(self, value):
method sample (line 67) | def sample(self, sample_shape=torch.Size()):
FILE: pyro/distributions/inverse_gamma.py
class InverseGamma (line 11) | class InverseGamma(TransformedDistribution):
method __init__ (line 30) | def __init__(self, concentration, rate, validate_args=None):
method expand (line 38) | def expand(self, batch_shape, _instance=None):
method concentration (line 43) | def concentration(self):
method rate (line 47) | def rate(self):
FILE: pyro/distributions/kl.py
function _kl_delta (line 20) | def _kl_delta(p, q):
function _kl_independent_independent (line 25) | def _kl_independent_independent(p, q):
function _kl_independent_mvn (line 38) | def _kl_independent_mvn(p, q):
FILE: pyro/distributions/lkj.py
class LKJCorrCholesky (line 14) | class LKJCorrCholesky(LKJCholesky): # DEPRECATED
method __init__ (line 15) | def __init__(self, d, eta, validate_args=None):
class LKJ (line 24) | class LKJ(TransformedDistribution):
method __init__ (line 49) | def __init__(self, dim, concentration=1.0, validate_args=None):
method expand (line 56) | def expand(self, batch_shape, _instance=None):
method mean (line 61) | def mean(self):
FILE: pyro/distributions/log_normal_negative_binomial.py
class LogNormalNegativeBinomial (line 14) | class LogNormalNegativeBinomial(TorchDistribution):
method __init__ (line 77) | def __init__(
method log_prob (line 114) | def log_prob(self, value):
method sample (line 118) | def sample(self, sample_shape=torch.Size()):
method expand (line 121) | def expand(self, batch_shape, _instance=None):
method mean (line 139) | def mean(self):
method variance (line 147) | def variance(self):
FILE: pyro/distributions/logistic.py
class Logistic (line 14) | class Logistic(TorchDistribution):
method __init__ (line 40) | def __init__(self, loc, scale, *, validate_args=None):
method expand (line 44) | def expand(self, batch_shape, _instance=None):
method log_prob (line 53) | def log_prob(self, value):
method rsample (line 59) | def rsample(self, sample_shape=torch.Size()):
method cdf (line 64) | def cdf(self, value):
method icdf (line 70) | def icdf(self, value):
method mean (line 74) | def mean(self):
method variance (line 78) | def variance(self):
method entropy (line 81) | def entropy(self):
class SkewLogistic (line 85) | class SkewLogistic(TorchDistribution):
method __init__ (line 124) | def __init__(self, loc, scale, asymmetry=1.0, *, validate_args=None):
method expand (line 128) | def expand(self, batch_shape, _instance=None):
method log_prob (line 138) | def log_prob(self, value):
method rsample (line 145) | def rsample(self, sample_shape=torch.Size()):
method cdf (line 150) | def cdf(self, value):
method icdf (line 156) | def icdf(self, value):
FILE: pyro/distributions/mixture.py
class MaskedConstraint (line 12) | class MaskedConstraint(constraints.Constraint):
method __init__ (line 23) | def __init__(self, mask, constraint0, constraint1):
method check (line 28) | def check(self, value):
class MaskedMixture (line 39) | class MaskedMixture(TorchDistribution):
method __init__ (line 66) | def __init__(self, mask, component0, component1, validate_args=None):
method has_rsample (line 98) | def has_rsample(self):
method support (line 102) | def support(self):
method expand (line 109) | def expand(self, batch_shape):
method sample (line 118) | def sample(self, sample_shape=torch.Size()):
method rsample (line 128) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 138) | def log_prob(self, value):
method mean (line 154) | def mean(self):
method variance (line 160) | def variance(self):
FILE: pyro/distributions/multivariate_studentt.py
class MultivariateStudentT (line 15) | class MultivariateStudentT(TorchDistribution):
method __init__ (line 34) | def __init__(self, df, loc, scale_tril, validate_args=None):
method scale_tril (line 50) | def scale_tril(self):
method covariance_matrix (line 56) | def covariance_matrix(self):
method precision_matrix (line 65) | def precision_matrix(self):
method infer_shapes (line 74) | def infer_shapes(df, loc, scale_tril):
method expand (line 79) | def expand(self, batch_shape, _instance=None):
method rsample (line 100) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 107) | def log_prob(self, value):
method mean (line 124) | def mean(self):
method variance (line 130) | def variance(self):
FILE: pyro/distributions/nanmasked.py
class NanMaskedNormal (line 9) | class NanMaskedNormal(Normal):
method log_prob (line 24) | def log_prob(self, value: torch.Tensor) -> torch.Tensor:
class NanMaskedMultivariateNormal (line 40) | class NanMaskedMultivariateNormal(MultivariateNormal):
method log_prob (line 65) | def log_prob(self, value: torch.Tensor) -> torch.Tensor:
FILE: pyro/distributions/omt_mvn.py
class OMTMultivariateNormal (line 13) | class OMTMultivariateNormal(MultivariateNormal):
method __init__ (line 30) | def __init__(self, loc, scale_tril):
method rsample (line 37) | def rsample(self, sample_shape=torch.Size()):
class _OMTMVNSample (line 43) | class _OMTMVNSample(Function):
method forward (line 45) | def forward(ctx, loc, scale_tril, shape):
method backward (line 53) | def backward(ctx, grad_output):
FILE: pyro/distributions/one_one_matching.py
class OneOneMatchingConstraint (line 18) | class OneOneMatchingConstraint(constraints.Constraint):
method __init__ (line 19) | def __init__(self, num_nodes):
method check (line 22) | def check(self, value):
class OneOneMatching (line 41) | class OneOneMatching(TorchDistribution):
method __init__ (line 84) | def __init__(self, logits, *, bp_iters=None, validate_args=None):
method support (line 97) | def support(self):
method log_partition_function (line 101) | def log_partition_function(self):
method log_prob (line 133) | def log_prob(self, value):
method enumerate_support (line 140) | def enumerate_support(self, expand=True):
method sample (line 143) | def sample(self, sample_shape=torch.Size()):
method mode (line 161) | def mode(self):
function maximum_weight_matching (line 169) | def maximum_weight_matching(logits):
FILE: pyro/distributions/one_two_matching.py
class OneTwoMatchingConstraint (line 18) | class OneTwoMatchingConstraint(constraints.Constraint):
method __init__ (line 19) | def __init__(self, num_destins):
method check (line 23) | def check(self, value):
class OneTwoMatching (line 42) | class OneTwoMatching(TorchDistribution):
method __init__ (line 85) | def __init__(self, logits, *, bp_iters=None, validate_args=None):
method support (line 98) | def support(self):
method log_partition_function (line 102) | def log_partition_function(self):
method log_prob (line 142) | def log_prob(self, value):
method enumerate_support (line 149) | def enumerate_support(self, expand=True):
method sample (line 152) | def sample(self, sample_shape=torch.Size()):
method mode (line 170) | def mode(self):
function enumerate_one_two_matchings (line 177) | def enumerate_one_two_matchings(num_destins):
function maximum_weight_matching (line 204) | def maximum_weight_matching(logits):
FILE: pyro/distributions/ordered_logistic.py
class OrderedLogistic (line 10) | class OrderedLogistic(Categorical):
method __init__ (line 41) | def __init__(self, predictor, cutpoints, validate_args=None):
method expand (line 56) | def expand(self, batch_shape, _instance=None):
FILE: pyro/distributions/polya_gamma.py
class TruncatedPolyaGamma (line 13) | class TruncatedPolyaGamma(TorchDistribution):
method __init__ (line 41) | def __init__(self, prototype, validate_args=None):
method expand (line 47) | def expand(self, batch_shape, _instance=None):
method sample (line 56) | def sample(self, sample_shape=()):
method log_prob (line 65) | def log_prob(self, value):
FILE: pyro/distributions/projected_normal.py
class ProjectedNormal (line 14) | class ProjectedNormal(TorchDistribution):
method __init__ (line 56) | def __init__(self, concentration, *, validate_args=None):
method infer_shapes (line 64) | def infer_shapes(concentration):
method expand (line 69) | def expand(self, batch_shape, _instance=None):
method mean (line 80) | def mean(self):
method mode (line 88) | def mode(self):
method rsample (line 91) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 98) | def log_prob(self, value):
method _register_log_prob (line 118) | def _register_log_prob(cls, dim, fn=None):
function _dot (line 125) | def _dot(x, y):
function _safe_log (line 129) | def _safe_log(x):
function _log_prob_2 (line 134) | def _log_prob_2(concentration, value):
function _log_prob_3 (line 157) | def _log_prob_3(concentration, value):
function _log_prob_4 (line 179) | def _log_prob_4(concentration, value):
FILE: pyro/distributions/rejector.py
class Rejector (line 10) | class Rejector(TorchDistribution):
method __init__ (line 25) | def __init__(
method _log_prob_accept (line 41) | def _log_prob_accept(self, x):
method _propose_log_prob (line 46) | def _propose_log_prob(self, x):
method rsample (line 51) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 67) | def log_prob(self, x):
method score_parts (line 70) | def score_parts(self, x):
FILE: pyro/distributions/relaxed_straight_through.py
class RelaxedOneHotCategoricalStraightThrough (line 12) | class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical):
method rsample (line 34) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 40) | def log_prob(self, value):
class QuantizeCategorical (line 45) | class QuantizeCategorical(torch.autograd.Function):
method forward (line 47) | def forward(ctx, soft_value):
method backward (line 56) | def backward(ctx, grad):
class RelaxedBernoulliStraightThrough (line 61) | class RelaxedBernoulliStraightThrough(RelaxedBernoulli):
method rsample (line 83) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 89) | def log_prob(self, value):
class QuantizeBernoulli (line 94) | class QuantizeBernoulli(torch.autograd.Function):
method forward (line 96) | def forward(ctx, soft_value):
method backward (line 102) | def backward(ctx, grad):
FILE: pyro/distributions/score_parts.py
class ScoreParts (line 11) | class ScoreParts(NamedTuple):
method scale_and_mask (line 21) | def scale_and_mask(
FILE: pyro/distributions/sine_bivariate_von_mises.py
class SineBivariateVonMises (line 18) | class SineBivariateVonMises(TorchDistribution):
method __init__ (line 84) | def __init__(
method norm_const (line 143) | def norm_const(self):
method log_prob (line 166) | def log_prob(self, value):
method sample (line 179) | def sample(self, sample_shape=torch.Size()):
method mean (line 295) | def mean(self):
method infer_shapes (line 299) | def infer_shapes(cls, **arg_shapes):
method expand (line 303) | def expand(self, batch_shape, _instance=None):
method _bfind (line 315) | def _bfind(self, eig):
method _lbinoms (line 326) | def _lbinoms(n):
FILE: pyro/distributions/sine_skewed.py
class SineSkewed (line 16) | class SineSkewed(TorchDistribution):
method __init__ (line 92) | def __init__(self, base_dist: TorchDistribution, skewness, validate_ar...
method __repr__ (line 112) | def __repr__(self):
method sample (line 134) | def sample(self, sample_shape=torch.Size()):
method log_prob (line 149) | def log_prob(self, value):
method expand (line 161) | def expand(self, batch_shape, _instance=None):
FILE: pyro/distributions/softlaplace.py
class SoftLaplace (line 13) | class SoftLaplace(TorchDistribution):
method __init__ (line 35) | def __init__(self, loc, scale, *, validate_args=None):
method expand (line 39) | def expand(self, batch_shape, _instance=None):
method log_prob (line 48) | def log_prob(self, value):
method rsample (line 54) | def rsample(self, sample_shape=torch.Size()):
method cdf (line 59) | def cdf(self, value):
method icdf (line 65) | def icdf(self, value):
method mean (line 69) | def mean(self):
method variance (line 73) | def variance(self):
FILE: pyro/distributions/spanning_tree.cpp
function make_complete_graph (line 11) | at::Tensor make_complete_graph(int num_vertices) {
function _remove_edge (line 26) | int _remove_edge(at::Tensor grid, at::Tensor edge_ids,
function _add_edge (line 49) | void _add_edge(at::Tensor grid, at::Tensor edge_ids,
function _find_valid_edges (line 60) | int _find_valid_edges(const std::vector<bool> &components, at::Tensor va...
function sample_tree_mcmc (line 77) | at::Tensor sample_tree_mcmc(at::Tensor edge_logits, at::Tensor edges) {
function sample_tree_approx (line 134) | at::Tensor sample_tree_approx(at::Tensor edge_logits) {
function find_best_tree (line 178) | at::Tensor find_best_tree(at::Tensor edge_logits) {
function PYBIND11_MODULE (line 221) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: pyro/distributions/spanning_tree.py
class SpanningTree (line 14) | class SpanningTree(TorchDistribution):
method __init__ (line 57) | def __init__(self, edge_logits, sampler_options=None, validate_args=No...
method validate_edges (line 78) | def validate_edges(self, edges):
method log_partition_function (line 119) | def log_partition_function(self):
method log_prob (line 142) | def log_prob(self, edges):
method sample (line 150) | def sample(self, sample_shape=torch.Size()):
method enumerate_support (line 177) | def enumerate_support(self, expand=True):
method mode (line 185) | def mode(self):
method edge_mean (line 194) | def edge_mean(self):
function _get_cpp_module (line 225) | def _get_cpp_module():
function make_complete_graph (line 244) | def make_complete_graph(num_vertices, backend="python"):
function _make_complete_graph (line 261) | def _make_complete_graph(num_vertices):
function _remove_edge (line 277) | def _remove_edge(grid, edge_ids, neighbors, components, e):
function _add_edge (line 297) | def _add_edge(grid, edge_ids, neighbors, components, e, k):
function _find_valid_edges (line 309) | def _find_valid_edges(components, valid_edge_ids):
function _sample_tree_mcmc (line 332) | def _sample_tree_mcmc(edge_logits, edges):
function sample_tree_mcmc (line 381) | def sample_tree_mcmc(edge_logits, edges, backend="python"):
function _sample_tree_approx (line 415) | def _sample_tree_approx(edge_logits):
function sample_tree_approx (line 452) | def sample_tree_approx(edge_logits, backend="python"):
function sample_tree (line 473) | def sample_tree(edge_logits, init_edges=None, mcmc_steps=1, backend="pyt...
function _find_best_tree (line 483) | def _find_best_tree(edge_logits):
function find_best_tree (line 519) | def find_best_tree(edge_logits, backend="python"):
function _permute_tree (line 593) | def _permute_tree(perm, tree):
function _close_under_permutations (line 599) | def _close_under_permutations(V, tree_generators):
function enumerate_spanning_trees (line 610) | def enumerate_spanning_trees(V):
FILE: pyro/distributions/stable.py
function _unsafe_standard_stable (line 14) | def _unsafe_standard_stable(alpha, beta, V, W, coords):
function _standard_stable (line 51) | def _standard_stable(alpha, beta, aux_uniform, aux_exponential, coords):
class Stable (line 96) | class Stable(TorchDistribution):
method __init__ (line 161) | def __init__(
method expand (line 171) | def expand(self, batch_shape, _instance=None):
method log_prob (line 181) | def log_prob(self, value):
method rsample (line 209) | def rsample(self, sample_shape=torch.Size()):
method mean (line 224) | def mean(self):
method variance (line 233) | def variance(self):
class StableWithLogProb (line 238) | class StableWithLogProb(Stable):
FILE: pyro/distributions/stable_log_prob.py
function create_integrator (line 19) | def create_integrator(num_points):
function set_integrator (line 40) | def set_integrator(num_points):
function integrate (line 47) | def integrate(*args, **kwargs): # noqa: F811
function _stable_log_prob (line 52) | def _stable_log_prob(alpha, beta, value, coords):
function _unsafe_alpha_stable_log_prob_S0 (line 90) | def _unsafe_alpha_stable_log_prob_S0(alpha, beta, Z):
function _unsafe_stable_log_prob (line 130) | def _unsafe_stable_log_prob(alpha, beta, Z):
function _unsafe_stable_given_uniform_log_prob (line 154) | def _unsafe_stable_given_uniform_log_prob(V, alpha, beta, Z):
function _unsafe_alpha_stable_log_prob_at_zero (line 188) | def _unsafe_alpha_stable_log_prob_at_zero(alpha, beta):
FILE: pyro/distributions/testing/fakes.py
class NonreparameterizedBeta (line 7) | class NonreparameterizedBeta(Beta):
class NonreparameterizedDirichlet (line 11) | class NonreparameterizedDirichlet(Dirichlet):
class NonreparameterizedGamma (line 15) | class NonreparameterizedGamma(Gamma):
class NonreparameterizedNormal (line 19) | class NonreparameterizedNormal(Normal):
FILE: pyro/distributions/testing/gof.py
class InvalidTest (line 68) | class InvalidTest(ValueError):
function print_histogram (line 72) | def print_histogram(probs, counts):
function multinomial_goodness_of_fit (line 81) | def multinomial_goodness_of_fit(
function unif01_goodness_of_fit (line 137) | def unif01_goodness_of_fit(samples, *, plot=False):
function exp_goodness_of_fit (line 160) | def exp_goodness_of_fit(samples, plot=False):
function density_goodness_of_fit (line 176) | def density_goodness_of_fit(samples, probs, plot=False):
function volume_of_sphere (line 205) | def volume_of_sphere(dim, radius):
function get_nearest_neighbor_distances (line 209) | def get_nearest_neighbor_distances(samples):
function vector_density_goodness_of_fit (line 224) | def vector_density_goodness_of_fit(samples, probs, *, dim=None, plot=Fal...
function auto_goodness_of_fit (line 266) | def auto_goodness_of_fit(samples, probs, *, dim=None, plot=False):
FILE: pyro/distributions/testing/naive_dirichlet.py
class NaiveDirichlet (line 11) | class NaiveDirichlet(Dirichlet):
method __init__ (line 19) | def __init__(self, concentration, validate_args=None):
method rsample (line 25) | def rsample(self, sample_shape=torch.Size()):
class NaiveBeta (line 31) | class NaiveBeta(Beta):
method __init__ (line 39) | def __init__(self, concentration1, concentration0, validate_args=None):
method rsample (line 44) | def rsample(self, sample_shape=torch.Size()):
FILE: pyro/distributions/testing/rejection_exponential.py
class RejectionExponential (line 14) | class RejectionExponential(Rejector):
method __init__ (line 18) | def __init__(self, rate, factor):
method log_prob_accept (line 26) | def log_prob_accept(self, x):
method batch_shape (line 32) | def batch_shape(self):
method event_shape (line 36) | def event_shape(self):
FILE: pyro/distributions/testing/rejection_gamma.py
class RejectionStandardGamma (line 13) | class RejectionStandardGamma(Rejector):
method __init__ (line 19) | def __init__(self, concentration):
method expand (line 42) | def expand(self, batch_shape, _instance=None):
method propose (line 63) | def propose(self, sample_shape=torch.Size()):
method propose_log_prob (line 74) | def propose_log_prob(self, value):
method log_prob_accept (line 87) | def log_prob_accept(self, value):
method log_prob (line 95) | def log_prob(self, x):
class RejectionGamma (line 100) | class RejectionGamma(Gamma):
method __init__ (line 103) | def __init__(self, concentration, rate, validate_args=None):
method expand (line 108) | def expand(self, batch_shape, _instance=None):
method rsample (line 115) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 118) | def log_prob(self, x):
method score_parts (line 121) | def score_parts(self, x):
class ShapeAugmentedGamma (line 128) | class ShapeAugmentedGamma(Gamma):
method __init__ (line 136) | def __init__(self, concentration, rate, boost=1, validate_args=None):
method expand (line 145) | def expand(self, batch_shape, _instance=None):
method rsample (line 155) | def rsample(self, sample_shape=torch.Size()):
method score_parts (line 164) | def score_parts(self, boosted_x=None):
class ShapeAugmentedDirichlet (line 175) | class ShapeAugmentedDirichlet(Dirichlet):
method __init__ (line 183) | def __init__(self, concentration, boost=1, validate_args=None):
method expand (line 189) | def expand(self, batch_shape, _instance=None):
method rsample (line 199) | def rsample(self, sample_shape=torch.Size()):
class ShapeAugmentedBeta (line 205) | class ShapeAugmentedBeta(Beta):
method __init__ (line 213) | def __init__(self, concentration1, concentration0, boost=1, validate_a...
method expand (line 220) | def expand(self, batch_shape, _instance=None):
method rsample (line 230) | def rsample(self, sample_shape=torch.Size()):
FILE: pyro/distributions/testing/special.py
function log (line 41) | def log(x):
function incomplete_gamma (line 49) | def incomplete_gamma(x, s):
function chi2sf (line 82) | def chi2sf(x, s):
FILE: pyro/distributions/torch.py
function _clamp_by_zero (line 18) | def _clamp_by_zero(x):
class Beta (line 23) | class Beta(torch.distributions.Beta, TorchDistributionMixin):
method conjugate_update (line 24) | def conjugate_update(self, other):
class Binomial (line 44) | class Binomial(torch.distributions.Binomial, TorchDistributionMixin):
method sample (line 56) | def sample(self, sample_shape=torch.Size()):
method log_prob (line 83) | def log_prob(self, value):
function _validate_thresh (line 107) | def _validate_thresh(thresh):
function _validate_tol (line 115) | def _validate_tol(tol):
class Categorical (line 124) | class Categorical(torch.distributions.Categorical, TorchDistributionMixin):
method log_prob (line 127) | def log_prob(self, value):
method enumerate_support (line 145) | def enumerate_support(self, expand=True):
class Dirichlet (line 152) | class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin):
method infer_shapes (line 154) | def infer_shapes(concentration):
method conjugate_update (line 159) | def conjugate_update(self, other):
class Gamma (line 177) | class Gamma(torch.distributions.Gamma, TorchDistributionMixin):
method conjugate_update (line 178) | def conjugate_update(self, other):
class Geometric (line 197) | class Geometric(torch.distributions.Geometric, TorchDistributionMixin):
method log_prob (line 199) | def log_prob(self, value):
class LogNormal (line 205) | class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin):
method __init__ (line 206) | def __init__(self, loc, scale, validate_args=None):
method expand (line 216) | def expand(self, batch_shape, _instance=None):
class LowRankMultivariateNormal (line 223) | class LowRankMultivariateNormal(
method infer_shapes (line 227) | def infer_shapes(loc, cov_factor, cov_diag):
class MultivariateNormal (line 233) | class MultivariateNormal(
method infer_shapes (line 237) | def infer_shapes(
class Multinomial (line 247) | class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin):
method infer_shapes (line 248) | def infer_shapes(total_count=None, probs=None, logits=None):
class Normal (line 256) | class Normal(torch.distributions.Normal, TorchDistributionMixin):
class OneHotCategorical (line 260) | class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDist...
method infer_shapes (line 262) | def infer_shapes(probs=None, logits=None):
class Poisson (line 269) | class Poisson(torch.distributions.Poisson, TorchDistributionMixin):
method __init__ (line 270) | def __init__(self, rate, *, is_sparse=False, validate_args=None):
method expand (line 274) | def expand(self, batch_shape, _instance=None):
method log_prob (line 280) | def log_prob(self, value):
class Independent (line 297) | class Independent(torch.distributions.Independent, TorchDistributionMixin):
method infer_shapes (line 299) | def infer_shapes(**kwargs):
method _validate_args (line 303) | def _validate_args(self):
method _validate_args (line 307) | def _validate_args(self, value):
method conjugate_update (line 310) | def conjugate_update(self, other):
class Uniform (line 321) | class Uniform(torch.distributions.Uniform, TorchDistributionMixin):
method __init__ (line 322) | def __init__(self, low, high, validate_args=None):
method expand (line 327) | def expand(self, batch_shape, _instance=None):
method support (line 335) | def support(self):
function _cat_docstrings (line 339) | def _cat_docstrings(*docstrings):
FILE: pyro/distributions/torch_distribution.py
class TorchDistributionMixin (line 19) | class TorchDistributionMixin(Distribution, Callable):
method __call__ (line 31) | def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.T...
method batch_shape (line 55) | def batch_shape(self) -> torch.Size:
method event_shape (line 63) | def event_shape(self) -> torch.Size:
method event_dim (line 71) | def event_dim(self) -> int:
method shape (line 78) | def shape(self, sample_shape=torch.Size()):
method infer_shapes (line 95) | def infer_shapes(cls, **arg_shapes):
method expand (line 122) | def expand(self, batch_shape, _instance=None) -> "ExpandedDistribution":
method expand_by (line 135) | def expand_by(self, sample_shape):
method reshape (line 156) | def reshape(self, sample_shape=None, extra_event_dims=None):
method to_event (line 163) | def to_event(self, reinterpreted_batch_ndims=None):
method independent (line 215) | def independent(self, reinterpreted_batch_ndims=None):
method mask (line 221) | def mask(self, mask):
class TorchDistribution (line 235) | class TorchDistribution(torch.distributions.Distribution, TorchDistribut...
class MaskedDistribution (line 302) | class MaskedDistribution(TorchDistribution):
method __init__ (line 317) | def __init__(self, base_dist, mask):
method expand (line 330) | def expand(self, batch_shape, _instance=None):
method has_rsample (line 344) | def has_rsample(self):
method has_enumerate_support (line 348) | def has_enumerate_support(self):
method support (line 352) | def support(self):
method sample (line 355) | def sample(self, sample_shape=torch.Size()):
method rsample (line 358) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 361) | def log_prob(self, value):
method score_parts (line 371) | def score_parts(self, value):
method enumerate_support (line 376) | def enumerate_support(self, expand=True):
method mean (line 380) | def mean(self):
method variance (line 384) | def variance(self):
method conjugate_update (line 387) | def conjugate_update(self, other):
class ExpandedDistribution (line 399) | class ExpandedDistribution(TorchDistribution):
method __init__ (line 402) | def __init__(self, base_dist, batch_shape=torch.Size()):
method expand (line 408) | def expand(self, batch_shape, _instance=None):
method _broadcast_shape (line 421) | def _broadcast_shape(existing_shape, new_shape):
method has_rsample (line 451) | def has_rsample(self):
method has_enumerate_support (line 455) | def has_enumerate_support(self):
method support (line 459) | def support(self):
method _sample (line 462) | def _sample(self, sample_fn, sample_shape):
method sample (line 477) | def sample(self, sample_shape=torch.Size()):
method rsample (line 480) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 483) | def log_prob(self, value):
method score_parts (line 490) | def score_parts(self, value):
method enumerate_support (line 503) | def enumerate_support(self, expand=True):
method mean (line 512) | def mean(self):
method variance (line 516) | def variance(self):
method conjugate_update (line 519) | def conjugate_update(self, other):
function _kl_masked_masked (line 530) | def _kl_masked_masked(p, q):
FILE: pyro/distributions/torch_patch.py
function patch_dependency (line 12) | def patch_dependency(target, root_module=torch):
function _Transform__getstate__ (line 46) | def _Transform__getstate__(self):
function _Transform_clear_cache (line 57) | def _Transform_clear_cache(self):
function _TransformedDistribution_clear_cache (line 64) | def _TransformedDistribution_clear_cache(self):
function _HalfCauchy_logprob (line 71) | def _HalfCauchy_logprob(self, value):
function _CorrCholesky_check (line 83) | def _CorrCholesky_check(self, value):
function _lazy_property__call__ (line 91) | def _lazy_property__call__(self):
FILE: pyro/distributions/torch_transform.py
class TransformModule (line 7) | class TransformModule(torch.distributions.Transform, torch.nn.Module):
method __init__ (line 13) | def __init__(self, *args, **kwargs):
method __hash__ (line 16) | def __hash__(self):
class ComposeTransformModule (line 20) | class ComposeTransformModule(torch.distributions.ComposeTransform, torch...
method __init__ (line 28) | def __init__(self, parts, cache_size=0):
method __hash__ (line 34) | def __hash__(self):
method with_cache (line 37) | def with_cache(self, cache_size=1):
FILE: pyro/distributions/transforms/__init__.py
function _transform_to_sphere (line 112) | def _transform_to_sphere(constraint):
function _transform_to_corr_matrix (line 118) | def _transform_to_corr_matrix(constraint):
function _transform_to_ordered_vector (line 126) | def _transform_to_ordered_vector(constraint):
function _transform_to_positive_ordered_vector (line 132) | def _transform_to_positive_ordered_vector(constraint):
function _transform_to_positive_definite (line 138) | def _transform_to_positive_definite(constraint):
function _transform_to_softplus_positive (line 144) | def _transform_to_softplus_positive(constraint):
function _transform_to_softplus_lower_cholesky (line 149) | def _transform_to_softplus_lower_cholesky(constraint):
function _transform_to_unit_lower_cholesky (line 154) | def _transform_to_unit_lower_cholesky(constraint):
function iterated (line 158) | def iterated(repeats, base_fn, *args, **kwargs):
FILE: pyro/distributions/transforms/affine_autoregressive.py
class AffineAutoregressive (line 19) | class AffineAutoregressive(TransformModule):
method __init__ (line 100) | def __init__(
method _call (line 122) | def _call(self, x):
method _inverse (line 141) | def _inverse(self, y):
method log_abs_det_jacobian (line 174) | def log_abs_det_jacobian(self, x, y):
method _call_stable (line 196) | def _call_stable(self, x):
method _inverse_stable (line 214) | def _inverse_stable(self, y):
class ConditionalAffineAutoregressive (line 238) | class ConditionalAffineAutoregressive(ConditionalTransformModule):
method __init__ (line 326) | def __init__(self, autoregressive_nn, **kwargs):
method condition (line 331) | def condition(self, context):
function affine_autoregressive (line 343) | def affine_autoregressive(input_dim, hidden_dims=None, **kwargs):
function conditional_affine_autoregressive (line 376) | def conditional_affine_autoregressive(
FILE: pyro/distributions/transforms/affine_coupling.py
class AffineCoupling (line 20) | class AffineCoupling(TransformModule):
method __init__ (line 91) | def __init__(
method domain (line 112) | def domain(self):
method codomain (line 116) | def codomain(self):
method _call (line 119) | def _call(self, x):
method _inverse (line 146) | def _inverse(self, y):
method log_abs_det_jacobian (line 172) | def log_abs_det_jacobian(self, x, y):
class ConditionalAffineCoupling (line 192) | class ConditionalAffineCoupling(ConditionalTransformModule):
method __init__ (line 270) | def __init__(self, split_dim, hypernet, **kwargs):
method condition (line 276) | def condition(self, context):
function affine_coupling (line 281) | def affine_coupling(input_dim, hidden_dims=None, split_dim=None, dim=-1,...
function conditional_affine_coupling (line 337) | def conditional_affine_coupling(
FILE: pyro/distributions/transforms/basic.py
class ELUTransform (line 15) | class ELUTransform(Transform):
method __eq__ (line 25) | def __eq__(self, other):
method _call (line 28) | def _call(self, x):
method _inverse (line 31) | def _inverse(self, y, eps=1e-8):
method log_abs_det_jacobian (line 36) | def log_abs_det_jacobian(self, x, y):
function elu (line 40) | def elu():
class LeakyReLUTransform (line 52) | class LeakyReLUTransform(Transform):
method __eq__ (line 62) | def __eq__(self, other):
method _call (line 65) | def _call(self, x):
method _inverse (line 68) | def _inverse(self, y):
method log_abs_det_jacobian (line 71) | def log_abs_det_jacobian(self, x, y):
function leaky_relu (line 77) | def leaky_relu():
function tanh (line 86) | def tanh():
FILE: pyro/distributions/transforms/batchnorm.py
class BatchNorm (line 14) | class BatchNorm(TransformModule):
method __init__ (line 77) | def __init__(self, input_dim, momentum=0.1, epsilon=1e-5):
method constrained_gamma (line 90) | def constrained_gamma(self):
method _call (line 93) | def _call(self, x):
method _inverse (line 107) | def _inverse(self, y):
method log_abs_det_jacobian (line 131) | def log_abs_det_jacobian(self, x, y):
function batchnorm (line 143) | def batchnorm(input_dim, **kwargs):
FILE: pyro/distributions/transforms/block_autoregressive.py
function log_matrix_product (line 19) | def log_matrix_product(A, B):
class BlockAutoregressive (line 29) | class BlockAutoregressive(TransformModule):
method __init__ (line 78) | def __init__(
method _call (line 126) | def _call(self, x):
method _inverse (line 174) | def _inverse(self, y):
method log_abs_det_jacobian (line 189) | def log_abs_det_jacobian(self, x, y):
class MaskedBlockLinear (line 202) | class MaskedBlockLinear(torch.nn.Module):
method __init__ (line 209) | def __init__(self, in_features, out_features, dim, bias=True):
method get_weights (line 258) | def get_weights(self):
method forward (line 282) | def forward(self, x):
function block_autoregressive (line 287) | def block_autoregressive(input_dim, **kwargs):
FILE: pyro/distributions/transforms/cholesky.py
class CorrLCholeskyTransform (line 13) | class CorrLCholeskyTransform(CorrCholeskyTransform): # DEPRECATED
method __init__ (line 14) | def __init__(self, cache_size=0):
class CholeskyTransform (line 22) | class CholeskyTransform(Transform):
method __eq__ (line 32) | def __eq__(self, other):
method _call (line 35) | def _call(self, x):
method _inverse (line 38) | def _inverse(self, y):
method log_abs_det_jacobian (line 41) | def log_abs_det_jacobian(self, x, y):
class CorrMatrixCholeskyTransform (line 50) | class CorrMatrixCholeskyTransform(CholeskyTransform):
method __eq__ (line 61) | def __eq__(self, other):
method log_abs_det_jacobian (line 64) | def log_abs_det_jacobian(self, x, y):
FILE: pyro/distributions/transforms/discrete_cosine.py
class DiscreteCosineTransform (line 12) | class DiscreteCosineTransform(Transform):
method __init__ (line 30) | def __init__(self, dim=-1, smooth=0.0, cache_size=0):
method __hash__ (line 37) | def __hash__(self):
method __eq__ (line 40) | def __eq__(self, other):
method domain (line 48) | def domain(self):
method codomain (line 52) | def codomain(self):
method _weight (line 56) | def _weight(self, y):
method _call (line 66) | def _call(self, x):
method _inverse (line 77) | def _inverse(self, y):
method log_abs_det_jacobian (line 88) | def log_abs_det_jacobian(self, x, y):
method with_cache (line 91) | def with_cache(self, cache_size=1):
method forward_shape (line 96) | def forward_shape(self, shape):
method inverse_shape (line 101) | def inverse_shape(self, shape):
FILE: pyro/distributions/transforms/generalized_channel_permute.py
class ConditionedGeneralizedChannelPermute (line 16) | class ConditionedGeneralizedChannelPermute(Transform):
method __init__ (line 21) | def __init__(self, permutation=None, LU=None):
method U_diag (line 28) | def U_diag(self):
method L (line 32) | def L(self):
method U (line 38) | def U(self):
method _call (line 41) | def _call(self, x):
method _inverse (line 65) | def _inverse(self, y):
method log_abs_det_jacobian (line 97) | def log_abs_det_jacobian(self, x, y):
class GeneralizedChannelPermute (line 111) | class GeneralizedChannelPermute(ConditionedGeneralizedChannelPermute, Tr...
method __init__ (line 169) | def __init__(self, channels=3, permutation=None):
class ConditionalGeneralizedChannelPermute (line 200) | class ConditionalGeneralizedChannelPermute(ConditionalTransformModule):
method __init__ (line 267) | def __init__(self, nn, channels=3, permutation=None):
method condition (line 280) | def condition(self, context):
function generalized_channel_permute (line 286) | def generalized_channel_permute(**kwargs):
function conditional_generalized_channel_permute (line 300) | def conditional_generalized_channel_permute(context_dim, channels=3, hid...
FILE: pyro/distributions/transforms/haar.py
class HaarTransform (line 11) | class HaarTransform(Transform):
method __init__ (line 30) | def __init__(self, dim=-1, flip=False, cache_size=0):
method __hash__ (line 36) | def __hash__(self):
method __eq__ (line 39) | def __eq__(self, other):
method domain (line 47) | def domain(self):
method codomain (line 51) | def codomain(self):
method _call (line 54) | def _call(self, x):
method _inverse (line 65) | def _inverse(self, y):
method log_abs_det_jacobian (line 76) | def log_abs_det_jacobian(self, x, y):
method with_cache (line 79) | def with_cache(self, cache_size=1):
method forward_shape (line 84) | def forward_shape(self, shape):
method inverse_shape (line 89) | def inverse_shape(self, shape):
FILE: pyro/distributions/transforms/householder.py
class ConditionedHouseholder (line 19) | class ConditionedHouseholder(Transform):
method __init__ (line 25) | def __init__(self, u_unnormed=None):
method u (line 30) | def u(self):
method _call (line 35) | def _call(self, x):
method _inverse (line 52) | def _inverse(self, y):
method log_abs_det_jacobian (line 70) | def log_abs_det_jacobian(self, x, y):
class Householder (line 82) | class Householder(ConditionedHouseholder, TransformModule):
method __init__ (line 131) | def __init__(self, input_dim, count_transforms=1):
method reset_parameters (line 151) | def reset_parameters(self):
class ConditionalHouseholder (line 157) | class ConditionalHouseholder(ConditionalTransformModule):
method __init__ (line 217) | def __init__(self, input_dim, nn, count_transforms=1):
method _u_unnormed (line 236) | def _u_unnormed(self, context):
method condition (line 246) | def condition(self, context):
function householder (line 251) | def householder(input_dim, count_transforms=None):
function conditional_householder (line 270) | def conditional_householder(
FILE: pyro/distributions/transforms/lower_cholesky_affine.py
class LowerCholeskyAffine (line 12) | class LowerCholeskyAffine(Transform):
method __init__ (line 32) | def __init__(self, loc, scale_tril, cache_size=0):
method _call (line 42) | def _call(self, x):
method _inverse (line 53) | def _inverse(self, y):
method log_abs_det_jacobian (line 64) | def log_abs_det_jacobian(self, x, y):
method with_cache (line 74) | def with_cache(self, cache_size=1):
FILE: pyro/distributions/transforms/matrix_exponential.py
class ConditionedMatrixExponential (line 19) | class ConditionedMatrixExponential(Transform):
method __init__ (line 24) | def __init__(self, weights=None, iterations=8, normalization="none", b...
method _exp (line 39) | def _exp(self, x, M):
method _trace (line 52) | def _trace(self, M):
method _call (line 63) | def _call(self, x):
method _inverse (line 75) | def _inverse(self, y):
method log_abs_det_jacobian (line 85) | def log_abs_det_jacobian(self, x, y):
class MatrixExponential (line 95) | class MatrixExponential(ConditionedMatrixExponential, TransformModule):
method __init__ (line 154) | def __init__(self, input_dim, iterations=8, normalization="none", boun...
method reset_parameters (line 162) | def reset_parameters(self):
class ConditionalMatrixExponential (line 168) | class ConditionalMatrixExponential(ConditionalTransformModule):
method __init__ (line 235) | def __init__(self, input_dim, nn, iterations=8, normalization="none", ...
method _params (line 243) | def _params(self, context):
method condition (line 246) | def condition(self, context):
function matrix_exponential (line 262) | def matrix_exponential(input_dim, iterations=8, normalization="none", bo...
function conditional_matrix_exponential (line 292) | def conditional_matrix_exponential(
FILE: pyro/distributions/transforms/neural_autoregressive.py
class NeuralAutoregressive (line 23) | class NeuralAutoregressive(TransformModule):
method __init__ (line 68) | def __init__(self, autoregressive_nn, hidden_units=16, activation="sig...
method _call (line 91) | def _call(self, x):
method log_abs_det_jacobian (line 121) | def log_abs_det_jacobian(self, x, y):
class ConditionalNeuralAutoregressive (line 144) | class ConditionalNeuralAutoregressive(ConditionalTransformModule):
method __init__ (line 194) | def __init__(self, autoregressive_nn, **kwargs):
method condition (line 199) | def condition(self, context):
function neural_autoregressive (line 212) | def neural_autoregressive(input_dim, hidden_dims=None, activation="sigmo...
function conditional_neural_autoregressive (line 239) | def conditional_neural_autoregressive(
FILE: pyro/distributions/transforms/normalize.py
class Normalize (line 13) | class Normalize(Transform):
method __init__ (line 23) | def __init__(self, p=2, cache_size=0):
method __eq__ (line 29) | def __eq__(self, other):
method _call (line 32) | def _call(self, x):
method _inverse (line 35) | def _inverse(self, y):
method with_cache (line 38) | def with_cache(self, cache_size=1):
FILE: pyro/distributions/transforms/ordered.py
class OrderedTransform (line 10) | class OrderedTransform(Transform):
method _call (line 23) | def _call(self, x):
method _inverse (line 27) | def _inverse(self, y):
method log_abs_det_jacobian (line 31) | def log_abs_det_jacobian(self, x, y):
FILE: pyro/distributions/transforms/permute.py
class Permute (line 14) | class Permute(Transform):
method __init__ (line 50) | def __init__(self, permutation, *, dim=-1, cache_size=1):
method domain (line 60) | def domain(self):
method codomain (line 64) | def codomain(self):
method inv_permutation (line 68) | def inv_permutation(self):
method _call (line 75) | def _call(self, x):
method _inverse (line 87) | def _inverse(self, y):
method log_abs_det_jacobian (line 96) | def log_abs_det_jacobian(self, x, y):
method with_cache (line 109) | def with_cache(self, cache_size=1):
function permute (line 115) | def permute(input_dim, permutation=None, dim=-1):
FILE: pyro/distributions/transforms/planar.py
class ConditionedPlanar (line 20) | class ConditionedPlanar(Transform):
method __init__ (line 25) | def __init__(self, params):
method u_hat (line 31) | def u_hat(self, u, w):
method _call (line 36) | def _call(self, x):
method _inverse (line 67) | def _inverse(self, y):
method log_abs_det_jacobian (line 81) | def log_abs_det_jacobian(self, x, y):
class Planar (line 95) | class Planar(ConditionedPlanar, TransformModule):
method __init__ (line 137) | def __init__(self, input_dim):
method _params (line 159) | def _params(self):
method reset_parameters (line 162) | def reset_parameters(self):
class ConditionalPlanar (line 170) | class ConditionalPlanar(ConditionalTransformModule):
method __init__ (line 221) | def __init__(self, nn):
method _params (line 225) | def _params(self, context):
method condition (line 228) | def condition(self, context):
function planar (line 233) | def planar(input_dim):
function conditional_planar (line 246) | def conditional_planar(input_dim, context_dim, hidden_dims=None):
FILE: pyro/distributions/transforms/polynomial.py
class Polynomial (line 17) | class Polynomial(TransformModule):
method __init__ (line 76) | def __init__(self, autoregressive_nn, input_dim, count_degree, count_s...
method reset_parameters (line 102) | def reset_parameters(self):
method _call (line 106) | def _call(self, x):
method _inverse (line 142) | def _inverse(self, y):
method log_abs_det_jacobian (line 157) | def log_abs_det_jacobian(self, x, y):
function polynomial (line 170) | def polynomial(input_dim, hidden_dims=None):
FILE: pyro/distributions/transforms/power.py
class PositivePowerTransform (line 9) | class PositivePowerTransform(Transform):
method __init__ (line 26) | def __init__(self, exponent, *, cache_size=0, validate_args=None):
method with_cache (line 38) | def with_cache(self, cache_size=1):
method __eq__ (line 43) | def __eq__(self, other):
method _call (line 48) | def _call(self, x):
method _inverse (line 51) | def _inverse(self, y):
method log_abs_det_jacobian (line 54) | def log_abs_det_jacobian(self, x, y):
method forward_shape (line 57) | def forward_shape(self, shape):
method inverse_shape (line 60) | def inverse_shape(self, shape):
FILE: pyro/distributions/transforms/radial.py
class ConditionedRadial (line 20) | class ConditionedRadial(Transform):
method __init__ (line 25) | def __init__(self, params):
method u_hat (line 31) | def u_hat(self, u, w):
method _call (line 36) | def _call(self, x):
method _inverse (line 66) | def _inverse(self, y):
method log_abs_det_jacobian (line 80) | def log_abs_det_jacobian(self, x, y):
class Radial (line 94) | class Radial(ConditionedRadial, TransformModule):
method __init__ (line 134) | def __init__(self, input_dim):
method _params (line 155) | def _params(self):
method reset_parameters (line 158) | def reset_parameters(self):
class ConditionalRadial (line 166) | class ConditionalRadial(ConditionalTransformModule):
method __init__ (line 215) | def __init__(self, nn):
method _params (line 219) | def _params(self, context):
method condition (line 222) | def condition(self, context):
function radial (line 227) | def radial(input_dim):
function conditional_radial (line 240) | def conditional_radial(input_dim, context_dim, hidden_dims=None):
FILE: pyro/distributions/transforms/simplex_to_ordered.py
class SimplexToOrderedTransform (line 12) | class SimplexToOrderedTransform(Transform):
method __init__ (line 31) | def __init__(self, anchor_point=None):
method _call (line 37) | def _call(self, x):
method _inverse (line 42) | def _inverse(self, y):
method log_abs_det_jacobian (line 53) | def log_abs_det_jacobian(self, x, y):
method __eq__ (line 61) | def __eq__(self, other):
method forward_shape (line 66) | def forward_shape(self, shape):
method inverse_shape (line 69) | def inverse_shape(self, shape):
FILE: pyro/distributions/transforms/softplus.py
function softplus_inv (line 9) | def softplus_inv(y):
class SoftplusTransform (line 14) | class SoftplusTransform(Transform):
method __eq__ (line 24) | def __eq__(self, other):
method _call (line 27) | def _call(self, x):
method _inverse (line 30) | def _inverse(self, y):
method log_abs_det_jacobian (line 33) | def log_abs_det_jacobian(self, x, y):
class SoftplusLowerCholeskyTransform (line 37) | class SoftplusLowerCholeskyTransform(Transform):
method __eq__ (line 47) | def __eq__(self, other):
method _call (line 50) | def _call(self, x):
method _inverse (line 54) | def _inverse(self, y):
FILE: pyro/distributions/transforms/spline.py
function _searchsorted (line 27) | def _searchsorted(sorted_sequence, values):
function _select_bins (line 37) | def _select_bins(x, idx):
function _calculate_knots (line 59) | def _calculate_knots(lengths, lower, upper):
function _monotonic_rational_spline (line 83) | def _monotonic_rational_spline(
class ConditionedSpline (line 303) | class ConditionedSpline(Transform):
method __init__ (line 313) | def __init__(self, params, bound=3.0, order="linear"):
method _call (line 321) | def _call(self, x):
method _inverse (line 326) | def _inverse(self, y):
method log_abs_det_jacobian (line 338) | def log_abs_det_jacobian(self, x, y):
method spline_op (line 350) | def spline_op(self, x, **kwargs):
class Spline (line 359) | class Spline(ConditionedSpline, TransformModule):
method __init__ (line 412) | def __init__(self, input_dim, count_bins=8, bound=3.0, order="linear"):
method _params (line 442) | def _params(self):
class ConditionalSpline (line 455) | class ConditionalSpline(ConditionalTransformModule):
method __init__ (line 521) | def __init__(self, nn, input_dim, count_bins, bound=3.0, order="linear"):
method _params (line 530) | def _params(self, context):
method condition (line 567) | def condition(self, context):
function spline (line 572) | def spline(input_dim, **kwargs):
function conditional_spline (line 588) | def conditional_spline(
FILE: pyro/distributions/transforms/spline_autoregressive.py
class SplineAutoregressive (line 18) | class SplineAutoregressive(TransformModule):
method __init__ (line 78) | def __init__(
method _call (line 87) | def _call(self, x):
method _inverse (line 101) | def _inverse(self, y):
method log_abs_det_jacobian (line 120) | def log_abs_det_jacobian(self, x, y):
class ConditionalSplineAutoregressive (line 134) | class ConditionalSplineAutoregressive(ConditionalTransformModule):
method __init__ (line 201) | def __init__(self, input_dim, autoregressive_nn, **kwargs):
method condition (line 207) | def condition(self, context):
function spline_autoregressive (line 220) | def spline_autoregressive(
function conditional_spline_autoregressive (line 254) | def conditional_spline_autoregressive(
FILE: pyro/distributions/transforms/spline_coupling.py
class SplineCoupling (line 15) | class SplineCoupling(TransformModule):
method __init__ (line 81) | def __init__(
method _call (line 102) | def _call(self, x):
method _inverse (line 129) | def _inverse(self, y):
method log_abs_det_jacobian (line 155) | def log_abs_det_jacobian(self, x, y):
function spline_coupling (line 168) | def spline_coupling(
FILE: pyro/distributions/transforms/sylvester.py
class Sylvester (line 14) | class Sylvester(Householder):
method __init__ (line 60) | def __init__(self, input_dim, count_transforms=1):
method dtanh_dx (line 79) | def dtanh_dx(self, x):
method R (line 83) | def R(self):
method S (line 87) | def S(self):
method Q (line 91) | def Q(self, x):
method reset_parameters2 (line 105) | def reset_parameters2(self):
method _call (line 109) | def _call(self, x):
method _inverse (line 133) | def _inverse(self, y):
method log_abs_det_jacobian (line 147) | def log_abs_det_jacobian(self, x, y):
function sylvester (line 160) | def sylvester(input_dim, count_transforms=None):
FILE: pyro/distributions/transforms/unit_cholesky.py
class UnitLowerCholeskyTransform (line 11) | class UnitLowerCholeskyTransform(Transform):
method __eq__ (line 20) | def __eq__(self, other):
method _call (line 23) | def _call(self, x):
method _inverse (line 26) | def _inverse(self, y):
FILE: pyro/distributions/transforms/utils.py
function clamp_preserve_gradients (line 5) | def clamp_preserve_gradients(x, min, max):
FILE: pyro/distributions/unit.py
class Unit (line 11) | class Unit(TorchDistribution):
method __init__ (line 23) | def __init__(self, log_factor, *, has_rsample=None, validate_args=None):
method expand (line 32) | def expand(self, batch_shape, _instance=None):
method sample (line 42) | def sample(self, sample_shape=torch.Size()):
method rsample (line 45) | def rsample(self, sample_shape=torch.Size()):
method log_prob (line 48) | def log_prob(self, value):
FILE: pyro/distributions/util.py
function copy_docs_from (line 33) | def copy_docs_from(source_class, full_text=False):
function weakmethod (line 72) | def weakmethod(fn):
class _DetachMemo (line 116) | class _DetachMemo(dict):
method get (line 117) | def get(self, key, default=None):
function detach (line 129) | def detach(obj):
class _DeepToMemo (line 141) | class _DeepToMemo(dict):
method __init__ (line 142) | def __init__(self, to_args, to_kwargs):
method get (line 147) | def get(self, key, default=None):
function deep_to (line 159) | def deep_to(obj, *args, **kwargs):
function is_identically_zero (line 188) | def is_identically_zero(x):
function is_identically_one (line 201) | def is_identically_one(x):
function broadcast_shape (line 214) | def broadcast_shape(*shapes, **kwargs):
function gather (line 242) | def gather(value, index, dim):
function sum_rightmost (line 253) | def sum_rightmost(value, dim):
function sum_leftmost (line 279) | def sum_leftmost(value, dim):
function scale_and_mask (line 311) | def scale_and_mask(tensor, scale=1.0, mask=None):
function scalar_like (line 331) | def scalar_like(prototype, fill_value):
function eye_like (line 336) | def eye_like(value, m, n=None):
function enable_validation (line 344) | def enable_validation(is_validate):
function is_validation_enabled (line 350) | def is_validation_enabled():
function validation_enabled (line 355) | def validation_enabled(is_validate=True):
FILE: pyro/distributions/von_mises_3d.py
class VonMises3D (line 12) | class VonMises3D(TorchDistribution):
method __init__ (line 35) | def __init__(self, concentration, validate_args=None):
method log_prob (line 46) | def log_prob(self, value):
method expand (line 60) | def expand(self, batch_shape):
FILE: pyro/distributions/zero_inflated.py
class ZeroInflatedDistribution (line 18) | class ZeroInflatedDistribution(TorchDistribution):
method __init__ (line 35) | def __init__(self, base_dist, *, gate=None, gate_logits=None, validate...
method support (line 58) | def support(self):
method gate (line 62) | def gate(self):
method gate_logits (line 66) | def gate_logits(self):
method log_prob (line 69) | def log_prob(self, value):
method sample (line 86) | def sample(self, sample_shape=torch.Size()):
method mean (line 95) | def mean(self):
method variance (line 99) | def variance(self):
method expand (line 104) | def expand(self, batch_shape, _instance=None):
class ZeroInflatedPoisson (line 121) | class ZeroInflatedPoisson(ZeroInflatedDistribution):
method __init__ (line 137) | def __init__(self, rate, *, gate=None, gate_logits=None, validate_args...
method rate (line 146) | def rate(self):
class ZeroInflatedNegativeBinomial (line 150) | class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution):
method __init__ (line 171) | def __init__(
method total_count (line 194) | def total_count(self):
method probs (line 198) | def probs(self):
method logits (line 202) | def logits(self):
FILE: pyro/infer/abstract_infer.py
class EmpiricalMarginal (line 17) | class EmpiricalMarginal(Empirical):
method __init__ (line 33) | def __init__(self, trace_posterior, sites=None, validate_args=None):
method _get_samples_and_weights (line 46) | def _get_samples_and_weights(self):
method _add_sample (line 71) | def _add_sample(self, value, log_weight=None, chain_id=0):
method _populate_traces (line 101) | def _populate_traces(self, trace_posterior, sites):
class Marginals (line 116) | class Marginals:
method __init__ (line 128) | def __init__(self, trace_posterior, sites=None, validate_args=None):
method _populate_traces (line 144) | def _populate_traces(self, trace_posterior, validate):
method support (line 150) | def support(self, flatten=False):
method empirical (line 174) | def empirical(self):
class TracePosterior (line 184) | class TracePosterior(object, metaclass=ABCMeta):
method __init__ (line 192) | def __init__(self, num_chains=1):
method _reset (line 196) | def _reset(self):
method marginal (line 205) | def marginal(self, sites=None):
method _traces (line 217) | def _traces(self, *args, **kwargs):
method __call__ (line 226) | def __call__(self, *args, **kwargs):
method run (line 241) | def run(self, *args, **kwargs):
method information_criterion (line 265) | def information_criterion(self, pointwise=False):
class TracePredictive (line 313) | class TracePredictive(TracePosterior):
method __init__ (line 330) | def __init__(self, model, posterior, num_samples, keep_sites=None):
method _traces (line 342) | def _traces(self, *args, **kwargs):
method _remove_dropped_nodes (line 355) | def _remove_dropped_nodes(self, trace):
method _adjust_to_data (line 363) | def _adjust_to_data(self, trace, data_trace):
method marginal (line 392) | def marginal(self, sites=None):
FILE: pyro/infer/autoguide/effect.py
class AutoMessengerMeta (line 21) | class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)):
class AutoMessenger (line 25) | class AutoMessenger(GuideMessenger, PyroModule, metaclass=AutoMessengerM...
method __init__ (line 35) | def __init__(self, model: Callable, *, amortized_plates: Tuple[str, .....
method __call__ (line 40) | def __call__(self, *args, **kwargs):
method call (line 51) | def call(self, *args, **kwargs):
method _adjust_plates (line 67) | def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch...
class AutoNormalMessenger (line 84) | class AutoNormalMessenger(AutoMessenger):
method __init__ (line 147) | def __init__(
method get_posterior (line 162) | def get_posterior(
method _get_params (line 177) | def _get_params(self, name: str, prior: Distribution):
method median (line 202) | def median(self, *args, **kwargs):
method _get_posterior_median (line 209) | def _get_posterior_median(self, name, prior):
class AutoHierarchicalNormalMessenger (line 215) | class AutoHierarchicalNormalMessenger(AutoNormalMessenger):
method __init__ (line 249) | def __init__(
method get_posterior (line 268) | def get_posterior(
method _get_params (line 289) | def _get_params(self, name: str, prior: Distribution):
method median (line 348) | def median(self, *args, **kwargs):
method _get_posterior_median (line 355) | def _get_posterior_median(self, name, prior):
class AutoRegressiveMessenger (line 365) | class AutoRegressiveMessenger(AutoMessenger):
method __init__ (line 401) | def __init__(
method get_posterior (line 415) | def get_posterior(
method _get_params (line 429) | def _get_params(self, name: str, prior: Distribution):
FILE: pyro/infer/autoguide/gaussian.py
class AutoGaussianMeta (line 36) | class AutoGaussianMeta(type(AutoGuide), ABCMeta):
method __init__ (line 40) | def __init__(cls, *args, **kwargs):
method __call__ (line 46) | def __call__(cls, *args, **kwargs):
class AutoGaussian (line 53) | class AutoGaussian(AutoGuide, metaclass=AutoGaussianMeta):
method __init__ (line 109) | def __init__(
method _prototype_hide_fn (line 125) | def _prototype_hide_fn(msg):
method _setup_prototype (line 130) | def _setup_prototype(self, *args, **kwargs) -> None:
method _compress_site (line 233) | def _compress_site(site):
method forward (line 247) | def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
method median (line 268) | def median(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
method _transform_values (line 280) | def _transform_values(
method _sample_aux_values (line 307) | def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch...
class AutoGaussianDense (line 311) | class AutoGaussianDense(AutoGaussian):
method _setup_prototype (line 321) | def _setup_prototype(self, *args, **kwargs):
method _sample_aux_values (line 385) | def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch...
method _dense_unflatten (line 400) | def _dense_unflatten(self, flat_samples: torch.Tensor) -> Dict[str, to...
method _dense_flatten (line 415) | def _dense_flatten(self, samples: Dict[str, torch.Tensor]) -> torch.Te...
method _dense_get_mvn (line 424) | def _dense_get_mvn(self):
class AutoGaussianFunsor (line 444) | class AutoGaussianFunsor(AutoGaussian):
method __init__ (line 453) | def __init__(self, *args, **kwargs):
method _setup_prototype (line 457) | def _setup_prototype(self, *args, **kwargs):
method _sample_aux_values (line 497) | def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch...
function _precision_to_scale_tril (line 554) | def _precision_to_scale_tril(P):
function _try_possibly_intractable (line 565) | def _try_possibly_intractable(fn, *args, **kwargs):
function _plates_to_shape (line 580) | def _plates_to_shape(plates):
function _break_plates (line 587) | def _break_plates(x, all_plates, kept_plates):
function _import_funsor (line 616) | def _import_funsor():
FILE: pyro/infer/autoguide/guides.py
function prototype_hide_fn (line 45) | def prototype_hide_fn(msg):
class AutoGuide (line 50) | class AutoGuide(PyroModule):
method __init__ (line 67) | def __init__(self, model, *, create_plates=None):
method model (line 77) | def model(self):
method __getstate__ (line 80) | def __getstate__(self):
method __setstate__ (line 86) | def __setstate__(self, state):
method _update_master (line 94) | def _update_master(self, master_ref):
method call (line 100) | def call(self, *args, **kwargs):
method sample_latent (line 115) | def sample_latent(*args, **kwargs):
method __setattr__ (line 122) | def __setattr__(self, name, value):
method _create_plates (line 128) | def _create_plates(self, *args, **kwargs):
method _setup_prototype (line 155) | def _setup_prototype(self, *args, **kwargs):
method median (line 174) | def median(self, *args, **kwargs):
class AutoGuideList (line 184) | class AutoGuideList(AutoGuide, nn.ModuleList):
method _check_prototype (line 198) | def _check_prototype(self, part_trace):
method append (line 205) | def append(self, part):
method add (line 222) | def add(self, part):
method forward (line 230) | def forward(self, *args, **kwargs):
method median (line 253) | def median(self, *args, **kwargs):
method quantiles (line 265) | def quantiles(self, quantiles, *args, **kwargs):
class AutoCallable (line 279) | class AutoCallable(AutoGuide):
method __init__ (line 309) | def __init__(self, model, guide, median=lambda *args, **kwargs: {}):
method forward (line 314) | def forward(self, *args, **kwargs):
class AutoDelta (line 319) | class AutoDelta(AutoGuide):
method __init__ (line 352) | def __init__(self, model, init_loc_fn=init_to_median, *, create_plates...
method _setup_prototype (line 357) | def _setup_prototype(self, *args, **kwargs):
method forward (line 376) | def forward(self, *args, **kwargs):
method median (line 404) | def median(self, *args, **kwargs):
class AutoNormal (line 415) | class AutoNormal(AutoGuide):
method __init__ (line 448) | def __init__(
method _setup_prototype (line 460) | def _setup_prototype(self, *args, **kwargs):
method _get_loc_and_scale (line 494) | def _get_loc_and_scale(self, name):
method forward (line 499) | def forward(self, *args, **kwargs):
method median (line 556) | def median(self, *args, **kwargs):
method quantiles (line 574) | def quantiles(self, quantiles, *args, **kwargs):
class AutoContinuous (line 605) | class AutoContinuous(AutoGuide):
method __init__ (line 632) | def __init__(self, model, init_loc_fn=init_to_median):
method _setup_prototype (line 636) | def _setup_prototype(self, *args, **kwargs):
method _init_loc (line 661) | def _init_loc(self):
method get_base_dist (line 674) | def get_base_dist(self):
method get_transform (line 688) | def get_transform(self, *args, **kwargs):
method get_posterior (line 702) | def get_posterior(self, *args, **kwargs):
method sample_latent (line 710) | def sample_latent(self, *args, **kwargs):
method _unpack_latent (line 720) | def _unpack_latent(self, latent):
method forward (line 748) | def forward(self, *args, **kwargs):
method _loc_scale (line 795) | def _loc_scale(self, *args, **kwargs):
method median (line 803) | def median(self, *args, **kwargs):
method quantiles (line 818) | def quantiles(self, quantiles, *args, **kwargs):
class AutoMultivariateNormal (line 844) | class AutoMultivariateNormal(AutoContinuous):
method __init__ (line 869) | def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):
method _setup_prototype (line 875) | def _setup_prototype(self, *args, **kwargs):
method get_base_dist (line 886) | def get_base_dist(self):
method get_transform (line 891) | def get_transform(self, *args, **kwargs):
method get_posterior (line 895) | def get_posterior(self, *args, **kwargs):
method _loc_scale (line 902) | def _loc_scale(self, *args, **kwargs):
class AutoDiagonalNormal (line 909) | class AutoDiagonalNormal(AutoContinuous):
method __init__ (line 932) | def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):
method _setup_prototype (line 938) | def _setup_prototype(self, *args, **kwargs):
method get_base_dist (line 947) | def get_base_dist(self):
method get_transform (line 952) | def get_transform(self, *args, **kwargs):
method get_posterior (line 955) | def get_posterior(self, *args, **kwargs):
method _loc_scale (line 961) | def _loc_scale(self, *args, **kwargs):
class AutoLowRankMultivariateNormal (line 965) | class AutoLowRankMultivariateNormal(AutoContinuous):
method __init__ (line 993) | def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1, ...
method _setup_prototype (line 1002) | def _setup_prototype(self, *args, **kwargs):
method get_posterior (line 1018) | def get_posterior(self, *args, **kwargs):
method _loc_scale (line 1027) | def _loc_scale(self, *args, **kwargs):
class AutoNormalizingFlow (line 1032) | class AutoNormalizingFlow(AutoContinuous):
method __init__ (line 1054) | def __init__(self, model, init_transform_fn):
method get_base_dist (line 1060) | def get_base_dist(self):
method get_transform (line 1065) | def get_transform(self, *args, **kwargs):
method get_posterior (line 1068) | def get_posterior(self, *args, **kwargs):
class AutoIAFNormal (line 1079) | class AutoIAFNormal(AutoNormalizingFlow):
method __init__ (line 1107) | def __init__(
class AutoLaplaceApproximation (line 1133) | class AutoLaplaceApproximation(AutoContinuous):
method _setup_prototype (line 1156) | def _setup_prototype(self, *args, **kwargs):
method get_posterior (line 1161) | def get_posterior(self, *args, **kwargs):
method laplace_approximation (line 1167) | def laplace_approximation(self, *args, **kwargs):
class AutoDiscreteParallel (line 1199) | class AutoDiscreteParallel(AutoGuide):
method _setup_prototype (line 1205) | def _setup_prototype(self, *args, **kwargs):
method forward (line 1254) | def forward(self, *args, **kwargs):
FILE: pyro/infer/autoguide/initialization.py
function _is_multivariate (line 29) | def _is_multivariate(d):
function init_to_feasible (line 35) | def init_to_feasible(site=None):
function init_to_sample (line 50) | def init_to_sample(site=None):
function init_to_median (line 62) | def init_to_median(
function init_to_mean (line 102) | def init_to_mean(
function init_to_uniform (line 136) | def init_to_uniform(
function init_to_value (line 157) | def init_to_value(
class _InitToGenerated (line 184) | class _InitToGenerated:
method __init__ (line 185) | def __init__(self, generate):
method __call__ (line 190) | def __call__(self, site):
function init_to_generated (line 197) | def init_to_generated(site=None, generate=lambda: init_to_uniform):
class InitMessenger (line 220) | class InitMessenger(Messenger):
method __init__ (line 229) | def __init__(self, init_fn):
method _pyro_sample (line 233) | def _pyro_sample(self, msg):
method _pyro_get_init_messengers (line 253) | def _pyro_get_init_messengers(self, msg):
FILE: pyro/infer/autoguide/structured.py
function _config_auxiliary (line 26) | def _config_auxiliary(msg):
class AutoStructured (line 30) | class AutoStructured(AutoGuide):
method __init__ (line 104) | def __init__(
method _auto_config (line 138) | def _auto_config(self, sample_sites, args, kwargs):
method _setup_prototype (line 165) | def _setup_prototype(self, *args, **kwargs):
method _compress_site (line 255) | def _compress_site(site):
method get_deltas (line 268) | def get_deltas(self, save_params=None):
method forward (line 352) | def forward(self, *args, **kwargs):
method median (line 369) | def median(self, *args, **kwargs):
FILE: pyro/infer/autoguide/utils.py
function _product (line 11) | def _product(shape):
function deep_setattr (line 21) | def deep_setattr(obj, key, val):
function mean_field_entropy (line 41) | def mean_field_entropy(model, args, whitelist=None):
function helpful_support_errors (line 63) | def helpful_support_errors(site):
FILE: pyro/infer/csis.py
class CSIS (line 16) | class CSIS(Importance):
method __init__ (line 40) | def __init__(
method set_validation_batch (line 57) | def set_validation_batch(self, *args, **kwargs):
method step (line 68) | def step(self, *args, **kwargs):
method loss_and_grads (line 91) | def loss_and_grads(self, grads, batch, *args, **kwargs):
method _differentiable_loss_particle (line 142) | def _differentiable_loss_particle(self, guide_trace):
method validation_loss (line 145) | def validation_loss(self, *args, **kwargs):
method _get_matched_trace (line 161) | def _get_matched_trace(self, model_trace, *args, **kwargs):
method _sample_from_joint (line 190) | def _sample_from_joint(self, *args, **kwargs):
FILE: pyro/infer/discrete.py
function _make_ring (line 24) | def _make_ring(temperature, cache, dim_to_size):
class SamplePosteriorMessenger (line 31) | class SamplePosteriorMessenger(ReplayMessenger):
method _pyro_sample (line 34) | def _pyro_sample(self, msg):
function _sample_posterior (line 41) | def _sample_posterior(
function _sample_posterior_from_trace (line 58) | def _sample_posterior_from_trace(
function infer_discrete (line 181) | def infer_discrete(
class TraceEnumSample_ELBO (line 234) | class TraceEnumSample_ELBO(TraceEnum_ELBO):
method _get_trace (line 256) | def _get_trace(self, model, guide, args, kwargs):
method sample_saved (line 269) | def sample_saved(self):
FILE: pyro/infer/elbo.py
class ELBOModule (line 19) | class ELBOModule(torch.nn.Module):
method __init__ (line 20) | def __init__(self, model: torch.nn.Module, guide: torch.nn.Module, elb...
method forward (line 26) | def forward(self, *args, **kwargs):
class ELBO (line 30) | class ELBO(object, metaclass=ABCMeta):
method __init__ (line 110) | def __init__(
method __call__ (line 139) | def __call__(self, model: torch.nn.Module, guide: torch.nn.Module) -> ...
method _guess_max_plate_nesting (line 146) | def _guess_max_plate_nesting(self, model, guide, args, kwargs):
method _vectorized_num_particles (line 188) | def _vectorized_num_particles(self, fn):
method _get_vectorized_trace (line 207) | def _get_vectorized_trace(self, model, guide, args, kwargs):
method _get_trace (line 221) | def _get_trace(self, model, guide, args, kwargs):
method _get_traces (line 228) | def _get_traces(self, model, guide, args, kwargs):
FILE: pyro/infer/energy_distance.py
function _squared_error (line 19) | def _squared_error(x, y, scale, mask):
class EnergyDistance (line 29) | class EnergyDistance:
method __init__ (line 79) | def __init__(
method _pow (line 96) | def _pow(self, x):
method _get_traces (line 101) | def _get_traces(self, model, guide, args, kwargs):
method __call__ (line 157) | def __call__(self, model, guide, *args, **kwargs):
method loss (line 225) | def loss(self, *args, **kwargs):
FILE: pyro/infer/enum.py
function iter_discrete_escape (line 16) | def iter_discrete_escape(trace, msg):
function iter_discrete_extend (line 25) | def ite
Copy disabled (too large)
Download .json
Condensed preview — 797 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (21,437K chars).
[
{
"path": ".codecov.yml",
"chars": 213,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\nignore:\n - \"pyro/docutil.py\"\n -"
},
{
"path": ".coveragerc",
"chars": 433,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\n[report]\nomit =\n pyro/docutil."
},
{
"path": ".gitattributes",
"chars": 118,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\n*.ipynb linguist-documentation\n"
},
{
"path": ".github/FUNDING.yml",
"chars": 899,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\n# These are supported funding mod"
},
{
"path": ".github/ISSUE_TEMPLATE/config.yml",
"chars": 199,
"preview": "blank_issues_enabled: false\ncontact_links:\n - name: Pyro Discussion Forum\n url: https://forum.pyro.ai/\n about: Fo"
},
{
"path": ".github/ISSUE_TEMPLATE/issue_template.md",
"chars": 1136,
"preview": "---\nname: General Issue\nabout: Report a bug or request a feature\n---\n\n<!--\nCopyright Contributors to the Pyro project.\n\n"
},
{
"path": ".github/workflows/ci.yml",
"chars": 10343,
"preview": "name: CI\n\non:\n push:\n branches: [dev, master]\n pull_request:\n branches: [dev, master]\n\nenv:\n CXX: g++-9\n CC: g"
},
{
"path": ".github/workflows/publish.yml",
"chars": 990,
"preview": "# This workflow uses actions that are not certified by GitHub.\n# They are provided by a third-party and are governed by\n"
},
{
"path": ".gitignore",
"chars": 1810,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\nrun_outputs*\n.DS_Store\n.benchmark"
},
{
"path": ".readthedocs.yml",
"chars": 500,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\n# Required\nversion: 2\n\nbuild:\n o"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 3336,
"preview": "\n<!--\nCopyright Contributors to the Pyro project.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n# Contributor Covenant Code "
},
{
"path": "CONTRIBUTING.md",
"chars": 4047,
"preview": "<!--\nCopyright Contributors to the Pyro project.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n# Development\n\nPlease follow "
},
{
"path": "LICENSE.md",
"chars": 11358,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "LICENSES/Apache-2.0.txt",
"chars": 10280,
"preview": "Apache License\nVersion 2.0, January 2004\nhttp://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AN"
},
{
"path": "LICENSES/BSD-3-Clause.txt",
"chars": 1460,
"preview": "Copyright (c) <year> <owner>. \n\nRedistribution and use in source and binary forms, with or without modification, are per"
},
{
"path": "LICENSES/MIT.txt",
"chars": 1078,
"preview": "MIT License\n\nCopyright (c) <year> <copyright holders>\n\nPermission is hereby granted, free of charge, to any person obtai"
},
{
"path": "MANIFEST.in",
"chars": 60,
"preview": "include LICENSE.md MANIFEST.in\nrecursive-include pyro *.cpp\n"
},
{
"path": "Makefile",
"chars": 2493,
"preview": ".PHONY: all install docs lint format test integration-test clean FORCE\n\nall: docs test\n\ninstall: FORCE\n\tpip install -e ."
},
{
"path": "README.md",
"chars": 4629,
"preview": "<!--\nCopyright Contributors to the Pyro project.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n<div align=\"center\">\n <a hre"
},
{
"path": "RELEASE-MANAGEMENT.md",
"chars": 1719,
"preview": "<!--\nCopyright Contributors to the Pyro project.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n# Pyro release management\n\nTh"
},
{
"path": "docker/Dockerfile",
"chars": 2106,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\nARG base_img=ubuntu:24.04\nFROM ${"
},
{
"path": "docker/Makefile",
"chars": 5178,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\n.PHONY: help create-host-workspac"
},
{
"path": "docker/README.md",
"chars": 2740,
"preview": "<!--\nCopyright Contributors to the Pyro project.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n## Using Pyro Docker\n\nSome ut"
},
{
"path": "docker/install.sh",
"chars": 1157,
"preview": "#!/usr/bin/env bash\n\n# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\nset -xe\n\npip"
},
{
"path": "docs/Makefile",
"chars": 810,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\n# Minimal makefile for Sphinx doc"
},
{
"path": "docs/README.md",
"chars": 636,
"preview": "<!--\nCopyright Contributors to the Pyro project.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n# Documentation #\nPyro Docume"
},
{
"path": "docs/requirements.txt",
"chars": 258,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\nsphinx==4.2.0\nsphinx-rtd-theme==1"
},
{
"path": "docs/source/_static/css/pyro.css",
"chars": 575,
"preview": "/*\n * Copyright Contributors to the Pyro project.\n *\n * SPDX-License-Identifier: Apache-2.0\n */\n\n@import url(\"theme.css\""
},
{
"path": "docs/source/_static/img/favicon/browserconfig.xml",
"chars": 373,
"preview": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n\n<!--\nCopyright Contributors to the Pyro project.\n\nSPDX-License-Identifier: Apach"
},
{
"path": "docs/source/_static/img/favicon/manifest.json",
"chars": 720,
"preview": "{\n \"name\": \"App\",\n \"icons\": [\n {\n \"src\": \"\\/android-icon-36x36.png\",\n \"sizes\": \"36x36\",\n \"type\": \"image\\/png\",\n "
},
{
"path": "docs/source/conf.py",
"chars": 6925,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport sys\n\n# import "
},
{
"path": "docs/source/contrib.autoname.rst",
"chars": 501,
"preview": "Automatic Name Generation\n==========================\n\n.. automodule:: pyro.contrib.autoname\n :members:\n :undoc-mem"
},
{
"path": "docs/source/contrib.bnn.rst",
"chars": 212,
"preview": "Bayesian Neural Networks\n=========================\n\n.. automodule:: pyro.contrib.bnn\n\nHiddenLayer\n----------------------"
},
{
"path": "docs/source/contrib.cevae.rst",
"chars": 1737,
"preview": "Causal Effect VAE\n=================\n\n.. automodule:: pyro.contrib.cevae\n\nCEVAE Class\n-----------\n.. autoclass:: pyro.con"
},
{
"path": "docs/source/contrib.easyguide.rst",
"chars": 490,
"preview": "Easy Custom Guides\n==================\n\n.. automodule:: pyro.contrib.easyguide\n\nEasyGuide\n---------\n.. autoclass:: pyro.c"
},
{
"path": "docs/source/contrib.epidemiology.rst",
"chars": 1493,
"preview": "Epidemiology\n============\n.. automodule:: pyro.contrib.epidemiology\n\n.. warning:: Code in ``pyro.contrib.epidemiology`` "
},
{
"path": "docs/source/contrib.examples.rst",
"chars": 708,
"preview": "Pyro Examples\n=============\n\nDatasets\n--------\n\nMulti MNIST\n~~~~~~~~~~~\n.. automodule:: pyro.contrib.examples.multi_mnis"
},
{
"path": "docs/source/contrib.forecast.rst",
"chars": 2562,
"preview": "Forecasting\n===========\n.. automodule:: pyro.contrib.forecast\n\n``pyro.contrib.forecast`` is a lightweight framework for "
},
{
"path": "docs/source/contrib.funsor.rst",
"chars": 1527,
"preview": "Funsor-based Pyro\n==========================\n\n\nPrimitives\n----------\n\n.. automodule:: pyro.contrib.funsor\n :members:\n"
},
{
"path": "docs/source/contrib.gp.rst",
"chars": 1629,
"preview": "Gaussian Processes\n==================\n\nSee the `Gaussian Processes tutorial <http://pyro.ai/examples/gp.html>`_ for an i"
},
{
"path": "docs/source/contrib.minipyro.rst",
"chars": 144,
"preview": "Minipyro\n========\n.. automodule:: pyro.contrib.minipyro\n :members:\n :undoc-members:\n :special-members: __call__"
},
{
"path": "docs/source/contrib.mue.rst",
"chars": 1305,
"preview": "Biological Sequence Models with MuE\n===================================\n.. automodule:: pyro.contrib.mue\n\n.. warning:: C"
},
{
"path": "docs/source/contrib.oed.rst",
"chars": 406,
"preview": "Optimal Experiment Design\n=========================\n\n.. automodule:: pyro.contrib.oed\n\nExpected Information Gain\n-------"
},
{
"path": "docs/source/contrib.randomvariable.rst",
"chars": 243,
"preview": "Random Variables\n================\n\n.. automodule:: pyro.contrib.randomvariable\n\nRandom Variable\n---------------\n.. autoc"
},
{
"path": "docs/source/contrib.timeseries.rst",
"chars": 815,
"preview": "Time Series\n===========\n.. automodule:: pyro.contrib.timeseries\n\nSee the `GP example <http://pyro.ai/examples/timeseries"
},
{
"path": "docs/source/contrib.tracking.rst",
"chars": 807,
"preview": "Tracking\n========\n\n.. automodule:: pyro.contrib.tracking\n\nData Association\n----------------\n.. automodule:: pyro.contrib"
},
{
"path": "docs/source/contrib.zuko.rst",
"chars": 75,
"preview": "Zuko in Pyro\n============\n\n.. automodule:: pyro.contrib.zuko\n :members:\n"
},
{
"path": "docs/source/distributions.rst",
"chars": 20493,
"preview": "Distributions\n=============\n\n.. toctree::\n :glob:\n :maxdepth: 2\n :caption: Contents:\n\nPyTorch Distributions\n~~~~~~"
},
{
"path": "docs/source/getting_started.rst",
"chars": 318,
"preview": "Getting Started\n===============\n\n - `Install Pyro <http://pyro.ai#install>`_.\n\n - Learn the basic concepts of Pyro:\n `"
},
{
"path": "docs/source/index.rst",
"chars": 991,
"preview": ".. Pyro documentation master file, created by\n sphinx-quickstart on Thu Jun 15 17:16:14 2017.\n You can adapt this fi"
},
{
"path": "docs/source/infer.autoguide.rst",
"chars": 3698,
"preview": "Automatic Guide Generation\n==========================\n\n.. automodule:: pyro.infer.autoguide\n\nAutoGuide\n---------\n.. auto"
},
{
"path": "docs/source/infer.reparam.rst",
"chars": 3622,
"preview": "Reparameterizers\n================\n.. automodule:: pyro.infer.reparam\n\nThe :mod:`pyro.infer.reparam` module contains repa"
},
{
"path": "docs/source/infer.util.rst",
"chars": 446,
"preview": "Inference utilities\n===================\n\n.. autofunction:: pyro.infer.util.enable_validation\n.. autofunction:: pyro.infe"
},
{
"path": "docs/source/inference.rst",
"chars": 754,
"preview": "Inference\n=========\n\nIn the context of probabilistic modeling, learning is usually called inference.\nIn the particular c"
},
{
"path": "docs/source/inference_algos.rst",
"chars": 2245,
"preview": "SVI\n---\n\n.. automodule:: pyro.infer.svi\n :members:\n :undoc-members:\n :show-inheritance:\n\nELBO\n----\n\n.. automodu"
},
{
"path": "docs/source/mcmc.rst",
"chars": 44,
"preview": "MCMC\n====\n\n.. include:: pyro.infer.mcmc.txt\n"
},
{
"path": "docs/source/nn.rst",
"chars": 897,
"preview": "Neural Networks\n===============\n\nThe module `pyro.nn` provides implementations of neural network modules\nthat are useful"
},
{
"path": "docs/source/ops.rst",
"chars": 2132,
"preview": "Miscellaneous Ops\n=================\n\nThe ``pyro.ops`` module implements tensor utilities\nthat are mostly independent of "
},
{
"path": "docs/source/optimization.rst",
"chars": 417,
"preview": "Optimization\n============\n\nThe module `pyro.optim` provides support for optimization in Pyro. In particular\nit provides "
},
{
"path": "docs/source/parameters.rst",
"chars": 592,
"preview": "Parameters\n==========\n\nParameters in Pyro are basically thin wrappers around PyTorch Tensors that carry unique names. \nA"
},
{
"path": "docs/source/poutine.rst",
"chars": 1323,
"preview": "Poutine (Effect handlers)\n==========================\n\nBeneath the built-in inference algorithms, Pyro has a library of c"
},
{
"path": "docs/source/primitives.rst",
"chars": 158,
"preview": "Primitives\n==========\n\n.. automodule:: pyro.primitives\n :members:\n :show-inheritance:\n :member-order: bysource\n"
},
{
"path": "docs/source/pyro.infer.mcmc.txt",
"chars": 1053,
"preview": "MCMC\n----\n\n.. autoclass:: pyro.infer.mcmc.api.MCMC\n :members:\n :undoc-members:\n :show-inheritance:\n\nStreamingMC"
},
{
"path": "docs/source/pyro.optim.txt",
"chars": 1181,
"preview": "Pyro Optimizers\n---------------\n\n.. automodule:: pyro.optim.optim\n :members:\n :undoc-members:\n :special-members"
},
{
"path": "docs/source/pyro.poutine.txt",
"chars": 3521,
"preview": "Messenger\n__________\n\n.. automodule:: pyro.poutine.messenger\n :members:\n :undoc-members:\n :show-inheritance:\n\nB"
},
{
"path": "docs/source/settings.rst",
"chars": 89,
"preview": "Settings\n--------\n\n.. automodule:: pyro.settings\n :members:\n :member-order: bysource\n"
},
{
"path": "docs/source/testing.rst",
"chars": 124,
"preview": "Testing Utilities\n-----------------\n\n.. automodule:: pyro.distributions.testing.gof\n :members:\n :member-order: bysou"
},
{
"path": "examples/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/air/air.py",
"chars": 14031,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nAn implementation of the mo"
},
{
"path": "examples/air/main.py",
"chars": 14648,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nAIR applied to the multi-mn"
},
{
"path": "examples/air/modules.py",
"chars": 2991,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nimport torch.nn as"
},
{
"path": "examples/air/viz.py",
"chars": 2528,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport math\nfrom collections im"
},
{
"path": "examples/baseball.py",
"chars": 16394,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport logging\n"
},
{
"path": "examples/capture_recapture/cjs.py",
"chars": 14374,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nWe show how to implement se"
},
{
"path": "examples/contrib/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/contrib/autoname/mixture.py",
"chars": 2572,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport torch\nf"
},
{
"path": "examples/contrib/autoname/scoping_mixture.py",
"chars": 2239,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport torch\nf"
},
{
"path": "examples/contrib/autoname/tree_data.py",
"chars": 3560,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport torch\nf"
},
{
"path": "examples/contrib/cevae/synthetic.py",
"chars": 3802,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis example demonstrates h"
},
{
"path": "examples/contrib/epidemiology/regional.py",
"chars": 7900,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport logging\n\nimp"
},
{
"path": "examples/contrib/epidemiology/sir.py",
"chars": 15144,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n# This script aims to replicate the"
},
{
"path": "examples/contrib/forecast/bart.py",
"chars": 7530,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport logging\n\nimp"
},
{
"path": "examples/contrib/funsor/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/contrib/funsor/hmm.py",
"chars": 35805,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis example is largely cop"
},
{
"path": "examples/contrib/gp/sv-dkl.py",
"chars": 8529,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nAn example to use Pyro Gaus"
},
{
"path": "examples/contrib/mue/FactorMuE.py",
"chars": 13776,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nA probabilistic PCA model with "
},
{
"path": "examples/contrib/mue/ProfileHMM.py",
"chars": 10356,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nA standard profile HMM model [1"
},
{
"path": "examples/contrib/oed/ab_test.py",
"chars": 4658,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nfrom functools "
},
{
"path": "examples/contrib/oed/gp_bayes_opt.py",
"chars": 5456,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nimport torch.autog"
},
{
"path": "examples/contrib/timeseries/gp_models.py",
"chars": 7165,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nfrom os.path im"
},
{
"path": "examples/cvae/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/cvae/baseline.py",
"chars": 3456,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport copy\nfrom pathlib import Pat"
},
{
"path": "examples/cvae/cvae.py",
"chars": 6895,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pathlib import Path\n\nimport nu"
},
{
"path": "examples/cvae/main.py",
"chars": 4147,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport baseline\nim"
},
{
"path": "examples/cvae/mnist.py",
"chars": 3204,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport numpy as np\nimport torch\nfro"
},
{
"path": "examples/cvae/util.py",
"chars": 4845,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pathlib import Path\n\nimport ma"
},
{
"path": "examples/dmm.py",
"chars": 25375,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nAn implementation of a Deep"
},
{
"path": "examples/eight_schools/README.md",
"chars": 767,
"preview": "<!--\nCopyright (c) 2017-2019 Uber Technologies, Inc.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\nAnalysis of the eight sch"
},
{
"path": "examples/eight_schools/data.py",
"chars": 249,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\n\nJ = 8\ny = torch.t"
},
{
"path": "examples/eight_schools/mcmc.py",
"chars": 1818,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport logging\n"
},
{
"path": "examples/eight_schools/svi.py",
"chars": 3025,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport logging\n"
},
{
"path": "examples/einsum.py",
"chars": 7580,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis example demonstrates h"
},
{
"path": "examples/hmm.py",
"chars": 32364,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis example shows how to m"
},
{
"path": "examples/inclined_plane.py",
"chars": 5455,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\nimport argparse\n\nimport numpy "
},
{
"path": "examples/lda.py",
"chars": 6847,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis example implements amo"
},
{
"path": "examples/lkj.py",
"chars": 2735,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport torch\n\n"
},
{
"path": "examples/minipyro.py",
"chars": 2971,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis example demonstrates t"
},
{
"path": "examples/mixed_hmm/README.md",
"chars": 4612,
"preview": "<!--\nCopyright (c) 2017-2019 Uber Technologies, Inc.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n# Hierarchical mixed-effe"
},
{
"path": "examples/mixed_hmm/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/mixed_hmm/experiment.py",
"chars": 5959,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport functool"
},
{
"path": "examples/mixed_hmm/model.py",
"chars": 9884,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nfrom torch.distrib"
},
{
"path": "examples/mixed_hmm/seal_data.py",
"chars": 2560,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nfrom urllib.request i"
},
{
"path": "examples/neutra.py",
"chars": 9381,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis example illustrates the us"
},
{
"path": "examples/rsa/README.md",
"chars": 1067,
"preview": "<!--\nCopyright (c) 2017-2019 Uber Technologies, Inc.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n## Rational Speech Acts ("
},
{
"path": "examples/rsa/generics.py",
"chars": 5389,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nInterpreting generic statem"
},
{
"path": "examples/rsa/hyperbole.py",
"chars": 6737,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nInterpreting hyperbole with"
},
{
"path": "examples/rsa/schelling.py",
"chars": 2839,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nSchelling coordination game"
},
{
"path": "examples/rsa/schelling_false.py",
"chars": 3334,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nSchelling coordination game"
},
{
"path": "examples/rsa/search_inference.py",
"chars": 7562,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nInference algorithms and ut"
},
{
"path": "examples/rsa/semantic_parsing.py",
"chars": 8669,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nCombining models of RSA pra"
},
{
"path": "examples/scanvi/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/scanvi/scanvi.py",
"chars": 16794,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nWe use a semi-supervised deep g"
},
{
"path": "examples/sir_hmc.py",
"chars": 24193,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n# Introduction\n# ============\n#\n# T"
},
{
"path": "examples/smcfilter.py",
"chars": 3358,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport logging\n"
},
{
"path": "examples/sparse_gamma_def.py",
"chars": 12106,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n# This is an implementation of "
},
{
"path": "examples/sparse_regression.py",
"chars": 13618,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport math\n\nim"
},
{
"path": "examples/svi_horovod.py",
"chars": 6668,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n# Distributed training via Horovod."
},
{
"path": "examples/svi_lightning.py",
"chars": 4919,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n# Distributed training via Pytorch "
},
{
"path": "examples/svi_torch.py",
"chars": 5054,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n# Using vanilla PyTorch to perform "
},
{
"path": "examples/toy_mixture_model_discrete_enumeration.py",
"chars": 5184,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nA toy mixture model to provide "
},
{
"path": "examples/vae/ss_vae_M2.py",
"chars": 20428,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport torch\ni"
},
{
"path": "examples/vae/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/vae/utils/custom_mlp.py",
"chars": 6871,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom inspect import isclass\n\nim"
},
{
"path": "examples/vae/utils/mnist_cached.py",
"chars": 8857,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport errno\nimport os\nfrom fun"
},
{
"path": "examples/vae/utils/vae_plots.py",
"chars": 3678,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\n\n\ndef plot_conditi"
},
{
"path": "examples/vae/vae.py",
"chars": 9201,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport numpy a"
},
{
"path": "examples/vae/vae_comparison.py",
"chars": 9094,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport itertool"
},
{
"path": "profiler/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "profiler/distributions.py",
"chars": 5360,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport torch\nf"
},
{
"path": "profiler/gaussianhmm.py",
"chars": 2678,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\n\nimport torch\nfrom "
},
{
"path": "profiler/hmm.py",
"chars": 2834,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport os\nimpor"
},
{
"path": "profiler/profiling_utils.py",
"chars": 4006,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport cProfile\nimport functool"
},
{
"path": "pyproject.toml",
"chars": 216,
"preview": "# Copyright Contributors to the Pyro project.\n#\n# SPDX-License-Identifier: Apache-2.0\n\n[tool.ruff]\nextend-exclude = [\"*."
},
{
"path": "pyro/__init__.py",
"chars": 1324,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport pyro.poutine as poutine\n"
},
{
"path": "pyro/contrib/README.md",
"chars": 699,
"preview": "<!--\nCopyright Contributors to the Pyro project.\n\nSPDX-License-Identifier: Apache-2.0\n-->\n\n# Contributed Code\n\nCode in `"
},
{
"path": "pyro/contrib/__init__.py",
"chars": 693,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nr\"\"\"\nContributed Code\n========="
},
{
"path": "pyro/contrib/autoguide.py",
"chars": 317,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport warnings\n\nfrom pyro.infe"
},
{
"path": "pyro/contrib/autoname/__init__.py",
"chars": 486,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThe :mod:`pyro.contrib.auto"
},
{
"path": "pyro/contrib/autoname/autoname.py",
"chars": 5100,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom collections import defaultdict"
},
{
"path": "pyro/contrib/autoname/named.py",
"chars": 8624,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThe ``pyro.contrib.named`` "
},
{
"path": "pyro/contrib/autoname/scoping.py",
"chars": 6482,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\n``pyro.contrib.autoname.sco"
},
{
"path": "pyro/contrib/bnn/__init__.py",
"chars": 177,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pyro.contrib.bnn.hidden_la"
},
{
"path": "pyro/contrib/bnn/hidden_layer.py",
"chars": 5384,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nimport torch.nn.fu"
},
{
"path": "pyro/contrib/bnn/utils.py",
"chars": 490,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport math\n\nimport torch\n\n\ndef"
},
{
"path": "pyro/contrib/cevae/__init__.py",
"chars": 23079,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis module implements the "
},
{
"path": "pyro/contrib/conjugate/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "pyro/contrib/conjugate/infer.py",
"chars": 9073,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom collections import default"
},
{
"path": "pyro/contrib/easyguide/__init__.py",
"chars": 206,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pyro.contrib.easyguide.eas"
},
{
"path": "pyro/contrib/easyguide/easyguide.py",
"chars": 12922,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport re\nimport weakref\nfrom a"
},
{
"path": "pyro/contrib/epidemiology/__init__.py",
"chars": 406,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pyro.distributions.coalescent "
},
{
"path": "pyro/contrib/epidemiology/compartmental.py",
"chars": 49835,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport functools\nimport logging\nimp"
},
{
"path": "pyro/contrib/epidemiology/distributions.py",
"chars": 13525,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport math\nfrom contextlib import "
},
{
"path": "pyro/contrib/epidemiology/models.py",
"chars": 51215,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport re\n\nimport torch\nfrom torch."
},
{
"path": "pyro/contrib/epidemiology/util.py",
"chars": 10993,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport numpy as np\nimport torch\n\nim"
},
{
"path": "pyro/contrib/examples/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "pyro/contrib/examples/bart.py",
"chars": 6858,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport bz2\nimpo"
},
{
"path": "pyro/contrib/examples/finance.py",
"chars": 705,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport urllib\n\nimport pan"
},
{
"path": "pyro/contrib/examples/multi_mnist.py",
"chars": 2920,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis script generates a dat"
},
{
"path": "pyro/contrib/examples/nextstrain.py",
"chars": 1567,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport logging\nimport os\nimport sub"
},
{
"path": "pyro/contrib/examples/polyphonic_data_loader.py",
"chars": 6919,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nData loader logic with two "
},
{
"path": "pyro/contrib/examples/scanvi_data.py",
"chars": 7442,
"preview": "# Copyright Contributors to the Pyro project.\n# Copyright (c) 2020, YosefLab.\n# SPDX-License-Identifier: Apache-2.0 AND "
},
{
"path": "pyro/contrib/examples/util.py",
"chars": 2479,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport sys\n\nimport to"
},
{
"path": "pyro/contrib/forecast/__init__.py",
"chars": 360,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom .evaluate import backtest, eva"
},
{
"path": "pyro/contrib/forecast/evaluate.py",
"chars": 8823,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport logging\nimport warnings\nfrom"
},
{
"path": "pyro/contrib/forecast/forecaster.py",
"chars": 23353,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport logging\nfrom abc import ABCM"
},
{
"path": "pyro/contrib/forecast/util.py",
"chars": 16054,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport numbers\nfrom functools impor"
},
{
"path": "pyro/contrib/funsor/__init__.py",
"chars": 1223,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport pyroapi\n\nfrom pyro.contrib.f"
},
{
"path": "pyro/contrib/funsor/handlers/__init__.py",
"chars": 1408,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pyro.poutine import ( # noqa:"
},
{
"path": "pyro/contrib/funsor/handlers/enum_messenger.py",
"chars": 9727,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nThis file contains reimplementa"
},
{
"path": "pyro/contrib/funsor/handlers/named_messenger.py",
"chars": 7623,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom collections import OrderedDict"
},
{
"path": "pyro/contrib/funsor/handlers/plate_messenger.py",
"chars": 15224,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom collections import OrderedDict"
},
{
"path": "pyro/contrib/funsor/handlers/primitives.py",
"chars": 1076,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport pyro.poutine.runtime\nfrom py"
},
{
"path": "pyro/contrib/funsor/handlers/replay_messenger.py",
"chars": 1806,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pyro.contrib.funsor.handlers.p"
},
{
"path": "pyro/contrib/funsor/handlers/runtime.py",
"chars": 8558,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom collections import Counter, Or"
},
{
"path": "pyro/contrib/funsor/handlers/trace_messenger.py",
"chars": 3632,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport funsor\nimport torch\n\nfrom py"
},
{
"path": "pyro/contrib/funsor/infer/__init__.py",
"chars": 514,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pyro.infer import SVI, config_"
},
{
"path": "pyro/contrib/funsor/infer/discrete.py",
"chars": 2921,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport functools\n\nimport funsor\n\nfr"
},
{
"path": "pyro/contrib/funsor/infer/elbo.py",
"chars": 1652,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport pyro.ops.jit\nfrom pyro.infer"
},
{
"path": "pyro/contrib/funsor/infer/trace_elbo.py",
"chars": 1724,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport contextlib\n\nimport funsor\n\nf"
},
{
"path": "pyro/contrib/funsor/infer/traceenum_elbo.py",
"chars": 12731,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport contextlib\n\nimport funsor\nfr"
},
{
"path": "pyro/contrib/funsor/infer/tracetmc_elbo.py",
"chars": 1916,
"preview": "# Copyright Contributors to the Pyro project.\n# SPDX-License-Identifier: Apache-2.0\n\nimport contextlib\n\nimport funsor\n\nf"
},
{
"path": "pyro/contrib/gp/__init__.py",
"chars": 340,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pyro.contrib.gp import ker"
},
{
"path": "pyro/contrib/gp/kernels/__init__.py",
"chars": 1533,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pyro.contrib.gp.kernels.br"
},
{
"path": "pyro/contrib/gp/kernels/brownian.py",
"chars": 1577,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nfrom torch.distrib"
},
{
"path": "pyro/contrib/gp/kernels/coregionalize.py",
"chars": 3699,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nfrom torch.distrib"
},
{
"path": "pyro/contrib/gp/kernels/dot_product.py",
"chars": 2646,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nfrom torch.distrib"
},
{
"path": "pyro/contrib/gp/kernels/isotropic.py",
"chars": 5740,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nfrom torch.distrib"
},
{
"path": "pyro/contrib/gp/kernels/kernel.py",
"chars": 8446,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport numbers\n\nfrom pyro.contr"
},
{
"path": "pyro/contrib/gp/kernels/periodic.py",
"chars": 2424,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport math\n\nimport torch\nfrom "
},
{
"path": "pyro/contrib/gp/kernels/static.py",
"chars": 1536,
"preview": "# Copyright (c) 2017-2019 Uber Technologies, Inc.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nfrom torch.distrib"
}
]
// ... and 597 more files (download for full content)
About this extraction
This page contains the full source code of the pyro-ppl/pyro GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 797 files (20.1 MB), approximately 5.3M tokens, and a symbol index with 5897 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.