[
  {
    "path": ".JuliaFormatter.toml",
    "content": "style = \"sciml\"\n"
  },
  {
    "path": ".git-blame-ignore-revs",
    "content": "# Run this command to always ignore these in local `git blame`:\n# git config blame.ignoreRevsFile .git-blame-ignore-revs\n\n# Run formatter\n70fd432667fb431e08ba52728734108d822a1922\n# Run formatter \n21038a047c023330876feb9259cd5c92add3ca81\n# Run formatter after bracket alignment removal\n799277f9652258282a91ecfe976df5fb8ab64c82\n# Format\ndb4333c604cc23c3c36420f09aa998d01ef0214b\n"
  },
  {
    "path": ".github/workflows/CI.yml",
    "content": "name: CI\non:\n  pull_request:\n  push:\n    branches:\n      - main \n    tags: '*'\njobs:\n  unittest:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        group:\n          - Core\n        version:\n          - '1'\n          - '1.7'\n    steps:\n      - uses: actions/checkout@v2\n      - uses: julia-actions/setup-julia@v1\n        with:\n          version: ${{ matrix.version }}\n      - uses: julia-actions/julia-buildpkg@v1\n      - uses: julia-actions/julia-runtest@v1\n      - uses: julia-actions/julia-processcoverage@v1\n      - uses: codecov/codecov-action@v2\n        with:\n          file: lcov.info\n"
  },
  {
    "path": ".github/workflows/CompatHelper.yml",
    "content": "name: CompatHelper\non:\n  schedule:\n    - cron: 0 0 * * *\n  workflow_dispatch:\njobs:\n  CompatHelper:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Pkg.add(\"CompatHelper\")\n        run: julia -e 'using Pkg; Pkg.add(\"CompatHelper\")'\n      - name: CompatHelper.main()\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}\n        run: julia -e 'using CompatHelper; CompatHelper.main()'\n"
  },
  {
    "path": ".github/workflows/Documentation.yml",
    "content": "name: Documentation\n\non:\n  push:\n    branches:\n      - main\n    tags: '*'\n  pull_request:\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - uses: julia-actions/setup-julia@v1\n        with:\n          version: '1'\n      - name: Install dependencies\n        run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'\n      - name: Build and deploy\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token\n          DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key\n          DATADEPS_ALWAYS_ACCEPT: true\n        run: julia --project=docs/ docs/make.jl\n"
  },
  {
    "path": ".github/workflows/FormatCheck.yml",
    "content": "name: format-check\n\non:\n  push:\n    branches:\n      - 'main'\n      - 'release-'\n    tags: '*'\n  pull_request:\n\njobs:\n  build:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        julia-version: [1]\n        julia-arch: [x86]\n        os: [ubuntu-latest]\n    steps:\n      - uses: julia-actions/setup-julia@latest\n        with:\n          version: ${{ matrix.julia-version }}\n\n      - uses: actions/checkout@v1\n      - name: Install JuliaFormatter and format\n        # This will use the latest version by default but you can set the version like so:\n        #\n        # julia  -e 'using Pkg; Pkg.add(PackageSpec(name=\"JuliaFormatter\", version=\"0.13.0\"))'\n        run: |\n          julia  -e 'using Pkg; Pkg.add(PackageSpec(name=\"JuliaFormatter\"))'\n          julia  -e 'using JuliaFormatter; format(\".\", verbose=true)'\n      - name: Format check\n        run: |\n          julia -e '\n          out = Cmd(`git diff`) |> read |> String\n          if out == \"\"\n              exit(0)\n          else\n              @error \"Some files have not been formatted !!!\"\n              write(stdout, out)\n              exit(1)\n          end'\n"
  },
  {
    "path": ".github/workflows/TagBot.yml",
    "content": "name: TagBot\non:\n  issue_comment:\n    types:\n      - created\n  workflow_dispatch:\njobs:\n  TagBot:\n    if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'\n    runs-on: ubuntu-latest\n    steps:\n      - uses: JuliaRegistries/TagBot@v1\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n          ssh: ${{ secrets.DOCUMENTER_KEY }}\n"
  },
  {
    "path": ".github/workflows/benchmark.yml",
    "content": "\nname: Benchmarks\n\non:\n  pull_request:\n  push:\n    branches:\n      - main \n    tags: '*'\n\njobs:\n  benchmark:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - uses: julia-actions/setup-julia@latest\n        with:\n          version: 1\n      - name: Install dependencies\n        run: julia -e 'using Pkg; Pkg.activate(\"tutorials\"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate();'\n      - name: Run benchmarks\n        run: julia --project=tutorials --color=yes benchmark/runbenchmarks.jl \n"
  },
  {
    "path": ".gitignore",
    "content": "Manifest.toml"
  },
  {
    "path": "CITATION.bib",
    "content": "@inproceedings{arya2022automatic,\n author = {Arya, Gaurav and Schauer, Moritz and Sch\\\"{a}fer, Frank and Rackauckas, Christopher},\n booktitle = {Advances in Neural Information Processing Systems},\n editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},\n pages = {10435--10447},\n publisher = {Curran Associates, Inc.},\n title = {Automatic Differentiation of Programs with Discrete Randomness},\n url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/43d8e5fc816c692f342493331d5e98fc-Paper-Conference.pdf},\n volume = {35},\n year = {2022}\n}\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Gaurav Arya <aryag@mit.edu> and contributors\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "Project.toml",
    "content": "name = \"StochasticAD\"\nuuid = \"e4facb34-4f7e-4bec-b153-e122c37934ac\"\nauthors = [\"Gaurav Arya <aryag@mit.edu> and contributors\"]\nversion = \"0.1.26\"\n\n[deps]\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nChainRulesOverloadGeneration = \"f51149dc-2911-5acf-81fc-2076a2a81d4f\"\nDictionaries = \"85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4\"\nDistributions = \"31c24e10-a181-5473-b8eb-7969acd0382f\"\nDistributionsAD = \"ced4e74d-a319-5a8a-b0ac-84af2272839c\"\nExprTools = \"e2ba6199-217a-4e67-a87a-7c52f15ade04\"\nForwardDiff = \"f6369f11-7733-5829-9624-2563aa707210\"\nFunctors = \"d9f16b24-f501-4c13-a1f2-28368ffc5196\"\nRandom = \"9a3f8284-a2c9-5f02-9a11-845980a1fd5c\"\n\n[weakdeps]\nEnzyme = \"7da242da-08ed-463a-9acd-ee780be4f1d9\"\n\n[extensions]\nStochasticADEnzymeExt = \"Enzyme\"\n\n[compat]\nChainRulesCore = \"1.15\"\nChainRulesOverloadGeneration = \"0.1\"\nDictionaries = \"0.3\"\nDistributions = \"0.25\"\nDistributionsAD = \"0.6\"\nExprTools = \"0.1\"\nForwardDiff = \"0.10\"\nFunctors = \"0.4.3\"\njulia = \"1\"\n\n[extras]\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nDiffResults = \"163ba53b-c6d8-5494-b064-1a9d43ac40c5\"\nForwardDiff = \"f6369f11-7733-5829-9624-2563aa707210\"\nGaussianDistributions = \"43dcc890-d446-5863-8d1a-14597580bb8d\"\nLinearAlgebra = \"37e2e46d-f89d-539d-b4ee-838fcccc9c8e\"\nMeasurements = \"eff96d63-e80a-5855-80a2-b1b0885c5ab7\"\nOffsetArrays = \"6fe1bfb0-de20-5000-8ca7-80f57d26f881\"\nPkg = \"44cfe95a-1eb2-52ea-b672-e2afdf69b78f\"\nPrintf = \"de0858da-6303-5e67-8744-51eddeeeb8d7\"\nSafeTestsets = \"1bc83da4-3b8d-516f-aca4-4fe02f6d838f\"\nStaticArrays = \"90137ffa-7385-5640-81b9-e52037218182\"\nStatistics = \"10745b16-79ce-11e8-11f9-7d13ad32a3b2\"\nStatsBase = \"2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91\"\nTest = \"8dfed614-e22c-5e08-85e1-65c5234f0b40\"\nUnPack = \"3a884ed6-31ef-47d7-9d2a-63182c4928ed\"\nZygote = \"e88e6eb3-aa80-5325-afca-941959d7151f\"\n\n[targets]\ntest = [\"LinearAlgebra\", \"Pkg\", \"Printf\", \"Test\", \"Statistics\", \"SafeTestsets\", \"OffsetArrays\", \"StaticArrays\", \"Zygote\", \"ForwardDiff\", \"GaussianDistributions\", \"Measurements\", \"UnPack\", \"StatsBase\", \"DiffResults\", \"ChainRulesCore\"]\n"
  },
  {
    "path": "README.md",
    "content": "![](docs/src/images/path_skeleton.png#gh-light-mode-only)\n![](docs/src/images/path_skeleton_dark.png#gh-dark-mode-only)\n\n# StochasticAD\n\n[![Build Status](https://github.com/gaurav-arya/StochasticAD.jl/workflows/CI/badge.svg?branch=main)](https://github.com/gaurav-arya/StochasticAD.jl/actions?query=workflow:CI)\n[![](https://img.shields.io/badge/docs-main-blue.svg)](https://gaurav-arya.github.io/StochasticAD.jl/dev/)\n[![arXiv article](https://img.shields.io/badge/article-arXiv%3A10.48550-B31B1B)](https://arxiv.org/abs/2210.08572)\n\nStochasticAD is an experimental, research package for automatic differentiation (AD) of stochastic programs. It implements AD algorithms for handling programs that can contain *discrete* randomness, based on the methodology developed in [this NeurIPS 2022 paper](https://doi.org/10.48550/arXiv.2210.08572). We're still working on docs and code cleanup!\n\n## Installation\n\nThe package can be installed with the Julia package manager:\n\n```julia\njulia> using Pkg;\njulia> Pkg.add(\"StochasticAD\");\n```\n\n## Citation\n\n```\n@inproceedings{arya2022automatic,\n author = {Arya, Gaurav and Schauer, Moritz and Sch\\\"{a}fer, Frank and Rackauckas, Christopher},\n booktitle = {Advances in Neural Information Processing Systems},\n editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},\n pages = {10435--10447},\n publisher = {Curran Associates, Inc.},\n title = {Automatic Differentiation of Programs with Discrete Randomness},\n url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/43d8e5fc816c692f342493331d5e98fc-Paper-Conference.pdf},\n volume = {35},\n year = {2022}\n}\n```\n"
  },
  {
    "path": "benchmark/benchmarks.jl",
    "content": "using BenchmarkTools\n\ninclude(\"random_walk.jl\")\ninclude(\"game_of_life.jl\")\ninclude(\"iteration.jl\")\ninclude(\"simple_ops.jl\")\n\nconst SUITE = BenchmarkGroup()\nSUITE[\"random_walk\"] = RandomWalkBenchmark.suite\nSUITE[\"game_of_life\"] = GoLBenchmark.suite\nSUITE[\"iteration\"] = IterationBenchmark.suite\nSUITE[\"simple_ops\"] = SimpleOpsBenchmark.suite\n"
  },
  {
    "path": "benchmark/game_of_life.jl",
    "content": "module GoLBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\nusing Statistics\nusing ForwardDiff: derivative\ninclude(\"../tutorials/game_of_life/core.jl\")\nusing .GoLCore: play, p\n\nconst suite = BenchmarkGroup()\n\nsuite[\"original\"] = @benchmarkable $play($p)\nsuite[\"PrunedFIs\"] = @benchmarkable derivative_estimate($play, $p;\n    backend = PrunedFIsBackend())\nsuite[\"PrunedFIsAggressive\"] = @benchmarkable derivative_estimate($play, $p;\n    backend = PrunedFIsAggressiveBackend())\nsuite[\"SmoothedFIs\"] = @benchmarkable derivative_estimate($play, $p;\n    backend = SmoothedFIsBackend())\n\nend\n"
  },
  {
    "path": "benchmark/iteration.jl",
    "content": "\"\"\"\nIn the library we have tried to avoid generated functions, instead reductions from base with the\nhope that the iteration will be optimized to be zero-cost.\nThis suite tests the performance of iteration on small nested structures, which crop up when `propagate` is called\non small structures of scalars.\nThe `couple` and `combine` operations of FIss, which use iteration, are benchmarked.\n\"\"\"\nmodule IterationBenchmark\n\nusing BenchmarkTools\nusing StochasticAD\nusing StaticArrays\n\nconst suite = BenchmarkGroup()\n\n# Examples consist of flat and non-flat versions of structures, to test zero-cost iteration.\ntups = Dict(\"easy\" => (ntuple(identity, 3), (1, (2, 3))),\n    \"hard\" => (ntuple(identity, 9), (1, (2, 3), (4, (5, (6, 7, 8), 9)))))\nSAs = Dict(\"easy\" => (SA[1, 2, 3], (1, SA[2, 3])),\n    \"hard\" => (SA[1, 2, 3, 4, 5, 6, 7, 8, 9],\n        (1, SA[2, 3], (4, (5, SA[6, 7, 8], 9)))))\n\nfor (setname, set) in ((\"tups\", tups), (\"SAs\", SAs))\n    suite[setname] = BenchmarkGroup()\n    setsuite = suite[setname]\n    for case in [\"easy\", \"hard\"]\n        casesuite = setsuite[case] = BenchmarkGroup()\n        for isflat in [false, true]\n            flatsuite = casesuite[isflat ? \"flat\" : \"not flat\"] = BenchmarkGroup()\n            values = set[case][isflat ? 1 : 2]\n            flatsuite[\"make_iterate_values\"] = @benchmarkable StochasticAD.structural_iterate($values)\n            iter_values = StochasticAD.structural_iterate(values)\n            flatsuite[\"foldl_values\"] = @benchmarkable foldl(+, $(iter_values))\n            flatsuite[\"iterate_values\"] = @benchmarkable for i in $(iter_values)\n            end\n            for backend in [PrunedFIsBackend(), PrunedFIsAggressiveBackend()]\n                FIs_suite = flatsuite[backend] = BenchmarkGroup()\n                Δs = StochasticAD.create_Δs(backend, Int)\n                Δs1 = StochasticAD.similar_new(Δs, 1, 1)\n                Δs_all = StochasticAD.structural_map(x -> map(Δ -> x, Δs1), values)\n                FIs_suite[\"make_iterate_Δs\"] = @benchmarkable StochasticAD.structural_iterate($Δs_all)\n                # We don't interpolate backend directly in below (i.e. do $FIs) because string interpolating a type\n                # seems to lead to slow benchmarks.\n                FIs_suite[\"couple_same\"] = @benchmarkable StochasticAD.couple(typeof($Δs),\n                    $Δs_all)\n                FIs_suite[\"combine_same\"] = @benchmarkable StochasticAD.combine(\n                    typeof($Δs),\n                    $Δs_all)\n            end\n        end\n    end\nend\n\nend\n"
  },
  {
    "path": "benchmark/random_walk.jl",
    "content": "module RandomWalkBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\nusing Statistics\nusing ForwardDiff: derivative\ninclude(\"../tutorials/random_walk/core.jl\")\nusing .RandomWalkCore: n, p, nsamples\nusing .RandomWalkCore: fX, get_dfX\n\nconst suite = BenchmarkGroup()\n\nsuite[\"original\"] = @benchmarkable $(fX)($p)\nsuite[\"PrunedFIs\"] = @benchmarkable derivative_estimate($fX, $p;\n    backend = PrunedFIsBackend())\nsuite[\"PrunedFIsAggressive\"] = @benchmarkable derivative_estimate($fX, $p;\n    backend = PrunedFIsAggressiveBackend())\nsuite[\"SmoothedFIs\"] = @benchmarkable derivative_estimate($fX, $p;\n    backend = SmoothedFIsBackend())\nforwarddiff_func = p -> fX(p; hardcode_leftright_step = true)\nsuite[\"ForwardDiff_smoothing\"] = @benchmarkable derivative($forwarddiff_func, $p)\n\nend\n"
  },
  {
    "path": "benchmark/runbenchmarks.jl",
    "content": "using PkgBenchmark\n\ninclude(\"utils.jl\")\nusing .Utils\n\nresults = benchmarkpkg(dirname(@__DIR__),\n    BenchmarkConfig(env = Dict(\"JULIA_NUM_THREADS\" => \"1\",\n        \"OMP_NUM_THREADS\" => \"1\")),\n    resultfile = joinpath(@__DIR__, \"result.json\"))\n@show results = print_group(results.benchmarkgroup)\n"
  },
  {
    "path": "benchmark/simple_ops.jl",
    "content": "module SimpleOpsBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\n\nconst suite = BenchmarkGroup()\n\nsuite[\"add\"] = BenchmarkGroup()\nsuite[\"add_via_propagate_nodeltas\"] = BenchmarkGroup()\nsuite[\"add_via_propagate\"] = BenchmarkGroup()\n\nsuite[\"add\"][\"original\"] = @benchmarkable +(0.5, 0.5)\nsuite[\"add_via_propagate_nodeltas\"][\"original\"] = @benchmarkable StochasticAD.propagate(+,\n    0.5,\n    0.5)\nsuite[\"add_via_propagate\"][\"original\"] = @benchmarkable StochasticAD.propagate(+, 0.5, 0.5;\n    keep_deltas = Val{\n        true,\n    })\nfor backend in [PrunedFIsBackend(), PrunedFIsAggressiveBackend()]\n    suite[\"add\"][backend] = @benchmarkable +(st, st) setup=(st = stochastic_triple(0.5;\n        backend = $backend))\n    suite[\"add_via_propagate_nodeltas\"][backend] = @benchmarkable StochasticAD.propagate(+,\n        st,\n        st) setup=(st = stochastic_triple(0.5;\n        backend = $backend))\n    suite[\"add_via_propagate\"][backend] = @benchmarkable StochasticAD.propagate(+, st, st;\n        keep_deltas = Val{\n            true,\n        }) setup=(st = stochastic_triple(0.5;\n        backend = $backend))\nend\n\nend\n"
  },
  {
    "path": "benchmark/utils.jl",
    "content": "module Utils\n\nexport print_group\n\nusing Functors\nusing BenchmarkTools\n\n## Printing\n\n# Type piracy, fine since just in benchmarking. (design of Functors should probably allow for user-customized functors)\n@functor BenchmarkTools.BenchmarkGroup\n\nfunction print_trial(t)\n    ptime = BenchmarkTools.prettytime(time(t))\n    pallocs = \"$(allocs(t)) allocs\"\n    return \"$ptime, $pallocs\"\nend\n\nfunction print_group(b)\n    fmap(t -> (t isa BenchmarkTools.Trial ? print_trial(t) : t), b)\nend\n\nend\n"
  },
  {
    "path": "docs/Project.toml",
    "content": "[deps]\nDistributions = \"31c24e10-a181-5473-b8eb-7969acd0382f\"\nDocThemeIndigo = \"8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f\"\nDocumenter = \"e30172f5-a6a5-5a46-863b-614d45cd2de4\"\nLiterate = \"98b081ad-f1c9-55d3-8b20-4c87d4299306\"\nStochasticAD = \"e4facb34-4f7e-4bec-b153-e122c37934ac\"\n"
  },
  {
    "path": "docs/make.jl",
    "content": "using Pkg\n\nusing Documenter\nusing StochasticAD\nusing DocThemeIndigo\nusing Literate\n\n### Formatting\n\nindigo = DocThemeIndigo.install(StochasticAD)\nformat = Documenter.HTML(prettyurls = false,\n    assets = [indigo, \"assets/extra_styles.css\"],\n    repolink = \"https://github.com/gaurav-arya/StochasticAD.jl\",\n    edit_link = \"main\")\n\n### Pagination\n\npages = [\n    \"Overview\" => \"index.md\",\n    \"Tutorials\" => [\n        \"tutorials/random_walk.md\",\n        \"tutorials/game_of_life.md\",\n        \"tutorials/particle_filter.md\",\n        \"tutorials/optimizations.md\",\n        \"tutorials/reverse_demo.md\"\n    ],\n    \"Public API\" => \"public_api.md\",\n    \"Developer documentation\" => \"devdocs.md\",\n    \"Limitations\" => \"limitations.md\"\n]\n\n### Prepare literate tutorials\n\n# TODO (for now they are manually built into docs/src/tutorials and checked into repo)\n\n### Make docs\n\nmakedocs(sitename = \"StochasticAD.jl\",\n    authors = \"Gaurav Arya and other contributors\",\n    modules = [StochasticAD],\n    format = format,\n    pages = pages,\n    warnonly = [:missing_docs])\n\ntry\n    deploydocs(repo = \"github.com/gaurav-arya/StochasticAD.jl\",\n        devbranch = \"main\",\n        push_preview = true)\ncatch e\n    println(\"Error encountered while deploying docs:\")\n    showerror(stdout, e)\nend\n"
  },
  {
    "path": "docs/src/assets/extra_styles.css",
    "content": ".display-light-only {display: block;}\n.display-dark-only {display: none;}\n.theme--documenter-dark .display-light-only {display: none;}\n.theme--documenter-dark .display-dark-only {display: block;}"
  },
  {
    "path": "docs/src/devdocs.md",
    "content": "# Developer documentation (WIP)\n\n## Writing a custom rule for stochastic triples\n\n### via `StochasticAD.propagate`\n\nTo handle a deterministic discrete construct that `StochasticAD` does not automatically handle (e.g. branching via `if`, boolean comparisons), it is often sufficient to simply add a dispatch rule that calls out to `StochasticAD.propagate`.\n\n```@docs\nStochasticAD.propagate\n```\n\n### via a custom dispatch\n\nIf a function does not meet the conditions of `StochasticAD.propagate` and is not already supported, a custom\ndispatch may be necessary. For example, consider the following function which manually implements a geometric random variable:\n\n```@example rule\nimport Random\nRandom.seed!(1234) # hide\nusing Distributions\n# make rng input explicit\nfunction mygeometric(rng, p)\n    x = 0\n    while !(rand(rng, Bernoulli(p)))\n        x += 1\n    end\n    return x\nend\nmygeometric(p) = mygeometric(Random.default_rng(), p)\n```\n\nThis is equivalent to `rand(Geometric(p))` which is already supported, but for pedagogical purposes we will\nimplement our own rule from scratch. Using the stochastic derivative formulas from [Automatic Differentiation of Programs with Discrete Randomness](https://doi.org/10.48550/arXiv.2210.08572), the right stochastic derivative of this program is given by\n```math\nY_R = X - 1, w_R = \\frac{x}{p(1-p)},\n```\nand the left stochastic derivative of this program is given by\n```math\nY_L = X + 1, w_L = -\\frac{x+1}{p}.\n```\n\nUsing these expressions, we can now write the dispatch rule for stochastic triples:\n\n```@example rule\nusing StochasticAD\nimport StochasticAD: StochasticTriple, similar_new, similar_empty, combine\nfunction mygeometric(rng, p_st::StochasticTriple{T}) where {T}\n    p = p_st.value\n    rng_copy = copy(rng) # save a copy for coupling later\n    x = mygeometric(rng, p)\n\n    # Form the new discrete perturbations (combinations of weight w and perturbation Y - X)\n    Δs1 = if p_st.δ > 0\n        # right stochastic derivative\n        w = p_st.δ * x / (p * (1 - p))\n        x > 0 ? similar_new(p_st.Δs, -1, w) : similar_empty(p_st.Δs, Int)\n    elseif p_st.δ < 0\n        # left stochastic derivative\n        w = -p_st.δ * (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L\n        similar_new(p_st.Δs, 1, w)\n    else\n        similar_empty(p_st.Δs, Int)\n    end\n\n    # Propagate any existing perturbations to p through the function\n    function map_func(Δ)\n        # Couple the samples by using the same RNG. (A simpler strategy would have been independent sampling, i.e. mygeometric(p + Δ) - x)\n        mygeometric(copy(rng_copy), p + Δ) - x \n    end\n    Δs2 = map(map_func, p_st.Δs)\n\n    # Return the output stochastic triple\n    StochasticTriple{T}(x, zero(x), combine((Δs2, Δs1)))\nend\n```\nIn the above, we used some of the interface functions supported by a collection of perturbations `Δs::StochasticAD.AbstractFIs`. These were `similar_empty(Δs, V)`, which created an empty perturbation of type `V`, `similar_new(Δs, Δ, w)`, which created a new perturbation of size `Δ` and weight `w`, `map(map_func, Δs)`,\nwhich propagates a collection of perturbations through a mapping function, and `combine((Δs2, Δs1)))` which combines multiple collections of perturbations together.\n\nWe can test out our rule:\n```@example rule\n@show stochastic_triple(mygeometric, 0.1)\n\n# try feeding an input that already has a pertrubation\nf(x) = mygeometric(2 * x + 0.1 * rand(Bernoulli(x)))^2\n@show stochastic_triple(f, 0.1)\n\n# verify against black-box finite differences\nN = 1000000\nsamples_stochad = [derivative_estimate(f, 0.1) for i in 1:N]\nsamples_fd = [(f(0.105) - f(0.095)) / 0.01 for i in 1:N]\n\nprintln(\"Stochastic AD: $(mean(samples_stochad)) ± $(std(samples_stochad) / sqrt(N))\")\nprintln(\"Finite differences: $(mean(samples_fd)) ± $(std(samples_fd) / sqrt(N))\")\n\nnothing # hide\n```\n\n## Distribution-specific customization of differentiation algorithm \n\n```@docs\nrandst\nInversionMethodDerivativeCoupling\n```"
  },
  {
    "path": "docs/src/index.md",
    "content": "```@raw html\n<img class=\"display-light-only\" src=\"images/path_skeleton.png\">\n<img class=\"display-dark-only\" src=\"images/path_skeleton_dark.png\">\n```\n\n# StochasticAD\n\n[StochasticAD](https://github.com/gaurav-arya/StochasticAD.jl) is an experimental, research package for automatic differentiation (AD) of stochastic programs.\nIt implements AD algorithms for handling programs that can contain *discrete* randomness, based on the methodology developed in [this NeurIPS 2022 paper](https://doi.org/10.48550/arXiv.2210.08572).\n\n## Introduction\n\nDerivatives are all about how functions are affected by a tiny change `ε` in their input. First, let's imagine perturbing the input of a deterministic, differentiable function such as $f(p) = p^2$ at $p = 2$.\n```@example continuous\nusing StochasticAD\nf(p) = p^2\nstochastic_triple(f, 2) # Feeds 2 + ε into f\n```\nThe output tells us that if we change the input `2` by a tiny amount `ε`, the output of `f` will change by approximately `4ε`. This is the case we're familiar with: we can get the value `4` by applying the chain rule, $\\frac{\\mathrm{d}}{\\mathrm{d} p} p^2 = 2p = 4$. Thinking in terms of tiny changes, the output above looks a lot like a [dual number](https://en.wikipedia.org/wiki/Dual_number). But what happens with a discrete random function? Let's give it a try. \n```@example discrete\nimport Random # hide\nRandom.seed!(4321) # hide\nusing StochasticAD, Distributions\nf(p) = rand(Bernoulli(p)) # 1 with probability p, 0 otherwise\nstochastic_triple(f, 0.5) # Feeds 0.5 + ε into f\n```\nThe output of a [Bernoulli variable](https://en.wikipedia.org/wiki/Bernoulli_distribution) cannot change by a tiny amount: it is either `0` or `1`. But in the probabilistic world, there is another way to change by a tiny amount *on average*: jump by a large amount, with tiny probability. `StochasticAD` introduces a stochastic triple object, which generalizes dual numbers by including a *third* component to describe these perturbations. Here, the stochastic triple says that the original random output was `0`, but given a small change `ε` in the input, the output will jump up to `1` with probability approximately `2ε`.\n\nStochastic triples can be used to construct a new random program whose average is the derivative of the average of the original program. We simply propagate stochastic triples through the program via [`stochastic_triple`](@ref), and then sum up the \"dual\" and \"triple\" components at the end via [`derivative_contribution`](@ref). This process is packaged together in the function [`derivative_estimate`](@ref). Let's try a crazier example, where we mix discrete and continuous randomness!\n```@example estimate\nusing StochasticAD, Distributions\nimport Random # hide\nRandom.seed!(1234) # hide\n\nfunction X(p)\n    a = p * (1 - p)\n    b = rand(Binomial(10, p))\n    c = 2 * b + 3 * rand(Bernoulli(p))\n    return a * c * rand(Normal(b, a))\nend\n\nst = @show stochastic_triple(X, 0.6) # sample a single stochastic triple at p = 0.6\n@show derivative_contribution(st) # which produces a single derivative estimate...\n\nsamples = [derivative_estimate(X, 0.6) for i in 1:1000] # many samples from derivative program\nderivative = mean(samples)\nuncertainty = std(samples) / sqrt(1000)\nprintln(\"derivative of 𝔼[X(p)] = $derivative ± $uncertainty\")\n```\n\n## Index\n\nSee the [public API](public_api.md) for a walkthrough of the API, and the tutorials on differentiating a [random walk](tutorials/random_walk.md), a [stochastic game of life](tutorials/game_of_life.md), and a [particle filter](tutorials/particle_filter.md), and solving [stochastic optimization and variational problems](tutorials/optimizations.md) with discrete randomness. This is a prototype package with a number of [limitations](limitations.md).\n\n"
  },
  {
    "path": "docs/src/limitations.md",
    "content": "# Limitations of StochasticAD\n\n`StochasticAD` has a number of limitations that are important to be aware of:\n\n* `StochasticAD` uses operator-overloading just like [ForwardDiff](https://juliadiff.org/ForwardDiff.jl/stable/), so all of the [limitations](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/) listed there apply here too. Also note that some useful features of `ForwardDiff`, such as chunking for greater efficiency with a large number of parameters, have not yet been implemented here.\n* We have limited support for reverse-mode AD via [smoothing](public_api.md#Smoothing), which cannot be guaranteed to be unbiased in all cases. \n* We do not yet support `if` statements with discrete random input. A workaround can be to use array indexing to express discrete random choices (see [the random walk tutorial](tutorials/random_walk.md) for an example).\n* We do not yet support non-real values as intermediate values (e.g. a function such as `length(A[rand(Bernoulli(p))])` where `A` is an array of strings is in theory differentiable).\n* We do not support discrete random variables that are implicitly implemented using continuous random variables, e.g. `rand() < p`.\n* We support a limited assortment of discrete random variables: currently `Bernoulli`, `Binomial`, `Geometric`, `Poisson`, and `Categorical` from [Distributions](https://juliastats.org/Distributions.jl/). We are working on increasing coverage across `Distributions` as well as other libraries providing discrete random samplers such as [MeasureTheory](https://cscherrer.github.io/MeasureTheory.jl/stable/).\n* Higher-order differentiation is not supported.\n\n`StochasticAD` is still in active development! PRs are welcome.\n\n"
  },
  {
    "path": "docs/src/public_api.md",
    "content": "# API walkthrough\n \nThe function [`derivative_estimate`](@ref) transforms a stochastic program containing discrete randomness into a new program whose average is the derivative of the original.\n```@docs\nderivative_estimate\n```\nWhile [`derivative_estimate`](@ref) is self-contained, we can also use the functions below to work with stochastic triples directly.\n```@docs\nStochasticAD.stochastic_triple\nStochasticAD.derivative_contribution\nStochasticAD.value\nStochasticAD.delta\nStochasticAD.perturbations\n```\nNote that [`derivative_estimate`](@ref) is simply the composition of [`stochastic_triple`](@ref) and [`derivative_contribution`](@ref). We also provide a convenience function for mimicking the behaviour\nof standard AD, where derivatives of discrete random steps are dropped:\n```@docs\nStochasticAD.dual_number\n```\n\n## Algorithms \n\n```@docs\nStochasticAD.ForwardAlgorithm\nStochasticAD.EnzymeReverseAlgorithm\n```\n\n## Smoothing\n\nWhat happens if we were to run [`derivative_contribution`](@ref) after each step, instead of only at the end? This is *smoothing*, which combines the second and third components of a single stochastic triple into a single dual component. \nSmoothing no longer has a guarantee of unbiasedness, but is surprisingly accurate in a number of situations. \nFor example, the popular [straight through gradient estimator](https://stackoverflow.com/questions/38361314/the-concept-of-straight-through-estimator-ste) can be viewed as a special case of smoothing.\nForward smoothing rules are provided through `ForwardDiff`, and backward rules through `ChainRules`, so that e.g. `Zygote.gradient` and `ForwardDiff.derivative` will use smoothed rules for discrete random variables rather than dropping the gradients entirely. \nCurrently, special discrete->discrete constructs such as array indexing are not supported for smoothing.\n\n\n## Optimization\n\nWe also provide utilities to make it easier to get started with forming and training a model via stochastic gradient descent:\n```@docs\nStochasticAD.StochasticModel\nStochasticAD.stochastic_gradient\n```\nThese are used in the [tutorial on stochastic optimization](tutorials/optimizations.md).\n"
  },
  {
    "path": "docs/src/tutorials/game_of_life.md",
    "content": "# Stochastic Game of Life\n\nWe consider a stochastic version of [Conway's Game of Life](https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life), played on a two-dimensional board. We shall use the following packages,\n```@setup game_of_life\nimport Pkg\nPkg.activate(\"../../../tutorials\")\nPkg.develop(path=\"../../..\")\nPkg.instantiate()\n```\n```@example game_of_life\nusing Distributions\nusing StochasticAD\nusing OffsetArrays \nusing StaticArrays\n```\n\n## Setting up the stochastic Game of Life\n\nEach turn, the standard Game of Life applies the following rules to each cell,\n```math\n\\text{dead and 3 neighbours alive} \\to \\text{ alive}, \\\\\n\\text{alive and 0, 1, or 4 neighbours alive} \\to \\text{ dead}.\n```\nThe cell's status does not change otherwise. In our stochastic version, these rules instead occur with probability `1-θ`, while the opposite event has probability `θ`. To initialize the board at the beginning of the game, we randomly set each cell alive with probability `p`. \n\nThe following high level function sets up the probabilities and provides them to `play_game_of_life`.\n```@example game_of_life\nfunction play(p, θ=0.1, N=12, T=10; log=false)\n    # N is the board half-length, T are game time steps\n    low = θ\n    high = 1-θ\n    birth_probs = SA[low, low, low, high, low] # 0, 1, 2, 3, 4 neighbours\n    death_probs = SA[high, high, low, low, high] # 0, 1, 2, 3, 4 neighbours \n    return play_game_of_life(p, vcat(birth_probs, death_probs), N, T; log)\nend\n```\nWe can now implement the Game of Life based on the specification. At the end of the game, we return the total number of alive cells.\n```@example game_of_life\n# A single turn of the game\nfunction update_state(all_probs, N, board_new, board_old)\n    for i in -N:N\n        for j in -N:N\n            neighbours = board_old[i+1, j] + board_old[i-1, j] + board_old[i, j-1] + board_old[i, j+1]\n            index = board_new[i,j] * 5 + neighbours + 1 \n            b = rand(Bernoulli(all_probs[index]))\n            board_new[i,j] += (1 - 2 * board_new[i,j]) * b \n        end\n    end\nend\n\nfunction play_game_of_life(p, all_probs, N, T; log=false)\n    dual_type = promote_type(typeof(rand(Bernoulli(p))), typeof.(rand.(Bernoulli.(all_probs)))...) # a hacky way of getting the correct array type \n    board = OffsetArray(zeros(dual_type, 2*N + 3, 2*N + 3), -(N+1):(N+1), -(N+1):(N+1)) # center board at (0,0), pad by 1 \n\n    # initialize the board\t\n    for i in -N:N\n        for j in -N:N\n            board[i,j] = rand(Bernoulli(p))\n        end\n    end\n    board_old = similar(board)\n    log && (history = [])\n\n    # play the game\n    for time_step in 1:T\n        copy!(board_old, board)\n        update_state(all_probs, N, board, board_old)\n        log && push!(history, copy(board))\n    end\n\n    if !log\n        return sum(board)\n    else\n        return sum(board), board, history\n    end\nend\n\nplay(0.5, 0.1) # play the game with p = 0.5 and θ = 0.1\n```\n\n!!! note \n    Note that we did have to be careful to write this program to be compatible with the [current capabilities of `StochasticAD`](../limitations.md). For example, we concatenated `birth_probs` and `death_probs` into a single array `all_probs` and used the index `board[i, j] * 5 + neighbours + 1` to find the probability, rather than use the more natural `if alive... else...` syntax.\n\n## Differentiating the Game of Life\n\nLet's differentiate the Game of Life!\n```@example game_of_life\n@show stochastic_triple(play, 0.5) # let's take a look at a single stochastic triple\n\nsamples = [derivative_estimate(play, 0.5) for i in 1:10000] # take many samples\nderivative = mean(samples)\nuncertainty = std(samples) / sqrt(10000)\nprintln(\"derivative of 𝔼[play(p)] = $derivative ± $uncertainty\")\n```\n\nThe following sketch of the final state of the board for a single run gives some insight into what the stochastic triples are doing. The original board is depicted in grey and white for dead and alive, and the cells which flip from dead to alive in the \"alternative\" path consider by the triples are marked with + signs, while the cells which flip from alive to dead are marked with X signs.\n\n```@raw html\n<img src=\"../images/final_gol_board.png\" width=\"50%\"/>\n``` ⠀\n\n\n\n\n"
  },
  {
    "path": "docs/src/tutorials/optimizations.md",
    "content": "# Stochastic optimizations with discrete randomness\n\n```@setup random_walk\nimport Pkg\nPkg.activate(\"../../../tutorials/toy_optimizations\")\nPkg.develop(path=\"../../..\")\nPkg.instantiate()\n```\n\nIn this tutorial, we solve two stochastic optimization problems using `StochasticAD` where the optimization objective is formed using discrete distributions. We will need the following packages:\n```@example optimizations\nusing Distributions # defines several supported discrete distributions \nusing StochasticAD\nusing CairoMakie # for plotting\nusing Optimisers # for stochastic gradient descent\n```\n\n## Optimizing our toy program\n\nRecall the \"crazy\" program from the intro:\n```@example optimizations\nfunction X(p)\n    a = p * (1 - p)\n    b = rand(Binomial(10, p))\n    c = 2 * b + 3 * rand(Bernoulli(p))\n    return a * c * rand(Normal(b, a))\nend\n```\n\nLet's maximize $\\mathbb{E}[X(p)]$! First, let's setup the problem, using the [`StochasticModel`](@ref) helper utility to create a trainable model:\n```@example optimizations\np0 = [0.5] # initial value of p, wrapped in an array for use in the stochastic model\nm = StochasticModel(p -> -X(p[1]), p0) # formulate as minimization problem\n```\nNow, let's perform stochastic gradient descent using [Adam](https://arxiv.org/abs/1412.6980), where we use [`stochastic_gradient`](@ref) to obtain a gradient of the model.\n```@example optimizations\niterations = 1000\ntrace = Float64[]\no = Adam() # use Adam for optimization\ns = Optimisers.setup(o, m)\nfor i in 1:iterations\n    # Perform a gradient step\n    Optimisers.update!(s, m, stochastic_gradient(m))\n    push!(trace, m.p[])\nend\np_opt = m.p[] # Our optimized value of p\n```\nFinally, let's plot the results of our optimization, and also perform a sweep through the parameter space to verify the accuracy of our estimator:\n```@example optimizations\n## Sweep through parameters to find average and derivative\nps = 0.02:0.02:0.98 # values of p to sweep\nN = 1000 # number of samples at each p\navg = [mean(X(p) for _ in 1:N) for p in ps]\nderivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]\n\n## Make plots\nf = Figure()\nax = f[1, 1] = Axis(f, title = \"Estimates\", xlabel=\"Value of p\")\nlines!(ax, ps, avg, label = \"≈ E[X(p)]\")\nlines!(ax, ps, derivative, label = \"≈ d/dp E[X(p)]\")\nvlines!(ax, [p_opt], label = \"p_opt\", color = :green, linewidth = 2.0)\nhlines!(ax, [0.0], color = :black, linewidth = 1.0)\nylims!(ax, (-50, 80))\n\nf[1, 2] = Legend(f, ax, framevisible = false)\nax = f[2, 1:2] = Axis(f, title = \"Optimizer trace\", xlabel=\"Iterations\", ylabel=\"Value of p\")\nlines!(ax, trace, color = :green, linewidth = 2.0)\nsave(\"crazy_opt.png\", f,  px_per_unit = 4) # hide\nnothing # hide\n```\n![](crazy_opt.png)\n\n## Solving a variational problem\n\nLet's consider a toy variational program: we find a [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) that is close to the distribution of a [negative Binomial](https://en.wikipedia.org/wiki/Negative_binomial_distribution), via minimization of the [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) $D_{\\mathrm{KL}}$. Concretely, let us solve\n```math\n\\underset{p \\in \\mathbb{R}}{\\operatorname{argmin}}\\; D_{\\mathrm{KL}}\\left(\\mathrm{Pois}(p) \\hspace{.3em}\\middle\\|\\hspace{.3em} \\mathrm{NBin}(10, 0.25) \\right).\n```\nThe following program produces an unbiased estimate of the objective:\n```@example optimizations\nfunction X(p)\n    i = rand(Poisson(p))\n    return logpdf(Poisson(p), i) - logpdf(NegativeBinomial(10, 0.25), i)\nend\n```\nWe can now optimize the KL-divergence via stochastic gradient descent!\n```@example optimizations\n# Minimize E[X] = KL(Poisson(p)| NegativeBinomial(10, 0.25))\niterations = 1000\np0 = [10.0]\nm = StochasticModel(p -> X(p[1]), p0)\ntrace = Float64[]\no = Adam(0.1)\ns = Optimisers.setup(o, m)\nfor i in 1:iterations\n    Optimisers.update!(s, m, stochastic_gradient(m))\n    push!(trace, m.p[])\nend\np_opt = m.p[]\n```\nLet's plot our results in the same way as before:\n```@example optimizations\nps = 10:0.5:50\nN = 1000\navg = [mean(X(p) for _ in 1:N) for p in ps]\nderivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]\nf = Figure()\nax = f[1, 1] = Axis(f, title = \"Estimates\", xlabel=\"Value of p\")\nlines!(ax, ps, avg, label = \"≈ E[X(p)]\")\nlines!(ax, ps, derivative, label = \"≈ d/dp E[X(p)]\")\nvlines!(ax, [p_opt], label = \"p_opt\", color = :green, linewidth = 2.0)\nhlines!(ax, [0.0], color = :black, linewidth = 1.0)\nylims!(ax, (-2.5, 5))\n\nf[1, 2] = Legend(f, ax, framevisible = false)\nax = f[2, 1:2] = Axis(f, title = \"Optimizer trace\", ylabel=\"Value of p\", xlabel=\"Iterations\")\nlines!(ax, trace, color = :green, linewidth = 2.0)\nsave(\"variational.png\", f, px_per_unit = 4) # hide\nnothing # hide\n```\n![](variational.png)\n"
  },
  {
    "path": "docs/src/tutorials/particle_filter.md",
    "content": "# Differentiable particle filter\n\nUsing a bootstrap particle sampler, we can approximate the posterior distributions\nof the states given noisy and partial observations of the state of a hidden Markov\nmodel by a cloud of `K` weighted particles with weights `W`.\n\nIn this tutorial, we are going to:\n- implement a differentiable particle filter based on `StochasticAD.jl`.\n- visualize the particle filter in ``d = 2`` dimensions.\n- compare the gradient based on the differentiable particle filter to a biased\n  gradient estimator as well as to the gradient of a differentiable Kalman filter.\n- show how to benchmark primal evaluation, forward- and reverse-mode AD of the\n  particle filter.\n\n## Setup\n\nWe will make use of several julia packages. For example, we are going to use\n`Distributions` and `DistributionsAD` that implement the reparameterization trick\nfor Gaussian distributions used in the observation and state-transition model, which\nwe specify below. We also import `GaussianDistributions.jl` to implement the\ndifferentiable Kalman filter.\n\n### Package dependencies\n\n```@setup particle_filter\nimport Pkg\nPkg.activate(\"../../../tutorials\")\nPkg.develop(path=\"../../..\")\nPkg.instantiate()\n```\n\n```@example particle_filter\n# activate tutorial project file\n\n# load dependencies\nusing StochasticAD\nusing Distributions\nusing DistributionsAD\nusing Random\nusing Statistics\nusing StatsBase\nusing LinearAlgebra\nusing Zygote\nusing ForwardDiff\nusing GaussianDistributions\nusing GaussianDistributions: correct, ⊕\nusing Measurements\nusing UnPack\nusing Plots\nusing LaTeXStrings\nusing BenchmarkTools\n```\n\n### Particle filter\n\nFor convenience, we first introduce the new type `StochasticModel` with the following\nfields:\n\n- `T`: total number of time steps.\n- `start`: starting distribution for the initial state. For example, in the form of a narrow\n   Gaussian `start(θ) = Gaussian(x0, 0.001 * I(d))`.\n- `dyn`: pointwise differentiable stochastic program in the form of Markov transition densities.\n   For example, `dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q(θ))`, where `Q(θ)` denotes the\n   covariance matrix.\n- `obs`: observation model having a smooth conditional probability density depending on\n   current state `x` and parameters `θ`. For example, `obs(x, θ) = MvNormal(x, R(θ))`,\n   where `R(θ)` denotes the covariance matrix.\n\nFor parameters `θ`,  `rand(start(θ))` gives a sample from the prior distribution of the\nstarting distribution. For current state `x` and parameters `θ`, `xnew = rand(dyn(x, θ))`\nsamples the new state (i.e. `dyn` gives for each `x, θ` a distribution-like object). Finally,\n`y = rand(obs(x, θ))` samples an observation.\n\nWe can then define the `ParticleFilter` type that wraps a stochastic model `StochM::StochasticModel`,\na sampling strategy (with arguments `p, K, sump=1`) and observational data `ys`.\nFor simplicity, our implementation assumes a observation-likelihood function being available\nvia `pdf(obs(x, θ), y)`.\n\n```@example particle_filter\nstruct StochasticModel{TType<:Integer,T1,T2,T3}\n    T::TType # time steps\n    start::T1 # prior\n    dyn::T2 # dynamical model\n    obs::T3 # observation model\nend\n\nstruct ParticleFilter{mType<:Integer,MType<:StochasticModel,yType,sType}\n    m::mType # number of particles\n    StochM::MType # stochastic model\n    ys::yType # observations\n    sample_strategy::sType # sampling function\nend\n```\n\n### Kalman filter\n\nWe consider a stochastic program that fulfills the assumptions of a Kalman filter.\nWe follow [Kalman.jl](https://github.com/mschauer/Kalman.jl/blob/master/README.md) to implement a differentiable version.\nOur `KalmanFilter` type wraps a stochastic model `StochM::StochasticModel` and observational data `ys`. It assumes a\nobservation-likelihood function is implemented via `llikelihood(yres, S)`. The Kalman filter\ncontains the following fields:\n\n- `d`: dimension of the state-transition matrix ``\\Phi`` according to ``x = \\Phi x + w`` with ``w \\sim \\operatorname{Normal}(0,Q)``.\n- `StochM`: Stochastic model of type `StochasticModel`.\n- `H`: linear map from the state space into the observed space according to ``y = H x + \\nu`` with ``\\nu \\sim \\operatorname{Normal}(0,R)``.\n- `R`: covariance matrix entering the observation model according to ``y = H x + \\nu`` with ``\\nu \\sim \\operatorname{Normal}(0,R)``.\n- `Q`: covariance matrix entering the state-transition model according to ``x = \\Phi x + w`` with ``w \\sim \\operatorname{Normal}(0,Q)``.\n- `ys`: observations.\n\n\n```@example particle_filter\nllikelihood(yres, S) = GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres)\nstruct KalmanFilter{dType<:Integer,MType<:StochasticModel,HType,RType,QType,yType}\n    # H, R = obs\n    # θ, Q = dyn\n    d::dType\n    StochM::MType # stochastic model\n    H::HType # observation model, maps the true state space into the observed space\n    R::RType # observation model, covariance matrix\n    Q::QType # dynamical model, covariance matrix\n    ys::yType # observations\nend\n```\n\nTo get observations `ys` from the latent states `xs` based on the\n(true, potentially unknown) parameters `θ`, we simulate a single particle\nfrom the forward model returning a vector of observations (no resampling steps).\n\n```@example particle_filter\nfunction simulate_single(StochM::StochasticModel, θ)\n    @unpack T, start, dyn, obs = StochM\n    x = rand(start(θ))\n    y = rand(obs(x, θ))\n    xs = [x]\n    ys = [y]\n    for t in 2:T\n        x = rand(dyn(x, θ))\n        y = rand(obs(x, θ))\n        push!(xs, x)\n        push!(ys, y)\n    end\n    xs, ys\nend\n```\n\nA particle filter becomes efficient if resampling steps are included. Resampling\nis numerically attractive because particles with small weight are discarded, so\ncomputational resources are not wasted on particles with vanishing weight.\n\nHere, let us implement a stratified resampling strategy, see for example\n[Murray (2012)](https://arxiv.org/abs/1202.6163), where `p` denotes the probabilities of `K` particles\nwith `sump = sum(p)`.\n\n```@example particle_filter\nfunction sample_stratified(p, K, sump=1)\n    n = length(p)\n    U = rand()\n    is = zeros(Int, K)\n    i = 1\n    cw = p[1]\n    for k in 1:K\n        t = sump * (k - 1 + U) / K\n        while cw < t && i < n\n            i += 1\n            @inbounds cw += p[i]\n        end\n        is[k] = i\n    end\n    return is\nend\n```\n\nThis sampling strategy can be used within a differentiable resampling step in our\nparticle filter using the `use_new_weight` function as implemented in\n`StochasticAD.jl`. The `resample` function below returns the states `X_new`\nand weights `W_new` of the resampled particles.\n\n- `m`: number of particles.\n- `X`: current particle states.\n- `W`: current weight vector of the particles.\n- `ω == sum(W)` is an invariant.\n- `sample_strategy`: specific resampling strategy to be used. For example, `sample_stratified`.\n- `use_new_weight=true`: Allows one to switch between biased, stop-gradient method and\n   differentiable resampling step.\n\n```@example particle_filter\nfunction resample(m, X, W, ω, sample_strategy, use_new_weight=true)\n    js = Zygote.ignore(() -> sample_strategy(W, m, ω))\n    X_new = X[js]\n    if use_new_weight\n        # differentiable resampling\n        W_chosen = W[js]\n        W_new = map(w -> ω * new_weight(w / ω) / m, W_chosen)\n    else\n        # stop gradient, biased approach\n        W_new = fill(ω / m, m)\n    end\n    X_new, W_new\nend\n```\n\nNote that we added a `if` condition that allows us to switch between the differentiable\nresampling step and the stop-gradient approach.\n\nWe're now equipped with all primitive operations to set up the particle filter,\nwhich propagates particles with weights `W` preserving the invariant `ω == sum(W)`.\nWe never normalize `W` and, therefore, `ω` in the code below contains likelihood\ninformation. The particle-filter implementation defaults to return particle\npositions and weights at `T` if `store_path=false` and takes the following input\narguments:\n\n- `θ`: parameters for the stochastic program (state-transition and observation model).\n- `store_path=false`: Option to store the path of the particles, e.g. to visualize/inspect\n  their trajectories.\n- `use_new_weight=true`: Option to switch between the stop-gradient and our differentiable\n  resampling step method. Defaults to using differentiable resampling.\n- `s`: controls the number of resampling steps according to `t > 1 && t < T && (t % s == 0)`.\n\n\n```@example particle_filter\nfunction (F::ParticleFilter)(θ; store_path=false, use_new_weight=true, s=1)\n    # s controls the number of resampling steps\n    @unpack m, StochM, ys, sample_strategy = F\n    @unpack T, start, dyn, obs = StochM\n\n\n    X = [rand(start(θ)) for j in 1:m] # particles\n    W = [1 / m for i in 1:m] # weights\n    ω = 1 # total weight\n    store_path && (Xs = [X])\n    for (t, y) in zip(1:T, ys)\n        # update weights & likelihood using observations\n        wi = map(x -> pdf(obs(x, θ), y), X)\n        W = W .* wi\n        ω_old = ω\n        ω = sum(W)\n        # resample particles\n        if t > 1 && t < T && (t % s == 0) # && 1 / sum((W / ω) .^ 2) < length(W) ÷ 32\n            X, W = resample(m, X, W, ω, sample_strategy, use_new_weight)\n        end\n        # update particle states\n        if t < T\n            X = map(x -> rand(dyn(x, θ)), X)\n            store_path && Zygote.ignore(() -> push!(Xs, X))\n        end\n    end\n    (store_path ? Xs : X), W\nend\n```\n\nFollowing [Kalman.jl](https://github.com/mschauer/Kalman.jl/blob/master/README.md), we implement\na differentiable Kalman filter to check the ground-truth gradient. Our Kalman filter\nreturns an updated posterior state estimate and the log-likelihood and takes the\nparameters of the stochastic program as an input.\n\n```@example particle_filter\nfunction (F::KalmanFilter)(θ)\n    @unpack d, StochM, H, R, Q = F\n    @unpack start = StochM\n\n    x = start(θ)\n    Φ = reshape(θ, d, d)\n\n    x, yres, S = GaussianDistributions.correct(x, ys[1] + R, H)\n    ll = llikelihood(yres, S)\n    xs = Any[x]\n    for i in 2:length(ys)\n        x = Φ * x ⊕ Q\n        x, yres, S = GaussianDistributions.correct(x, ys[i] + R, H)\n        ll += llikelihood(yres, S)\n\n        push!(xs, x)\n    end\n    xs, ll\nend\n```\n\nFor both filters, it is straightforward to obtain the log-likelihood via:\n\n```@example particle_filter\nfunction log_likelihood(F::ParticleFilter, θ, use_new_weight=true, s=1)\n    _, W = F(θ; store_path=false, use_new_weight=use_new_weight, s=s)\n    log(sum(W))\nend\n```\nand\n```@example particle_filter\nfunction log_likelihood(F::KalmanFilter, θ)\n    _, ll = F(θ)\n    ll\nend\n```\n\nFor convenience, we define functions for\n- forward-mode AD (and differentiable resampling step) to compute the gradient of\n  the log-likelihood of the particle filter.\n- reverse-mode AD (and differentiable resampling step) to compute the gradient of\n  the log-likelihood of the particle filter.\n- forward-mode AD (and stop-gradient method) to compute the gradient of\n  the log-likelihood of the particle filter (without the `new_weight` function).\n- forward-mode AD to compute the gradient of the log-likelihood of the Kalman filter.\n\n```@example particle_filter\n\nforw_grad(θ, F::ParticleFilter; s=1) = ForwardDiff.gradient(θ -> log_likelihood(F, θ, true, s), θ)\nback_grad(θ, F::ParticleFilter; s=1) = Zygote.gradient(θ -> log_likelihood(F, θ, true, s), θ)[1]\nforw_grad_biased(θ, F::ParticleFilter; s=1) = ForwardDiff.gradient(θ -> log_likelihood(F, θ, false, s), θ)\nforw_grad_Kalman(θ, F::KalmanFilter) = ForwardDiff.gradient(θ -> log_likelihood(F, θ), θ)\n```\n\n## Model\n\nHaving set up all core functionalities, we can now define the specific stochastic\nmodel.\n\nWe consider the following system with a ``d``-dimensional latent process,\n\n```math\n\\begin{aligned}\nx_i &= \\Phi x_{i-1} + w_i &\\text{ with } w_i \\sim \\operatorname{Normal}(0,Q),\\\\\ny_i &= x_i + \\nu_i &\\text{ with } \\nu_i \\sim \\operatorname{Normal}(0,R),\n\\end{aligned}\n```\n\nwhere ``\\Phi`` is a ``d``-dimensional rotation matrix.\n\n```@example particle_filter\nseed = 423897\n\n### Define model\n# here: n-dimensional rotation matrix\nRandom.seed!(seed)\nT = 20 # time steps\nd = 2 # dimension\n# generate a rotation matrix\nM = randn(d, d)\nc = 0.3 # scaling\nO = exp(c * (M - transpose(M)) / 2)\n@assert det(O) ≈ 1\n@assert transpose(O) * O ≈ I(d)\nθtrue = vec(O) # true parameter\n\n# observation model\nR = 0.01 * collect(I(d))\nobs(x, θ) = MvNormal(x, R) # y = H x + ν with ν ~ Normal(0, R)\n\n# dynamical model\nQ = 0.02 * collect(I(d))\ndyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q) #  x = Φ*x + w with w ~ Normal(0,Q)\n\n# starting position\nx0 = randn(d)\n# prior distribution\nstart(θ) = Gaussian(x0, 0.001 * collect(I(d)))\n\n# put it all together\nstochastic_model = StochasticModel(T, start, dyn, obs)\n\n# relevant corresponding Kalman filterng defs\nH_Kalman = collect(I(d))\nR_Kalman = Gaussian(zeros(Float64, d), R)\n# Φ_Kalman = O\nQ_Kalman = Gaussian(zeros(Float64, d), Q)\n###\n\n### simulate model\nRandom.seed!(seed)\nxs, ys = simulate_single(stochastic_model, θtrue)\n```\n\n## Visualization\n\nUsing `particle_filter(θ; store_path=true)` and `kalman_filter(θ)`, it is\nstraightforward to visualize both filters for our observed data.\n\n```@example particle_filter\nm = 1000\nkalman_filter = KalmanFilter(d, stochastic_model, H_Kalman, R_Kalman, Q_Kalman, ys)\nparticle_filter = ParticleFilter(m, stochastic_model, ys, sample_stratified)\n```\n\n\n```@example particle_filter\n### run and visualize filters\nXs, W = particle_filter(θtrue; store_path=true)\nfig = plot(getindex.(xs, 1), getindex.(xs, 2), legend=false, xlabel=L\"x_1\", ylabel=L\"x_2\") # x1 and x2 are bad names..conflicting notation\nscatter!(fig, getindex.(ys, 1), getindex.(ys, 2))\nfor i in 1:min(m, 100) # note that Xs has obs noise.\n    local xs = [Xs[t][i] for t in 1:T]\n    scatter!(fig, getindex.(xs, 1), getindex.(xs, 2), marker_z=1:T, color=:cool, alpha=0.1) # color to indicate time step\nend\n\nxs_Kalman, ll_Kalman = kalman_filter(θtrue)\nplot!(getindex.(mean.(xs_Kalman), 1), getindex.(mean.(xs_Kalman), 2), legend=false, color=\"red\")\npng(\"pf_1\") # hide\n```\n![](pf_1.png)\n\n## Bias\n\nWe can also investigate the distribution of the gradients from the particle filter\nwith and without differentiable resampling step, as compared to the gradient computed\nby differentiating the Kalman filter.\n\n```@example particle_filter\n### compute gradients\nRandom.seed!(seed)\nX = [forw_grad(θtrue, particle_filter) for i in 1:200] # gradient of the particle filter *with* differentiation of the resampling step\nRandom.seed!(seed)\nXbiased = [forw_grad_biased(θtrue, particle_filter) for i in 1:200] # Gradient of the particle filter *without* differentiation of the resampling step\n# pick an arbitrary coordinate\nindex = 1 # take derivative with respect to first parameter (2-dimensional example has a rotation matrix with four parameters in total)\n# plot histograms for the sampled derivative values\nfig = plot(normalize(fit(Histogram, getindex.(X, index), nbins=20), mode=:pdf), legend=false) # ours\nplot!(normalize(fit(Histogram, getindex.(Xbiased, index), nbins=20), mode=:pdf)) # biased\nvline!([mean(X)[index]], color=1)\nvline!([mean(Xbiased)[index]], color=2)\n# add derivative of differentiable Kalman filter as a comparison\nXK = forw_grad_Kalman(θtrue, kalman_filter)\nvline!([XK[index]], color=\"black\")\npng(\"pf_2\") # hide\n```\n![](pf_2.png)\n\nThe estimator using the `new_weight` function agrees with the gradient value from\nthe Kalman filter and the [particle filter AD scheme developed by Ścibior and Wood](https://arxiv.org/abs/2106.10314),\nunlike biased estimators that neglect the contribution of the derivative from the\nresampling step. However, the biased estimator displays a smaller variance.\n\n## Benchmark\n\nFinally, we can use `BenchmarkTools.jl` to benchmark the run times of the primal\npass with respect to forward-mode and reverse-mode AD of the particle filter. As\nexpected, forward-mode AD outperforms reverse-mode AD for the small number of\nparameters considered here.\n\n```@example particle_filter\n# secs for how long the benchmark should run, see https://juliaci.github.io/BenchmarkTools.jl/stable/\nsecs = 1\n\nsuite = BenchmarkGroup()\nsuite[\"scaling\"] = BenchmarkGroup([\"grads\"])\n\nsuite[\"scaling\"][\"primal\"] = @benchmarkable log_likelihood(particle_filter, θtrue)\nsuite[\"scaling\"][\"forward\"] = @benchmarkable forw_grad(θtrue, particle_filter)\nsuite[\"scaling\"][\"backward\"] = @benchmarkable back_grad(θtrue, particle_filter)\n\ntune!(suite)\nresults = run(suite, verbose=true, seconds=secs)\n\nt1 = measurement(mean(results[\"scaling\"][\"primal\"].times), std(results[\"scaling\"][\"primal\"].times) / sqrt(length(results[\"scaling\"][\"primal\"].times)))\nt2 = measurement(mean(results[\"scaling\"][\"forward\"].times), std(results[\"scaling\"][\"forward\"].times) / sqrt(length(results[\"scaling\"][\"forward\"].times)))\nt3 = measurement(mean(results[\"scaling\"][\"backward\"].times), std(results[\"scaling\"][\"backward\"].times) / sqrt(length(results[\"scaling\"][\"backward\"].times)))\n@show t1 t2 t3\n\nts = (t1, t2, t3) ./ 10^6 # ms\n@show ts\n```\n"
  },
  {
    "path": "docs/src/tutorials/random_walk.md",
    "content": "# Random walk\n\n```@setup random_walk\nimport Pkg\nPkg.activate(\"../../../tutorials\")\nPkg.develop(path=\"../../..\")\nPkg.instantiate()\n```\n\nIn this tutorial, we differentiate a random walk over the integers using `StochasticAD`. We will need the following packages,\n\n```@example random_walk\nusing Distributions # defines several supported discrete distributions \nusing StochasticAD\nusing StaticArrays # for more efficient small arrays\n```\n\n## Setting up the random walk\n\nLet's define a function for simulating the walk.\n```@example random_walk\nfunction simulate_walk(probs, steps, n)\n    state = 0\n    for i in 1:n\n        probs_here = probs(state) # transition probabilities for possible steps\n        step_index = rand(Categorical(probs_here)) # which step do we take?\n        step = steps[step_index] # get size of step \n        state += step\n    end\n    return state\nend\n```\nHere, `steps` is a (1-indexed) array of the possible steps we can take. Each of these steps has a certain probability. To make things more interesting, we take in a *function* `probs` to produce these probabilities that can depend on the current state of the random walk.\n\nLet's zoom in on the two lines where discrete randomness is involved. \n```\nstep_index = rand(Categorical(probs_here)) # which step do we take?\nstep = steps[step_index] # get size of step \n```\nThis is a cute pattern for making a discrete choice. First, we sample from a `Categorical` distribution from `Distributions.jl`, using the probabilities `probs_here` at our current position. This gives us an index between `1` and `length(steps)`, which we can use to pick the actual step to take. Stochastic triples propagate through both steps!\n\n## Differentiating the random walk\n\nLet's define a toy problem. We consider a random walk with `-1` and `+1` steps, where the probability of `+1` starts off high but decays exponentially with a decay length of `p`. We take `n = 100` steps and set `p = 50`.\n```@example random_walk\nusing StochasticAD\n\nconst steps = SA[-1, 1] # move left or move right\nmake_probs(p) = X -> SA[1 - exp(-X / p), exp(-X / p)]\n\nf(p, n) = simulate_walk(make_probs(p), steps, n)\n@show f(50, 100) # let's run a single random walk with p = 50\n@show stochastic_triple(p -> f(p, 100), 50) # let's see how a single stochastic triple looks like at p = 50\n```\nTime to differentiate! For fun, let's differentiate the *square* of the output of the random walk.\n```@example random_walk\nf_squared(p, n) = f(p, n)^2\n\nsamples = [derivative_estimate(p -> f_squared(p, 100), 50) for i in 1:1000] # many samples from derivative program at p = 50\nderivative = mean(samples)\nuncertainty = std(samples) / sqrt(1000)\nprintln(\"derivative of 𝔼[f_squared] = $derivative ± $uncertainty\")\n```\n\n## Computing variance\n\nA crucial figure of merit for a derivative estimator is its variance. We compute the standard deviation (square root of the variance) of our estimator over a range of `n`.\n```@example random_walk\nn_range = 10:10:100 # range for testing asymptotic variance behaviour\np_range = 2 .* n_range\nnsamples = 10000\n\nstds_triple = Float64[]\nfor (n, p) in zip(n_range, p_range)\n    std_triple = std(derivative_estimate(p -> f_squared(p, n), p)\n                     for i in 1:(nsamples))\n    push!(stds_triple, std_triple)\nend\n@show stds_triple\n```\nFor comparison with other unbiased estimators, we also compute `stds_score` and `stds_score_baseline` for the\n[score function gradient estimator](https://arxiv.org/pdf/1906.10652.pdf), both without and with a variance-reducing batch-average control variate (CV). (For details, see [`core.jl`](https://github.com/gaurav-arya/StochasticAD.jl/blob/main/tutorials/random_walk/core.jl) and [`compare_score.jl`](https://github.com/gaurav-arya/StochasticAD.jl/blob/main/random_walk/compare_score.jl).) We can now graph the standard deviation of each estimator versus $n$, observing lower variance in the unbiased derivative estimate produced by stochastic triples:\n\n```@raw html\n<img src=\"../images/compare_score.png\" width=\"50%\"/>\n``` ⠀\n\n"
  },
  {
    "path": "docs/src/tutorials/reverse_demo.md",
    "content": "```@meta\nEditURL = \"../../../tutorials/reverse_example/reverse_demo.jl\"\n```\n\n# Simple reverse mode example\n\n```@setup random_walk\nimport Pkg\nPkg.activate(\"../../../tutorials\")\nPkg.develop(path=\"../../..\")\nPkg.instantiate()\n\nimport Random\nRandom.seed!(1234)\n```\n\nLoad our packages\n\n````@example reverse_demo\nusing StochasticAD\nusing Distributions\nusing Enzyme\nusing LinearAlgebra\n````\n\nLet us define our target function.\n\n````@example reverse_demo\n# Define a toy `StochasticAD`-differentiable function for computing an integer value from a string.\nstring_value(strings, index) = Int(sum(codepoint, strings[index]))\nstring_value(strings, index::StochasticTriple) = StochasticAD.propagate(index -> string_value(strings, index), index)\n\nfunction f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling())\n    strings = [\"cat\", \"dog\", \"meow\", \"woofs\"]\n    index = randst(Categorical(θ); derivative_coupling)\n    return string_value(strings, index)\nend\n\nθ = [0.1, 0.5, 0.3, 0.1]\n@show f(θ)\nnothing\n````\n\nFirst, let's compute the sensitivity of `f` in a particular direction via forward-mode Stochastic AD.\n\n````@example reverse_demo\nu = [1.0, 2.0, 4.0, -7.0]\n@show derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)\nnothing\n````\n\nNow, let's do the same with reverse-mode.\n\n````@example reverse_demo\n@show derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))\n````\n\nLet's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative.\n\n````@example reverse_demo\nforward() = derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)\nreverse() = derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))\n\nN = 40000\ndirectional_derivs_fwd = [forward() for i in 1:N]\nderivs_bwd = [reverse() for i in 1:N]\ndirectional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd]\nprintln(\"Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))\")\nprintln(\"Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))\")\n@assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2)\n\nnothing\n````\n\n---\n\n*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*\n\n"
  },
  {
    "path": "ext/StochasticADEnzymeExt.jl",
    "content": "module StochasticADEnzymeExt\n\nusing StochasticAD\nusing Enzyme\n\nfunction enzyme_target(u, X, p, backend)\n    # equivalent to derivative_estimate(X, p; backend, direction = u), but specialize to real output to make Enzyme happier\n    st = StochasticAD.stochastic_triple_direction(X, p, u; backend)\n    if !(StochasticAD.valtype(st) <: Real)\n        error(\"EnzymeReverseAlgorithm only supports real-valued outputs.\")\n    end\n    return derivative_contribution(st)\nend\n\nfunction StochasticAD.derivative_estimate(X, p, alg::StochasticAD.EnzymeReverseAlgorithm;\n        direction = nothing, alg_data = (; forward_u = nothing))\n    if !isnothing(direction)\n        error(\"EnzymeReverseAlgorithm does not support keyword argument `direction`\")\n    end\n    if p isa AbstractVector\n        Δu = zeros(float(eltype(p)), length(p))\n        u = isnothing(alg_data.forward_u) ?\n            rand(StochasticAD.RNG, float(eltype(p)), length(p)) : alg_data.forward_u\n        autodiff(Enzyme.Reverse, enzyme_target, Active, Duplicated(u, Δu),\n            Const(X), Const(p), Const(alg.backend))\n        return Δu\n    elseif p isa Real\n        u = isnothing(alg_data.forward_u) ? rand(StochasticAD.RNG, float(typeof(p))) :\n            forward_u\n        ((du, _, _, _),) = autodiff(Enzyme.Reverse, enzyme_target, Active, Active(u),\n            Const(X), Const(p), Const(alg.backend))\n        return du\n    else\n        error(\"EnzymeReverseAlgorithm only supports p::Real or p::AbstractVector\")\n    end\nend\n\nend\n"
  },
  {
    "path": "src/StochasticAD.jl",
    "content": "module StochasticAD\n\n### Public API\n\nexport stochastic_triple, derivative_contribution, perturbations, smooth_triple,\n       dual_number, StochasticTriple # For working with stochastic triples\nexport derivative_estimate, StochasticModel, stochastic_gradient # Higher level functionality\nexport new_weight # Particle resampling\nexport PrunedFIsBackend,\n       PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend,\n       StrategyWrapperFIsBackend\nexport PrunedFIs, PrunedFIsAggressive, DictFIs, SmoothedFIs, StrategyWrapperFIs\nexport randst\nexport InversionMethodDerivativeCoupling\n\n### Imports\n\nusing Random\nusing Distributions\nusing DistributionsAD\nusing ChainRulesCore\nusing ChainRulesOverloadGeneration\nusing ExprTools\nusing ForwardDiff\nusing Functors\nimport ChainRulesCore\n# resolve conflicts while this code exists in both.\nconst on_new_rule = ChainRulesOverloadGeneration.on_new_rule\nconst refresh_rules = ChainRulesOverloadGeneration.refresh_rules\n\nconst RNG = copy(Random.default_rng())\n\n### Files responsible for backends\n\ninclude(\"finite_infinitesimals.jl\")\ninclude(\"backends/pruned.jl\")\ninclude(\"backends/pruned_aggressive.jl\")\ninclude(\"backends/dict.jl\")\ninclude(\"backends/smoothed.jl\")\ninclude(\"backends/abstract_wrapper.jl\")\ninclude(\"backends/strategy_wrapper.jl\")\nusing .PrunedFIsModule\nusing .PrunedFIsAggressiveModule\nusing .DictFIsModule\nusing .SmoothedFIsModule\nusing .AbstractWrapperFIsModule\nusing .StrategyWrapperFIsModule\n\ninclude(\"prelude.jl\") # Defines global constants\ninclude(\"smoothing.jl\") # Smoothing rules. Placed before general rules so that new_weight frule is caught by overload generation.\ninclude(\"stochastic_triple.jl\") # Defines stochastic triple object and higher level functions\ninclude(\"general_rules.jl\") # Defines rules for propagation through deterministic functions\ninclude(\"discrete_randomness.jl\") # Defines rules for propagation through discrete random functions\ninclude(\"propagate.jl\") # Experimental generalized forward propagation functionality\ninclude(\"algorithms.jl\") # Add algorithm-based higher-level interface \ninclude(\"misc.jl\") # Miscellaneous functions that do not fit in the usual flow\n\nend\n"
  },
  {
    "path": "src/algorithms.jl",
    "content": "abstract type AbstractStochasticADAlgorithm end\n\n\"\"\"\n    ForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm\n    \nA differentiation algorithm relying on forward propagation of stochastic triples.\n\nThe `backend` argument controls the algorithm used by the third component of the stochastic triples.\n\n!!! note \n    The required computation time for forward-mode AD scales linearly with the number of \n    parameters in `p` (but is unaffected by the number of parameters in `X(p)`).\n\"\"\"\nstruct ForwardAlgorithm{B <: StochasticAD.AbstractFIsBackend} <:\n       AbstractStochasticADAlgorithm\n    backend::B\nend\n\n\"\"\"\n    EnzymeReverseAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm\n\nA differentiation algorithm relying on transposing the propagation of stochastic triples to\nproduce a reverse-mode algorithm. The transposition is performed by [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl),\nwhich must be loaded for the algorithm to run.\n\nCurrently, only real- and vector-valued inputs are supported, and only real-valued outputs are supported.\n\nThe `backend` argument controls the algorithm used by the third component of the stochastic triples.\n\nIn the call to `derivative_estimate`, this algorithm optionally accepts `alg_data` with the field `forward_u`,\nwhich specifies the directional derivative used in the forward pass that will be transposed. \nIf `forward_u` is not provided, it is randomly generated.\n\n!!! warning\n    For the reverse-mode algorithm to yield correct results, the employed `backend` cannot use input-dependent pruning  \n    strategies. A suggested reverse-mode compatible backend is `PrunedFIsBackend(Val(:wins))`.\n    \n    Additionally, this algorithm relies on the ability of `Enzyme.jl` to differentiate the forward stochastic triple run.\n    It is recommended to check that the primal function `X` is type stable for its input `p` using a tool such as\n    [JET.jl](https://github.com/aviatesk/JET.jl), with all code executed in a function with no global state. \n    In addition, sometimes `X` may be type stable but stochastic triples introduce additional type instabilities.\n    This can be debugged by checking type stability of Enzyme's target, which is\n    `Base.get_extension(StochasticAD, :StochasticADEnzymeExt).enzyme_target(u, X, p, backend)`,\n    where `u` is a test direction.\n    \n!!! note\n    For more details on the reverse-mode approach, see the following papers and talks:\n    \n    * [\"You Only Linearize Once: Tangents Transpose to Gradients\"](https://arxiv.org/abs/2204.10923), Radul et al. 2022.\n    * [\"Reverse mode ADEV via YOLO: tangent estimators transpose to gradient estimators\"](https://www.youtube.com/watch?v=pnPmk-leSsE), Becker et al. 2024\n    * [\"Probabilistic Programming with Programmable Variational Inference\"](https://pldi24.sigplan.org/details/pldi-2024-papers/87/Probabilistic-Programming-with-Programmable-Variational-Inference), Becker et al. 2024\n\"\"\"\nstruct EnzymeReverseAlgorithm{B <: StochasticAD.AbstractFIsBackend}\n    backend::B\nend\n\nfunction derivative_estimate(\n        X, p, alg::ForwardAlgorithm; direction = nothing, alg_data::NamedTuple = (;))\n    return derivative_estimate(X, p; backend = alg.backend, direction)\nend\n\n@doc raw\"\"\"\n    derivative_estimate(X, p, alg::AbstractStochasticADAlgorithm = ForwardAlgorithm(PrunedFIsBackend()); direction=nothing, alg_data::NamedTuple = (;))\n\nCompute an unbiased estimate of ``\\frac{\\mathrm{d}\\mathbb{E}[X(p)]}{\\mathrm{d}p}``, \nthe derivative of the expectation of the random function `X(p)` with respect to its input `p`.\n\nBoth `p` and `X(p)` can be any object supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/),\ne.g. scalars or abstract arrays. \nThe output of `derivative_estimate` has the same outer structure as `p`, but with each\nscalar in `p` replaced by a derivative estimate of `X(p)` with respect to that entry.\nFor example, if `X(p) <: AbstractMatrix` and `p <: Real`, then the output would be a matrix.\n\nThe `alg` keyword argument specifies the [algorithm](public_api.md#Algorithms) used to compute the derivative estimate.\nFor backward compatibility, an additional signature `derivative_estimate(X, p; backend, direction=nothing)`\nis supported, which uses `ForwardAlgorithm` by default with the supplied `backend.`\nThe `alg_data` keyword argument can specify any additional data that specific algorithms accept or require.\n\nWhen `direction` is provided, the output is only differentiated with respect to a perturbation\nof `p` in that direction.\n\n# Example\n```jldoctest\njulia> using Distributions, Random, StochasticAD; Random.seed!(4321);\n\njulia> derivative_estimate(rand ∘ Bernoulli, 0.5) # A random quantity that averages to the true derivative.\n2.0\n\njulia> derivative_estimate(x -> [rand(Bernoulli(x * i/4)) for i in 1:3], 0.5)\n3-element Vector{Float64}:\n 0.2857142857142857\n 0.6666666666666666\n 0.0\n```\n\"\"\"\nderivative_estimate\n"
  },
  {
    "path": "src/backends/abstract_wrapper.jl",
    "content": "module AbstractWrapperFIsModule\n\nimport ..StochasticAD\n\nexport AbstractWrapperFIs\n\n\"\"\"\n    AbstractWrapperFIs{V, FIs} <: StochasticAD.AbstractFIs{V}\n\nA convenience type for backend strategies that wrap another backend. A subtype `WrapperFIs <: AbstractWrapperFIs`\nshould have a field called Δs containing the wrapped backend, and should also define the following methods:\n* `StochasticAD.similar_type(::Type{<:WrapperFIs}, V, FIs)`: return the type of a new\n    `WrapperFIs` with value type `V` and wrapped backend type `FIs`,\n* `AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::WrapperFIs, Δs::AbstractFIs)`: construct\na new `WrapperFIs` wrapping `Δs` given an existing wrapped instace `wrapper_Δs`. \n* `AbstractWrapperFIsModule.reconstruct_wrapper(::Type{<:WrapperFIs}, Δs::AbstractFIs)`: construct\na new `WrapperFIs` wrapping `Δs` given the type of an existing `WrapperFIs`.\n\nThen, all other methods will generically be forwarded to the inner backend, except those overloaded by the\nspecific wrapper type.\n\"\"\"\nabstract type AbstractWrapperFIs{V, FIs} <: StochasticAD.AbstractFIs{V} end\n\nfunction reconstruct_wrapper end\n\nfunction StochasticAD.similar_new(Δs::AbstractWrapperFIs, Δ, w)\n    reconstruct_wrapper(Δs, StochasticAD.similar_new(Δs.Δs, Δ, w))\nend\nfunction StochasticAD.similar_empty(Δs::AbstractWrapperFIs, V)\n    reconstruct_wrapper(Δs, StochasticAD.similar_empty(Δs.Δs, V))\nend\n\nfunction StochasticAD.similar_type(WrapperFIs::Type{<:AbstractWrapperFIs{V0, FIs}},\n        V) where {V0, FIs}\n    return StochasticAD.similar_type(WrapperFIs, V, StochasticAD.similar_type(FIs, V))\nend\n\nStochasticAD.valtype(Δs::AbstractWrapperFIs) = StochasticAD.valtype(Δs.Δs)\n\nfunction StochasticAD.couple(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},\n        Δs_all;\n        rep = nothing,\n        kwargs...) where {V, FIs}\n    _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)\n    _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)\n    return reconstruct_wrapper(StochasticAD.get_any(Δs_all),\n        StochasticAD.couple(FIs, _Δs_all; _rep_kwarg..., kwargs...))\nend\n\nfunction StochasticAD.combine(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},\n        Δs_all;\n        rep = nothing,\n        kwargs...) where {V, FIs}\n    _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)\n    _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)\n    return reconstruct_wrapper(StochasticAD.get_any(Δs_all),\n        StochasticAD.combine(FIs, _Δs_all; _rep_kwarg..., kwargs...))\nend\n\nfunction StochasticAD.get_rep(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},\n        Δs_all;\n        kwargs...) where {V, FIs}\n    _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)\n    return reconstruct_wrapper(StochasticAD.get_any(Δs_all),\n        StochasticAD.get_rep(FIs, _Δs_all; kwargs...))\nend\n\nfunction StochasticAD.scalarize(Δs::AbstractWrapperFIs; rep = nothing, kwargs...)\n    _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)\n    return StochasticAD.structural_map(StochasticAD.scalarize(\n        Δs.Δs; _rep_kwarg..., kwargs...)) do _Δs\n        reconstruct_wrapper(Δs, _Δs)\n    end\nend\n\nfunction StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs, Δs_all; kwargs...)\n    StochasticAD.derivative_contribution(Δs.Δs, Δs_all; kwargs...)\nend\n\nStochasticAD.alltrue(f, Δs::AbstractWrapperFIs) = StochasticAD.alltrue(f, Δs.Δs)\n\nStochasticAD.perturbations(Δs::AbstractWrapperFIs) = StochasticAD.perturbations(Δs.Δs)\n\nfunction StochasticAD.filter_state(Δs::AbstractWrapperFIs, state)\n    StochasticAD.filter_state(Δs.Δs, state)\nend\n\nfunction StochasticAD.weighted_map_Δs(f, Δs::AbstractWrapperFIs; kwargs...)\n    reconstruct_wrapper(Δs, StochasticAD.weighted_map_Δs(f, Δs.Δs; kwargs...))\nend\n\nStochasticAD.new_Δs_strategy(Δs::AbstractWrapperFIs) = StochasticAD.new_Δs_strategy(Δs.Δs)\n\nfunction Base.empty(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}) where {V, FIs}\n    return reconstruct_wrapper(WrapperFIs, empty(FIs))\nend\n\nBase.empty(Δs::AbstractWrapperFIs) = reconstruct_wrapper(Δs, empty(Δs.Δs))\nBase.isempty(Δs::AbstractWrapperFIs) = isempty(Δs.Δs)\nBase.length(Δs::AbstractWrapperFIs) = length(Δs.Δs)\nBase.iszero(Δs::AbstractWrapperFIs) = iszero(Δs.Δs)\n\nfunction StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs)\n    StochasticAD.derivative_contribution(Δs.Δs)\nend\n\nfunction Base.convert(::Type{<:AbstractWrapperFIs{V}}, Δs::AbstractWrapperFIs) where {V}\n    reconstruct_wrapper(Δs, convert(StochasticAD.similar_type(typeof(Δs.Δs), V), Δs.Δs))\nend\n\nfunction StochasticAD.send_signal(\n        Δs::AbstractWrapperFIs, signal::StochasticAD.AbstractPerturbationSignal)\n    reconstruct_wrapper(Δs, StochasticAD.send_signal(Δs.Δs, signal))\nend\n\nfunction Base.show(io::IO, Δs::AbstractWrapperFIs)\n    return show(io, Δs.Δs)\nend\n\nend\n"
  },
  {
    "path": "src/backends/dict.jl",
    "content": "module DictFIsModule\n\nexport DictFIsBackend, DictFIs\n\nimport ..StochasticAD\nusing Dictionaries\n\n\"\"\"\n    DictFIsBackend <: StochasticAD.AbstractFIsBackend\n\nA dictionary backend algorithm which keeps entries for each perturbation that has occurred without pruning. \nCurrently very unoptimized.\n\"\"\"\nstruct DictFIsBackend <: StochasticAD.AbstractFIsBackend end\n\n\"\"\"\n    DictFIsState    \n\nState maintained by dictionary backend.\n\"\"\"\nmutable struct DictFIsState\n    tag_count::Int64\n    valid::Bool\n    DictFIsState(valid = true) = new(0, valid)\nend\n\nstruct InfinitesimalEvent\n    tag::Any # unique identifier\n    w::Float64 # weight (infinitesimal probability wε) \nend\n\nBase.:<(event1::InfinitesimalEvent, event2::InfinitesimalEvent) = event1.tag < event2.tag\nfunction Base.:(==)(event1::InfinitesimalEvent, event2::InfinitesimalEvent)\n    event1.tag == event2.tag\nend\nBase.:isless(event1::InfinitesimalEvent, event2::InfinitesimalEvent) = event1 < event2\n\n\"\"\"\n    DictFIs{V} <: StochasticAD.AbstractFIs{V}\n\nThe implementing backend structure for DictFIsBackend.\n\"\"\"\nstruct DictFIs{V} <: StochasticAD.AbstractFIs{V}\n    dict::Dictionary{InfinitesimalEvent, V}\n    state::DictFIsState\nend\n\nstate(Δs::DictFIs) = Δs.state\n\n### Empty / no perturbation\n\nfunction DictFIs{V}(state::DictFIsState) where {V}\n    DictFIs{V}(Dictionary{InfinitesimalEvent, V}(), state)\nend\nStochasticAD.similar_empty(Δs::DictFIs, V::Type) = DictFIs{V}(Δs.state)\nBase.empty(Δs::DictFIs{V}) where {V} = StochasticAD.similar_empty(Δs::DictFIs, V::Type)\nfunction Base.empty(::Type{<:DictFIs{V}}) where {V}\n    DictFIs{V}(DictFIsState(false))\nend\n\n### Create a new perturbation with infinitesimal probability\n\nfunction new_perturbation(Δ::V, w::Real, state::DictFIsState) where {V}\n    state.tag_count += 1\n    event = InfinitesimalEvent(state.tag_count, w)\n    DictFIs{V}(Dictionary([event], [Δ]), state)\nend\nfunction StochasticAD.similar_new(Δs::DictFIs, Δ::V, w::Real) where {V}\n    new_perturbation(Δ, w, Δs.state)\nend\n\n### Create Δs backend for the first stochastic triple of computation\n\nStochasticAD.create_Δs(::DictFIsBackend, V) = DictFIs{V}(DictFIsState())\n\n### Convert type of a backend\n\nfunction Base.convert(::Type{DictFIs{V}}, Δs::DictFIs) where {V}\n    DictFIs{V}(convert(Dictionary{InfinitesimalEvent, V}, Δs.dict), Δs.state)\nend\n\n### Getting information about Δs\n\nBase.isempty(Δs::DictFIs) = isempty(Δs.dict)\nBase.length(Δs::DictFIs) = length(Δs.dict)\nBase.iszero(Δs::DictFIs) = isempty(Δs) || all(iszero.(Δs.dict))\nfunction StochasticAD.derivative_contribution(Δs::DictFIs{V}) where {V}\n    sum((Δ * event.w for (event, Δ) in pairs(Δs.dict)), init = zero(V) * 0.0)\nend\n\nfunction StochasticAD.perturbations(Δs::DictFIs)\n    [(; Δ, weight = event.w, state = event) for (event, Δ) in pairs(Δs.dict)]\nend\n\n### Unary propagation\n\nfunction StochasticAD.weighted_map_Δs(f, Δs::DictFIs; kwargs...)\n    # Pass key as state in map\n    mapped_values_and_weights = map(f, collect(Δs.dict), keys(Δs.dict))\n    mapped_values = first.(mapped_values_and_weights)\n    mapped_weights = last.(mapped_values_and_weights)\n    scaled_events = map((event, a) -> InfinitesimalEvent(event.tag, event.w * a),\n        keys(Δs.dict),\n        mapped_weights) # TODO: should original events (with old tag) also be modified?\n    dict = Dictionary(scaled_events, mapped_values)\n    DictFIs(dict, Δs.state)\nend\n\nStochasticAD.alltrue(f, Δs::DictFIs) = all(map(f, collect(Δs.dict)))\n\n### Coupling\n\nfunction StochasticAD.get_rep(::Type{<:DictFIs}, Δs_all)\n    for Δs in StochasticAD.structural_iterate(Δs_all)\n        if Δs.state.valid\n            return Δs\n        end\n    end\n    return first(Δs_all)\nend\n\nfunction StochasticAD.couple(FIs::Type{<:DictFIs}, Δs_all;\n        rep = StochasticAD.get_rep(FIs, Δs_all),\n        out_rep = nothing,\n        kwargs...)\n    all_keys = Iterators.map(StochasticAD.structural_iterate(Δs_all)) do Δs\n        keys(Δs.dict)\n    end\n    distinct_keys = unique(all_keys |> Iterators.flatten)\n    Δs_coupled_dict = [StochasticAD.structural_map(\n                           Δs -> isassigned(Δs.dict, key) ?\n                                 Δs.dict[key] :\n                                 zero(eltype(Δs.dict)),\n                           Δs_all)\n                       for key in distinct_keys]\n    DictFIs(Dictionary(distinct_keys, Δs_coupled_dict), rep.state)\nend\n\nfunction StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all;\n        rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...)\n    Δs_dicts = Iterators.map(Δs -> Δs.dict, StochasticAD.structural_iterate(Δs_all))\n    Δs_combined_dict = reduce(Δs_dicts) do Δs_dict1, Δs_dict2\n        mergewith((x, y) -> StochasticAD.structural_map(+, x, y), Δs_dict1, Δs_dict2)\n    end\n    DictFIs(Δs_combined_dict, rep.state)\nend\n\nfunction StochasticAD.scalarize(Δs::DictFIs; out_rep = nothing)\n    # TODO: use vcat here?\n    tupleify(Δ1, Δ2) = StochasticAD.structural_map(tuple, Δ1, Δ2)\n    Δ_all_allkeys = foldl(tupleify, values(Δs.dict))\n    Δ_all_rep = first(values(Δs.dict))\n    _keys = keys(Δs.dict)\n    return StochasticAD.structural_map(Δ_all_rep, Δ_all_allkeys) do _, Δ_allkeys\n        return DictFIs(Dictionary(_keys, Δ_allkeys), Δs.state)\n    end\nend\n\nfunction StochasticAD.filter_state(Δs::DictFIs{V}, key) where {V}\n    haskey(Δs.dict, key) ? Δs.dict[key] : zero(V)\nend\n\n### Miscellaneous\n\nStochasticAD.similar_type(::Type{<:DictFIs}, V::Type) = DictFIs{V}\nStochasticAD.valtype(::Type{<:DictFIs{V}}) where {V} = V\n\nend\n"
  },
  {
    "path": "src/backends/pruned.jl",
    "content": "module PrunedFIsModule\n\nimport ..StochasticAD\n\nexport PrunedFIsBackend, PrunedFIs\n\n\"\"\"\n    PrunedFIsBackend <: StochasticAD.AbstractFIsBackend\n\nA backend algorithm that prunes between perturbations as soon as they clash (e.g. added together).\nCurrently chooses uniformly between all perturbations.\n\"\"\"\nstruct PrunedFIsBackend{M <: Val} <: StochasticAD.AbstractFIsBackend\n    pruning_mode::M\n    function PrunedFIsBackend(pruning_mode::M = Val(:weights)) where {M}\n        if pruning_mode isa Val{:weights} || pruning_mode isa Val{:wins}\n            return new{M}(pruning_mode)\n        else\n            error(\"Unsupported pruning_mode $pruning_mode for `PrunedFIsBackend.\")\n        end\n    end\nend\n\n\"\"\"\n    PrunedFIsState\n\nState maintained by pruning backend.\n\"\"\"\nmutable struct PrunedFIsState{M, W}\n    tag::Int32\n    weight::Float64\n    valid::Bool\n    # TODO: generalize (wins, pruning_mode) into a general interface for accumulating state\n    # that informs future pruning decisions.\n    wins::W\n    pruning_mode::M\n    function PrunedFIsState(pruning_mode::M, valid = true) where {M <: Val}\n        wins = pruning_mode isa Val{:wins} ? (valid ? 1 : 0) : nothing\n        state::PrunedFIsState = new{M, typeof(wins)}(0, 0.0, valid, wins)\n        state.tag = objectid(state) % typemax(Int32)\n        return state\n    end\nend\n\nBase.:(==)(state1::PrunedFIsState, state2::PrunedFIsState) = state1.tag == state2.tag\n# c.f. https://github.com/JuliaLang/julia/blob/61c3521613767b2af21dfa5cc5a7b8195c5bdcaf/base/hashing.jl#L38C45-L38C51\nBase.hash(state::PrunedFIsState) = state.tag\n\n\"\"\"\n    PrunedFIs{V} <: StochasticAD.AbstractFIs{V}\n\nThe implementing backend structure for PrunedFIsBackend.\n\"\"\"\nstruct PrunedFIs{V, S <: PrunedFIsState} <: StochasticAD.AbstractFIs{V}\n    Δ::V\n    state::S\nend\n\n### Empty / no perturbation\n\nPrunedFIs{V}(Δ::V, state::S) where {V, S <: PrunedFIsState} = PrunedFIs{V, S}(Δ, state)\nPrunedFIs{V}(state::PrunedFIsState) where {V} = PrunedFIs{V}(zero(V), state)\n# TODO: avoid allocations here\nfunction StochasticAD.similar_empty(Δs::PrunedFIs, V::Type)\n    PrunedFIs{V}(PrunedFIsState(Δs.state.pruning_mode, false))\nend\nBase.empty(Δs::PrunedFIs{V}) where {V} = StochasticAD.similar_empty(Δs::PrunedFIs, V::Type)\n# we truly have no clue what the state is here, so use an invalidated state\nfunction Base.empty(::Type{<:PrunedFIs{V, S}}) where {V, M, S <: PrunedFIsState{M}}\n    PrunedFIs{V}(PrunedFIsState(M(), false))\nend\n\n### Create a new perturbation with infinitesimal probability\n\nfunction StochasticAD.similar_new(Δs::PrunedFIs, Δ::V, w::Real) where {V}\n    if iszero(w)\n        return StochasticAD.similar_empty(Δs, V)\n    end\n    state = PrunedFIsState(Δs.state.pruning_mode)\n    state.weight += w\n    Δs = PrunedFIs{V}(Δ, state)\n    return Δs\nend\n\n### Create Δs backend for the first stochastic triple of computation\n\nfunction StochasticAD.create_Δs(backend::PrunedFIsBackend, V)\n    PrunedFIs{V}(PrunedFIsState(backend.pruning_mode, false))\nend\n\n### Convert type of a backend\n\nfunction Base.convert(::Type{<:PrunedFIs{V}}, Δs::PrunedFIs) where {V}\n    PrunedFIs{V}(convert(V, Δs.Δ), Δs.state)\nend\n\n### Getting information about perturbations\n\n# \"empty\" here means no perturbation or a perturbation that has been pruned away\nBase.isempty(Δs::PrunedFIs) = !Δs.state.valid\nBase.length(Δs::PrunedFIs) = isempty(Δs) ? 0 : 1\nfunction Base.iszero(Δs::PrunedFIs)\n    isempty(Δs) || all(iszero, StochasticAD.structural_iterate(Δs.Δ))\nend\nBase.iszero(Δs::PrunedFIs{<:Real}) = isempty(Δs) || iszero(Δs.Δ)\nBase.iszero(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) || all(iszero.(Δs.Δ))\nisapproxzero(Δs::PrunedFIs) = isempty(Δs) || isapprox(Δs.Δ, zero(Δs.Δ))\n\n# we lazily prune, so check if empty first\nfunction pruned_value(Δs::PrunedFIs{V}) where {V}\n    isempty(Δs) ? StochasticAD.structural_map(zero, Δs.Δ) : Δs.Δ\nend\npruned_value(Δs::PrunedFIs{<:Real}) = isempty(Δs) ? zero(Δs.Δ) : Δs.Δ\npruned_value(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ\npruned_value(Δs::PrunedFIs{<:AbstractArray}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ\n\nStochasticAD.derivative_contribution(Δs::PrunedFIs) = pruned_value(Δs) * Δs.state.weight\nfunction StochasticAD.perturbations(Δs::PrunedFIs)\n    ((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),)\nend\n\n### Unary propagation\n\nfunction StochasticAD.weighted_map_Δs(f, Δs::PrunedFIs; kwargs...)\n    Δ_out, weight_out = f(pruned_value(Δs), Δs.state)\n    # TODO: we could add a direct overload for map_Δs that elides the below line\n    Δs.state.weight *= weight_out\n    PrunedFIs(Δ_out, Δs.state)\nend\n\nStochasticAD.alltrue(f, Δs::PrunedFIs) = f(pruned_value(Δs))\n\n### Coupling\n\nfunction StochasticAD.get_rep(FIs::Type{<:PrunedFIs}, Δs_all)\n    return empty(FIs) #StochasticAD.get_any(Δs_all)\nend\n\nfunction get_pruned_state(Δs_all; Δ_func = nothing, rep, out_rep = nothing)\n    if !isnothing(Δ_func) && isnothing(out_rep)\n        error(\"Specifying Δ_func requires out_rep to be specified.\")\n    end\n    function op(cur_state, Δs)\n        # lazy pruning optimization temporarily disabled with custom Δ_func \n        # (because custom Δ_func's may prefer not to lazily prune)\n        (isnothing(Δ_func) && isapproxzero(Δs)) && return cur_state\n        candidate_state = Δs.state\n        if !candidate_state.valid ||\n           (candidate_state == cur_state)\n            return cur_state\n        end\n        if !cur_state.valid\n            return candidate_state\n        end\n\n        # Compute \"strength\" of each perturbation for pruning proposal\n        if !isnothing(Δ_func)\n            # TODO: structural_map for each state can take asymptotically more time than necessary when combining many distinct states\n            candidate_Δ = StochasticAD.structural_map(\n                Base.Fix2(StochasticAD.filter_state, candidate_state), Δs_all)\n            candidate_Δ_func::Float64 = Δ_func(candidate_Δ, candidate_state, out_rep)\n            cur_Δ = StochasticAD.structural_map(\n                Base.Fix2(StochasticAD.filter_state, cur_state), Δs_all)\n            cur_Δ_func::Float64 = Δ_func(cur_Δ, cur_state, out_rep)\n        else\n            candidate_Δ_func = 1.0\n            cur_Δ_func = 1.0\n        end\n        candidate_intrinsic_strength = Δs.state.pruning_mode isa Val{:wins} ?\n                                       candidate_state.wins : abs(candidate_state.weight)\n        cur_intrinsic_strength = Δs.state.pruning_mode isa Val{:wins} ? cur_state.wins :\n                                 abs(cur_state.weight)\n        candidate_strength = candidate_intrinsic_strength * candidate_Δ_func\n        cur_strength = cur_intrinsic_strength * cur_Δ_func\n\n        both_states_bad = iszero(candidate_strength) && iszero(cur_strength)\n        if both_states_bad\n            cur_state.valid = false\n            candidate_state.valid = false\n            return cur_state\n        end\n\n        # Prune between perturbations\n        total_strength = cur_strength + candidate_strength\n        p = candidate_strength / total_strength\n        if isone(p) || (rand(StochasticAD.RNG) < p)\n            cur_state.valid = false\n            if Δs.state.pruning_mode isa Val{:wins}\n                candidate_state.wins += 1\n            end\n            candidate_state.weight *= 1 / p\n            return candidate_state\n        else\n            candidate_state.valid = false\n            if Δs.state.pruning_mode isa Val{:wins}\n                cur_state.wins += 1\n            end\n            cur_state.weight *= 1 / (1 - p)\n            return cur_state\n        end\n    end\n    dummy_state = PrunedFIsState(rep.state.pruning_mode, false) # For type stability, as well as retval if no better state found. TODO: can this be avoided?\n    _new_state = foldl(op, StochasticAD.structural_iterate(Δs_all); init = dummy_state)\n    return _new_state::PrunedFIsState\nend\n\n# for pruning, coupling amounts to getting rid of perturbed values that have been\n# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.\nfunction StochasticAD.couple(\n        FIs::Type{<:PrunedFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all),\n        out_rep = nothing, Δ_func = nothing, kwargs...)\n    state = get_pruned_state(Δs_all; rep, Δ_func)\n    Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here\n    PrunedFIs(Δ_coupled, state)\nend\n\n# basically couple combined with a sum.\nfunction StochasticAD.combine(\n        FIs::Type{<:PrunedFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all),\n        Δ_func = nothing, out_rep = nothing, kwargs...)\n    state = get_pruned_state(Δs_all;\n        rep,\n        out_rep,\n        Δ_func = !isnothing(Δ_func) ? (Δ, state, val) -> Δ_func(sum(Δ), state, val) :\n                 Δ_func)\n    Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all))\n    PrunedFIs(Δ_combined, state)\nend\n\nfunction StochasticAD.scalarize(Δs::PrunedFIs; out_rep = nothing)\n    return StochasticAD.structural_map(Δs.Δ) do Δ\n        return PrunedFIs(Δ, Δs.state)\n    end\nend\n\nfunction StochasticAD.filter_state(Δs::PrunedFIs{V}, state) where {V}\n    Δs.state == state ? pruned_value(Δs) : zero(V)\nend\n\n### Miscellaneous\n\nfunction StochasticAD.similar_type(::Type{<:PrunedFIs{V0, M}}, V::Type) where {V0, M}\n    PrunedFIs{V, M}\nend\nStochasticAD.valtype(::Type{<:PrunedFIs{V}}) where {V} = V\n\nfunction Base.show(io::IO, Δs::PrunedFIs{V}) where {V}\n    print(io, \"$(pruned_value(Δs)) with probability $(Δs.state.weight)ε\")\nend\n\nend\n"
  },
  {
    "path": "src/backends/pruned_aggressive.jl",
    "content": "module PrunedFIsAggressiveModule\n\nimport ..StochasticAD\n\nexport PrunedFIsAggressiveBackend, PrunedFIsAggressive\n\n\"\"\"\n    PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend\n\nA backend algorithm that aggressively prunes between perturbations as soon as they are created.\n\"\"\"\nstruct PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend end\n\n\"\"\"\n    PrunedFIsAggressiveState\n\nState maintained by aggressive pruning backend.\n\"\"\"\nmutable struct PrunedFIsAggressiveState\n    active_tag::Int64 # 0 is always a dummy tag\n    weight::Float64\n    tag_count::Int64\n    valid::Bool\n    PrunedFIsAggressiveState(valid = true) = new(0, 0.0, 0, valid)\nend\n\n\"\"\"\n    PrunedFIsAggressive{V} <: StochasticAD.AbstractFIs{V}\n\nThe implementing backend structure for PrunedFIsAggressiveBackend.\n\"\"\"\nstruct PrunedFIsAggressive{V} <: StochasticAD.AbstractFIs{V}\n    Δ::V\n    tag::Int\n    state::PrunedFIsAggressiveState\n    # directly called when propagating an existing perturbation\nend\n\n### Empty / no perturbation\n\nfunction PrunedFIsAggressive{V}(state::PrunedFIsAggressiveState) where {V}\n    PrunedFIsAggressive{V}(zero(V), -1, state)\nend\nfunction StochasticAD.similar_empty(Δs::PrunedFIsAggressive, V::Type)\n    PrunedFIsAggressive{V}(Δs.state)\nend\nfunction Base.empty(Δs::PrunedFIsAggressive{V}) where {V}\n    StochasticAD.similar_empty(Δs, V)\nend\n# we truly have no clue what the state is here, so use an invalidated state\nfunction Base.empty(::Type{<:PrunedFIsAggressive{V}}) where {V}\n    PrunedFIsAggressive{V}(PrunedFIsAggressiveState(false))\nend\n\n### Create a new perturbation with infinitesimal probability\n\nfunction new_perturbation(Δ::V, w::Real, state::PrunedFIsAggressiveState) where {V}\n    total_weight = state.weight + w\n    if rand(StochasticAD.RNG) * total_weight < state.weight\n        state.weight += w\n        return PrunedFIsAggressive{V}(state)\n    else\n        state.tag_count += 1\n        state.active_tag = state.tag_count\n        state.weight += w\n        return PrunedFIsAggressive{V}(Δ, state.active_tag, state)\n    end\nend\nfunction StochasticAD.similar_new(Δs::PrunedFIsAggressive, Δ::V, w::Real) where {V}\n    new_perturbation(Δ, w, Δs.state)\nend\n\n### Create Δs backend for the first stochastic triple of computation\n\nfunction StochasticAD.create_Δs(::PrunedFIsAggressiveBackend, V)\n    PrunedFIsAggressive{V}(PrunedFIsAggressiveState())\nend\n\n### Convert type of a backend\n\nfunction Base.convert(::Type{PrunedFIsAggressive{V}}, Δs::PrunedFIsAggressive) where {V}\n    PrunedFIsAggressive{V}(convert(V, Δs.Δ), Δs.tag, Δs.state)\nend\n\n### Getting information about perturbations\n\n# \"empty\" here means no perturbation or a perturbation that has been pruned away\nBase.isempty(Δs::PrunedFIsAggressive) = Δs.tag != Δs.state.active_tag\nBase.length(Δs::PrunedFIsAggressive) = isempty(Δs) ? 0 : 1\nBase.iszero(Δs::PrunedFIsAggressive) = isempty(Δs) || iszero(Δs.Δ)\n\n# we lazily prune, so check if empty first\npruned_value(Δs::PrunedFIsAggressive{V}) where {V} = isempty(Δs) ? zero(V) : Δs.Δ\n\nfunction StochasticAD.derivative_contribution(Δs::PrunedFIsAggressive)\n    pruned_value(Δs) * Δs.state.weight\nend\n\nfunction StochasticAD.perturbations(Δs::PrunedFIsAggressive)\n    ((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),)\nend\n\n### Unary propagation\n\nfunction StochasticAD.weighted_map_Δs(f, Δs::PrunedFIsAggressive; kwargs...)\n    Δ_out, weight_out = f(Δs.Δ, nothing)\n    Δs.state.weight *= weight_out\n    PrunedFIsAggressive(Δ_out, Δs.tag, Δs.state)\nend\n\nStochasticAD.alltrue(f, Δs::PrunedFIsAggressive) = f(Δs.Δ)\n\n### Coupling\n\nfunction StochasticAD.get_rep(::Type{<:PrunedFIsAggressive}, Δs_all)\n    # Get some Δs with a valid state, or any if all are invalid.\n    return reduce((Δs1, Δs2) -> Δs1.state.valid ? Δs1 : Δs2,\n        StochasticAD.structural_iterate(Δs_all))\nend\n\n# for pruning, coupling amounts to getting rid of perturbed values that have been\n# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.\nfunction StochasticAD.couple(FIs::Type{<:PrunedFIsAggressive}, Δs_all;\n        rep = StochasticAD.get_rep(FIs, Δs_all),\n        out_rep = nothing, kwargs...)\n    state = rep.state\n    Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here\n    PrunedFIsAggressive(Δ_coupled, state.active_tag, state)\nend\n\n# basically couple combined with a sum.\nfunction StochasticAD.combine(FIs::Type{<:PrunedFIsAggressive}, Δs_all;\n        rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...)\n    state = rep.state\n    Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all))\n    PrunedFIsAggressive(Δ_combined, state.active_tag, state)\nend\n\nfunction StochasticAD.scalarize(Δs::PrunedFIsAggressive; out_rep = nothing)\n    return StochasticAD.structural_map(Δs.Δ) do Δ\n        return PrunedFIsAggressive(Δ, Δs.tag, Δs.state)\n    end\nend\n\nStochasticAD.filter_state(Δs::PrunedFIsAggressive, _) = pruned_value(Δs)\n\n### Miscellaneous\n\nStochasticAD.similar_type(::Type{<:PrunedFIsAggressive}, V::Type) = PrunedFIsAggressive{V}\nStochasticAD.valtype(::Type{<:PrunedFIsAggressive{V}}) where {V} = V\n\n# should I have a mime input?\nfunction Base.show(io::IO, mime::MIME\"text/plain\",\n        Δs::PrunedFIsAggressive{V}) where {V}\n    print(io, \"$(pruned_value(Δs)) with probability $(Δs.state.weight)ε, tag $(Δs.tag)\")\nend\n\nfunction Base.show(io::IO, Δs::PrunedFIsAggressive{V}) where {V}\n    print(io, \"$(pruned_value(Δs)) with probability $(Δs.state.weight)ε, tag $(Δs.tag)\")\nend\n\nend\n"
  },
  {
    "path": "src/backends/smoothed.jl",
    "content": "module SmoothedFIsModule\n\nimport ..StochasticAD\n\nexport SmoothedFIsBackend, SmoothedFIs\n\n\"\"\"\n    SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend\n\nA backend algorithm that smooths perturbations togethers. \n\"\"\"\nstruct SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend end\n\n\"\"\"\n    SmoothedFIs{V} <: StochasticAD.AbstractFIs{V}\n\nThe implementing backend structure for SmoothedFIsBackend.\n\"\"\"\n# TODO: make type of δ generic\nstruct SmoothedFIs{V, V_float} <: StochasticAD.AbstractFIs{V}\n    δ::V_float\n    function SmoothedFIs{V}(δ) where {V}\n        # hardcode Float64 representation for now, for simplicity.\n        δ_f64 = StochasticAD.structural_map(Base.Fix1(convert, Float64), δ)\n        return new{V, typeof(δ_f64)}(δ_f64)\n    end\nend\n\n### Empty / no perturbation\n\nStochasticAD.similar_empty(::SmoothedFIs, V::Type) = SmoothedFIs{V}(0.0)\nBase.empty(::Type{<:SmoothedFIs{V}}) where {V} = SmoothedFIs{V}(0.0)\nBase.empty(Δs::SmoothedFIs) = empty(typeof(Δs))\n\n### Create a new perturbation with infinitesimal probability\n\nfunction StochasticAD.similar_new(::SmoothedFIs, Δ::V, w::Real) where {V}\n    SmoothedFIs{V}(Δ * w)\nend\n\nStochasticAD.new_Δs_strategy(::SmoothedFIs) = StochasticAD.TwoSidedStrategy()\n\n### Create Δs backend for the first stochastic triple of computation\n\nStochasticAD.create_Δs(::SmoothedFIsBackend, V) = SmoothedFIs{V}(0.0)\n\n### Convert type of a backend\n\nfunction Base.convert(FIs::Type{<:SmoothedFIs{V}}, Δs::SmoothedFIs) where {V}\n    SmoothedFIs{V}(Δs.δ)::FIs\nend\n\n### Getting information about perturbations\n\nBase.isempty(Δs::SmoothedFIs) = false\nBase.iszero(Δs::SmoothedFIs) = iszero(Δs.δ)\nBase.iszero(Δs::SmoothedFIs{<:Tuple}) = all(iszero.(Δs.δ))\nStochasticAD.derivative_contribution(Δs::SmoothedFIs) = Δs.δ\n\n### Unary propagation\n\nfunction StochasticAD.weighted_map_Δs(f, Δs::SmoothedFIs; deriv, out_rep, kwargs...)\n    SmoothedFIs{typeof(out_rep)}(deriv(Δs.δ))\nend\n\nStochasticAD.alltrue(f, Δs::SmoothedFIs) = true\n\n### Coupling\n\nStochasticAD.get_rep(::Type{<:SmoothedFIs}, Δs_all) = StochasticAD.get_any(Δs_all)\n\nfunction StochasticAD.couple(\n        ::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep, kwargs...)\n    SmoothedFIs{typeof(out_rep)}(StochasticAD.structural_map(Δs -> Δs.δ, Δs_all))\nend\n\nfunction StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, kwargs...)\n    V_out = StochasticAD.valtype(first(StochasticAD.structural_iterate(Δs_all)))\n    Δ_combined = sum(Δs -> Δs.δ, StochasticAD.structural_iterate(Δs_all))\n    SmoothedFIs{V_out}(Δ_combined)\nend\n\nfunction StochasticAD.scalarize(Δs::SmoothedFIs; out_rep)\n    return StochasticAD.structural_map(out_rep, Δs.δ) do out, δ\n        return SmoothedFIs{typeof(out)}(δ)\n    end\nend\n\n### Miscellaneous\n\nStochasticAD.similar_type(::Type{<:SmoothedFIs}, V::Type) = SmoothedFIs{V, Float64}\nStochasticAD.valtype(::Type{<:SmoothedFIs{V}}) where {V} = V\n\nfunction Base.show(io::IO, Δs::SmoothedFIs)\n    print(io, \"$(Δs.δ)ε\")\nend\n\nend\n"
  },
  {
    "path": "src/backends/strategy_wrapper.jl",
    "content": "module StrategyWrapperFIsModule\n\nusing ..StochasticAD\nusing ..StochasticAD.AbstractWrapperFIsModule\n\nexport StrategyWrapperFIsBackend, StrategyWrapperFIs\n\nstruct StrategyWrapperFIsBackend{\n    B <: StochasticAD.AbstractFIsBackend,\n    S <: StochasticAD.AbstractPerturbationStrategy\n} <:\n       StochasticAD.AbstractFIsBackend\n    backend::B\n    strategy::S\nend\n\nstruct StrategyWrapperFIs{\n    V,\n    FIs <: StochasticAD.AbstractFIs{V},\n    S <: StochasticAD.AbstractPerturbationStrategy\n} <:\n       AbstractWrapperFIs{V, FIs}\n    Δs::FIs\n    strategy::S\nend\n\nfunction StochasticAD.create_Δs(backend::StrategyWrapperFIsBackend, V)\n    return StrategyWrapperFIs(StochasticAD.create_Δs(backend.backend, V), backend.strategy)\nend\n\nfunction StochasticAD.similar_type(::Type{<:StrategyWrapperFIs{V0, FIs0, S}},\n        V,\n        FIs) where {V0, FIs0, S}\n    return StrategyWrapperFIs{V, FIs, S}\nend\n\nfunction AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::StrategyWrapperFIs, Δs)\n    return StrategyWrapperFIs(Δs, wrapper_Δs.strategy)\nend\n\nfunction AbstractWrapperFIsModule.reconstruct_wrapper(\n        ::Type{\n            <:StrategyWrapperFIs{V, FIs, S},\n        },\n        Δs) where {V, FIs, S}\n    return StrategyWrapperFIs(Δs, S())\nend\n\nStochasticAD.new_Δs_strategy(Δs::StrategyWrapperFIs) = Δs.strategy\n\nend\n"
  },
  {
    "path": "src/discrete_randomness.jl",
    "content": "## Helper functions for discrete distributions \n\n# index of the parameter p\n_param_index(::Geometric) = 1\n_param_index(::Bernoulli) = 1\n_param_index(::Binomial) = 2\n_param_index(::Poisson) = 1\n_param_index(::Categorical) = 1\n\n_get_parameter(d) = params(d)[_param_index(d)]\n\n# constructors\nfor dist in [:Geometric, :Bernoulli, :Binomial, :Poisson, :Categorical]\n    @eval _constructor(::$dist) = $dist\nend\n\n# reconstruct probability distribution with new paramter value\nfunction _reconstruct(d, p)\n    i = _param_index(d)\n    return _constructor(d)(params(d)[1:(i - 1)]..., p, params(d)[(i + 1):end]...)\nend\n\n# support of probability distribution\n_has_finite_support(d) = false\n_has_finite_support(d::Union{Bernoulli, Binomial, Categorical}) = true\n\n_get_support(d::Union{Bernoulli, Binomial, Categorical}) = minimum(d):maximum(d)\n# manual overloads to ensure that static-ness is preserved for Bernoulli's and Categoricals with static arrays.\n# since mapping over the range above could result in allocating vectors.\n_get_support(::Bernoulli) = (0, 1)\n# the map below looks a bit silly, but it gives us a collection of the categories with the same structure as probs(d). \n_get_support(d::Categorical) = map((val, prob) -> val, 1:ncategories(d), probs(d))\n\n## Derivative couplings\n\n# Derivative coupling approaches, determining which weighted perturbations to consider\nabstract type AbstractDerivativeCoupling end\n\n\"\"\"\n    InversionMethodDerivativeCoupling(; mode::Val = Val(:positive_weight), handle_zeroprob::Val = Val(true))\n\nSpecifies an inversion method coupling for generating perturbations from a univariate distribution.\nValid choices of `mode` are `Val(:positive_weight)`, `Val(:always_right)`, and `Val(:always_left)`.\n\n# Example\n```jldoctest\njulia> using StochasticAD, Distributions, Random; Random.seed!(4321);\n\njulia> function X(p)\n           return randst(Bernoulli(1 - p); derivative_coupling = InversionMethodDerivativeCoupling(; mode = Val(:always_right)))\n       end\nX (generic function with 1 method)\n\njulia> stochastic_triple(X, 0.5)\nStochasticTriple of Int64:\n0 + 0ε + (1 with probability -2.0ε)\n```\n\"\"\"\nBase.@kwdef struct InversionMethodDerivativeCoupling{M, HZP}\n    mode::M = Val(:positive_weight)\n    handle_zeroprob::HZP = Val(true)\nend\n\n# Strategies for precisely which perturbations to form given a derivative coupling\nstruct SingleSidedStrategy <: AbstractPerturbationStrategy end\nstruct TwoSidedStrategy <: AbstractPerturbationStrategy end\nstruct SmoothedStraightThroughStrategy <: AbstractPerturbationStrategy end\nstruct StraightThroughStrategy <: AbstractPerturbationStrategy end\nstruct IgnoreDiscreteStrategy <: AbstractPerturbationStrategy end\n\nnew_Δs_strategy(Δs) = SingleSidedStrategy()\n\n# Derivative coupling high-level interface\n\n\"\"\"\n    δtoΔs(d, val, δ, Δs::AbstractFIs)\n\nGiven the parameter `val` of a distribution `d` and an infinitesimal change `δ`,\nreturn the discrete change in the output, with a similar representation to `Δs`.\n\"\"\"\nδtoΔs(d, val, δ, Δs, derivative_coupling) = δtoΔs(\n    d, val, δ, Δs, derivative_coupling, new_Δs_strategy(Δs))\nfunction δtoΔs(d, val, δ, Δs, derivative_coupling, ::SingleSidedStrategy)\n    _δtoΔs(d, val, δ, Δs, derivative_coupling)\nend\nfunction δtoΔs(d, val, δ, Δs, derivative_coupling, ::TwoSidedStrategy)\n    Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling)\n    Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling)\n    return combine((scale(Δs1, 0.5), scale(Δs2, -0.5)))\nend\n# TODO: implement this ST for other distributions and couplings, if meaningful?\nfunction δtoΔs(d::Union{Bernoulli, Binomial},\n        val,\n        δ,\n        Δs,\n        derivative_coupling::InversionMethodDerivativeCoupling,\n        ::StraightThroughStrategy)\n    p = succprob(d)\n    Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling)\n    Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling)\n    return combine((scale(Δs1, 1 - p), scale(Δs2, -p)))\nend\nfunction δtoΔs(d, val::V, δ, Δs, derivative_coupling, ::IgnoreDiscreteStrategy) where {V}\n    similar_empty(Δs, V)\nend\n\n# Implement straight through strategy, works for all distrs, but does something that is only\n# meaningful for smoothed backends (using one(val))\nfunction δtoΔs(d, val, δ, Δs, derivative_coupling, ::SmoothedStraightThroughStrategy)\n    p = _get_parameter(d)\n    δout = ForwardDiff.derivative(a -> mean(_reconstruct(d, p + a * δ)), 0.0)\n    return similar_new(Δs, one(val), δout)\nend\n\n# Derivative coupling low-level implementations \n\nfunction _δtoΔs(d::Geometric,\n        val::V,\n        δ::Real,\n        Δs::AbstractFIs,\n        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}\n    p = succprob(d)\n    if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||\n       (derivative_coupling.mode isa Val{:always_right})\n        return val > 0 ? similar_new(Δs, -one(V), δ * val / p / (1 - p)) :\n               similar_empty(Δs, V)\n    elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||\n           (derivative_coupling.mode isa Val{:always_left})\n        return similar_new(Δs, one(V), -δ * (val + 1) / p)\n    else\n        return similar_empty(Δs, V)\n    end\nend\n\nfunction _δtoΔs(d::Bernoulli,\n        val::V,\n        δ::Real,\n        Δs::AbstractFIs,\n        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}\n    p = succprob(d)\n    if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||\n       (derivative_coupling.mode isa Val{:always_right})\n        return isone(val) ? similar_empty(Δs, V) : similar_new(Δs, one(V), δ / (1 - p))\n    elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||\n           (derivative_coupling.mode isa Val{:always_left})\n        return isone(val) ? similar_new(Δs, -one(V), -δ / p) : similar_empty(Δs, V)\n    else\n        return similar_empty(Δs, V)\n    end\nend\n\nfunction _δtoΔs(d::Binomial,\n        val::V,\n        δ::Real,\n        Δs::AbstractFIs,\n        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}\n    p = succprob(d)\n    n = ntrials(d)\n    if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||\n       (derivative_coupling.mode isa Val{:always_right})\n        return val == n ? similar_empty(Δs, V) :\n               similar_new(Δs, one(V), δ * (n - val) / (1 - p))\n    elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||\n           (derivative_coupling.mode isa Val{:always_left})\n        return !iszero(val) ? similar_new(Δs, -one(V), -δ * val / p) : similar_empty(Δs, V)\n    else\n        return similar_empty(Δs, V)\n    end\nend\n\nfunction _δtoΔs(d::Poisson,\n        val::V,\n        δ::Real,\n        Δs::AbstractFIs,\n        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}\n    p = mean(d) # rate\n    if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||\n       (derivative_coupling.mode isa Val{:always_right})\n        return similar_new(Δs, 1, δ)\n    elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||\n           (derivative_coupling.mode isa Val{:always_left})\n        return val > 0 ? similar_new(Δs, -1, -δ * val / p) : similar_empty(Δs, V)\n    else\n        return similar_empty(Δs, V)\n    end\nend\n\nfunction _δtoΔs(d::Categorical,\n        val::V,\n        δs,\n        Δs::AbstractFIs,\n        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}\n    p = params(d)[1]\n    # NB: Although we might expect sum(δs) = 0, it is useful to handle things more generally, viewing δs\n    # as perturbing the Categorical distribution locally along some direction in the space of general measures.\n    # The below formulation gets things right in this case too. \n    left_sum = sum(δs[1:(val - 1)], init = zero(eltype(δs)))\n    right_sum = sum(δs[1:val], init = zero(eltype(δs)))\n\n    if (derivative_coupling.mode isa Val{:positive_weight} && left_sum > 0) ||\n       (derivative_coupling.mode isa Val{:always_left} && !iszero(left_sum))\n        # compute left_nonzero\n        if derivative_coupling.handle_zeroprob isa Val{true}\n            stop = rand() * left_sum\n            upto = zero(eltype(δs)) # The \"upto\" logic handles an edge case of probability 0 events that have non-zero derivative.\n            # It's a lot of logic to handle an edge case, but hopefully it's optimized away.\n            left_nonzero = val\n            for i in (val - 1):-1:1\n                if !iszero(p[i]) || ((upto += δs[i]) > stop)\n                    left_nonzero = i\n                    break\n                end\n            end\n        else\n            left_nonzero = val - 1\n        end\n        Δs_left = similar_new(Δs, left_nonzero - val, left_sum / p[val])\n    else\n        Δs_left = similar_empty(Δs, typeof(val))\n    end\n\n    if (derivative_coupling.mode isa Val{:positive_weight} && right_sum < 0) ||\n       (derivative_coupling.mode isa Val{:always_right} && !iszero(right_sum))\n        # compute right_nonzero\n        if derivative_coupling.handle_zeroprob isa Val{true}\n            stop = -rand() * right_sum\n            upto = zero(eltype(δs))\n            right_nonzero = val\n            for i in (val + 1):length(p)\n                if !iszero(p[i]) || ((upto += δs[i]) > stop)\n                    right_nonzero = i\n                    break\n                end\n            end\n        else\n            right_nonzero = val + 1\n        end\n        Δs_right = similar_new(Δs, right_nonzero - val, -right_sum / p[val])\n    else\n        Δs_right = similar_empty(Δs, typeof(val))\n    end\n\n    return combine((Δs_left, Δs_right); rep = Δs)\nend\n\n## Propagation couplings\n\nabstract type AbstractPropagationCoupling end\n\n\"\"\"\n    InversionMethodPropagationCoupling \n\nSpecifies an inversion method coupling for propagating perturbations.\n\"\"\"\nstruct InversionMethodPropagationCoupling <: AbstractPropagationCoupling end\n\nfunction _map_func(d, val, Δ, ::InversionMethodPropagationCoupling)\n    # construct alternative distribution\n    p = _get_parameter(d)\n    alt_d = _reconstruct(d, p + Δ)\n    # compute bounds on original ω\n    low = cdf(d, val - 1)\n    high = cdf(d, val)\n    # sample alternative value\n    alt_val = quantile(alt_d, rand(RNG) * (high - low) + low)\n    return convert(Signed, alt_val - val)\nend\n\nfunction _map_enumeration(d, val, Δ, ::InversionMethodPropagationCoupling)\n    # construct alternative distribution\n    p = _get_parameter(d)\n    alt_d = _reconstruct(d, p + Δ)\n    # compute bounds on original ω\n    low = cdf(d, val - 1)\n    high = cdf(d, val)\n    if _has_finite_support(alt_d)\n        map(_get_support(alt_d)) do alt_val\n            # interval intersect of (cdf(alt_d, alt_val - 1), cdf(alt_d, alt_val)) and (low, high)\n            alt_low = cdf(alt_d, alt_val - 1)\n            alt_high = cdf(alt_d, alt_val)\n            prob_alt = max(0.0, min(alt_high, high) - max(alt_low, low)) /\n                       (high - low)\n            return (alt_val - val, prob_alt)\n        end\n    else\n        error(\"enumeration not supported for distribution $d. Does $d have finite support?\")\n    end\nend\n\n## Overloading of random sampling \n\n# Define randst interface\n\n\"\"\"\n    randst(rng, d::Distributions.Sampleable; kwargs...)\n\nWhen no keyword arguments are provided, `randst` behaves identically to `rand(rng, d)` in both ordinary computation\nand for stochastic triple dispatches. However, `randst` also allows the user to provide various keyword arguments\nfor customizing the differentiation logic. The set of allowed keyword arguments depends on the type of `d`: a couple\ncommon ones are `derivative_coupling` and `propagation_coupling`.\n\nFor developers: if you wish to accept custom keyword arguments in a stochastic triple dispatch, you should overload\n`randst`, and redirect `rand` to your `randst` method. If you do not, it suffices to just overload `rand`.\n\"\"\"\nrandst(rng, d::Distributions.Sampleable; kwargs...) = rand(rng, d)\nrandst(d::Distributions.Sampleable; kwargs...) = randst(Random.default_rng(), d; kwargs...)\n\n# Define stochastic triple rules\n\nfor dist in [:Geometric, :Bernoulli, :Binomial, :Poisson]\n    @eval function Base.rand(rng::AbstractRNG,\n            d_st::$dist{StochasticTriple{T, V, FIs}}) where {T, V, FIs}\n        return randst(rng, d_st)\n    end\n    @eval function randst(rng::AbstractRNG,\n            d_st::$dist{StochasticTriple{T, V, FIs}};\n            Δ_kwargs = (;),\n            derivative_coupling = InversionMethodDerivativeCoupling(),\n            propagation_coupling = InversionMethodPropagationCoupling()) where {T, V, FIs}\n        st = _get_parameter(d_st)\n        d = _reconstruct(d_st, st.value)\n        val = convert(Signed, rand(rng, d))\n        Δs1 = δtoΔs(d, val, st.δ, st.Δs, derivative_coupling)\n\n        Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling),\n            st.Δs;\n            enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling),\n            deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling),\n            out_rep = val,\n            Δ_kwargs...)\n\n        StochasticTriple{T}(val, zero(val), combine((Δs2, Δs1); rep = Δs1)) # ensure that tags are in order in combine, in case backend wishes to exploit this \n    end\nend\n\n# currently handle Categorical separately since parameter is a vector\n# what if some elements in vector are not stochastic triples... promotion should take care of that?\nfunction Base.rand(rng::AbstractRNG,\n        d_st::Categorical{StochasticTriple{T, V, FIs}}) where {T, V, FIs}\n    return randst(rng, d_st)\nend\nfunction randst(rng::AbstractRNG,\n        d_st::Categorical{<:StochasticTriple{T},\n            <:AbstractVector{<:StochasticTriple{T, V}}};\n        Δ_kwargs = (;),\n        derivative_coupling = InversionMethodDerivativeCoupling(),\n        propagation_coupling = InversionMethodPropagationCoupling()) where {T, V}\n    sts = _get_parameter(d_st) # stochastic triple for each probability\n    p = map(st -> st.value, sts) # try to keep the same type. e.g. static array -> static array. TODO: avoid allocations \n    d = _reconstruct(d_st, p)\n    val = convert(Signed, rand(rng, d))\n\n    Δs_all = map(st -> st.Δs, sts)\n    Δs_rep = get_rep(Δs_all)\n\n    Δs1 = δtoΔs(d, val, map(st -> st.δ, sts), Δs_rep, derivative_coupling)\n\n    Δs_coupled = couple(Δs_all; rep = Δs_rep, out_rep = p) # TODO: again, there are possible allocations here\n    Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling),\n        Δs_coupled;\n        enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling),\n        deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling),\n        out_rep = val,\n        Δ_kwargs...)\n\n    Δs = combine((Δs2, Δs1); rep = Δs1, out_rep = val, Δ_kwargs...)\n\n    StochasticTriple{T}(val, zero(val), Δs)\nend\n\n## Handling finite perturbation to Binomial number of trials\n\n\"\"\"\n    DiscreteDeltaStochasticTriple{T, V, FIs <: AbstractFIs}\n\nAn experimental discrete stochastic triple type used internally for representing perturbations\nto non-real quantities. Currently only used to represent a finite perturbation to the Binomial \nparameter n.\n\n## Constructor\n\n- `value`: the primal value.\n- `Δs``: some representation of the perturbation to the primal, which can have an unconventional\n         interpretation depending on `T`.\n\"\"\"\nstruct DiscreteDeltaStochasticTriple{T, V, FIs <: AbstractFIs}\n    value::V\n    Δs::FIs\n    function DiscreteDeltaStochasticTriple{T, V, FIs}(value::V,\n            Δs::FIs) where {T, V,\n            FIs <: AbstractFIs}\n        new{T, V, FIs}(value, Δs)\n    end\nend\n\nfunction DiscreteDeltaStochasticTriple{T}(val::V, Δs::FIs) where {T, V, FIs <: AbstractFIs}\n    DiscreteDeltaStochasticTriple{T, V, FIs}(val, Δs)\nend\n\nfunction Distributions.Binomial(n::StochasticTriple{T}, p::Real) where {T}\n    return DiscreteDeltaStochasticTriple{T}(Binomial(n.value, p), n.Δs)\nend\n\n# TODO: Support functions other than `rand` called on a perturbed Binomial.\nfunction Base.rand(rng::AbstractRNG,\n        d_st::DiscreteDeltaStochasticTriple{T, <:Binomial}) where {T}\n    return randst(rng, d_st)\nend\nfunction randst(rng::AbstractRNG,\n        d_st::DiscreteDeltaStochasticTriple{T, <:Binomial}) where {T}\n    d = d_st.value\n    val = rand(rng, d)\n    function map_func(Δ)\n        if Δ >= 0\n            return rand(StochasticAD.RNG, Binomial(Δ, value(succprob(d))))\n        else\n            return -rand(StochasticAD.RNG,\n                Hypergeometric(value(val), ntrials(d) - value(val), -Δ))\n        end\n    end\n    Δs = map(map_func, d_st.Δs)\n    if val isa StochasticTriple\n        return StochasticTriple{T}(val.value, val.δ, combine((Δs, val.Δs); rep = Δs))\n    else\n        return StochasticTriple{T}(val, zero(val), Δs)\n    end\nend\n"
  },
  {
    "path": "src/finite_infinitesimals.jl",
    "content": "# TODO: make this a module, with the interface exported?\n\n## \n\"\"\"\n    AbstractFIsBackend\n\nAn abstract type for backend strategies of Finite perturbations that occur with Infinitesimal probability (FIs).\n\"\"\"\nabstract type AbstractFIsBackend end\n\n\"\"\"\n    AbstractFIs{V}\n\nAn abstract type for concrete backend representations of Finite Infinitesimals. \n\"\"\"\nabstract type AbstractFIs{V} end\n\n### Some of the necessary interface notes below.\n# TODO: document\n\nfunction create_Δs end\n\nfunction similar_new end\nfunction similar_empty end\nfunction similar_type end\n\nvaltype(Δs::AbstractFIs) = valtype(typeof(Δs))\n\n# TODO: typeof ∘ first is a loose check, should make more robust.\n# TODO: perhaps deprecate these methods in favor of an explicit first argument?\ncouple(Δs_all; kwargs...) = couple(typeof(first(Δs_all)), Δs_all; kwargs...)\ncombine(Δs_all; kwargs...) = combine(typeof(first(Δs_all)), Δs_all; kwargs...)\nget_rep(Δs_all; kwargs...) = get_rep(typeof(first(Δs_all)), Δs_all; kwargs...)\nfunction scalarize end\n\nfunction derivative_contribution end\n\nfunction alltrue end\n\nfunction perturbations end\n\nfunction filter_state end\n\nfunction weighted_map_Δs end\nfunction map_Δs(f, Δs::AbstractFIs; kwargs...)\n    StochasticAD.weighted_map_Δs((Δs, state) -> (f(Δs, state), 1.0), Δs; kwargs...)\nend\nfunction Base.map(f, Δs::AbstractFIs; kwargs...)\n    StochasticAD.map_Δs((Δs, _) -> f(Δs), Δs; kwargs...)\nend\n# We also add a scale to deriv for scaling smoothed perturbations \nfunction scale(Δs::AbstractFIs, a::Real)\n    StochasticAD.weighted_map_Δs((Δ, state) -> (Δ, a),\n        Δs;\n        deriv = Base.Fix1(*, a),\n        out_rep = Δs)\nend\n\nfunction new_Δs_strategy end\n\n# utility function useful e.g. for get_rep in some backends\nfunction get_any(Δs_all)\n    # The code below is a bit ridiculous, but it's faster than `first` for small structures:)\n    foldl((Δs1, Δs2) -> Δs1, StochasticAD.structural_iterate(Δs_all))\nend\n\nabstract type AbstractPerturbationStrategy end\n\nabstract type AbstractPerturbationSignal end\n\nfunction send_signal end\n\n# Ignore signals by default since they do not change semantics.\nfunction StochasticAD.send_signal(\n        Δs::StochasticAD.AbstractFIs, ::StochasticAD.AbstractPerturbationSignal)\n    return Δs\nend\n"
  },
  {
    "path": "src/general_rules.jl",
    "content": "\"\"\"\nOperators which have already been overloaded by StochasticAD. \n\"\"\"\nconst handled_ops = Tuple{DataType, Int}[]\n\n\"\"\"\n    define_triple_overload(sig)\n\nGiven the signature type-type of the primal function, define operator\noverloading rules for stochastic triples.\nCurrently supports functions with all-real inputs and one real output.\n\"\"\"\n# TODO: special case optimizations\n# TODO: generalizations to not-all-real inputs and/or not-one-real output\nfunction define_triple_overload(sig)\n    opT, argTs = Iterators.peel(ExprTools.parameters(sig))\n    opT <: Type{<:Type} && return  # not handling constructors\n    sig <: Tuple{Type, Vararg{Any}} && return\n    opT <: Core.Builtin && return false  # can't do operator overloading for builtins\n\n    isabstracttype(opT) || fieldcount(opT) == 0 || return false  # not handling functors\n    isempty(argTs) && return false  # we are an operator overloading AD, need operands\n    all(argT isa Type && Real <: argT for argT in argTs) || return\n\n    N = length(ExprTools.parameters(sig)) - 1  # skip the op\n\n    # Skip already-handled ops, as well as ops that will be handled manually later (and more correctly, see #79).\n    if (opT, N) in handled_ops || (opT.instance in UNARY_TYPEFUNCS_WRAP)\n        return\n    end\n\n    push!(handled_ops, (opT, N))\n\n    if opT.instance in UNARY_PREDICATES && (N == 1)\n        @eval function (f::$opT)(st::StochasticTriple)\n            val = value(st)\n            out = f(val)\n            if !alltrue(Δ -> (f(val + Δ) == out), st.Δs)\n                error(\"Output of boolean predicate cannot depend on input (unsupported by StochasticAD)\")\n            end\n            return out\n        end\n    elseif opT.instance in BINARY_PREDICATES && (N == 2)\n        # Special case equality comparisons as in https://github.com/JuliaDiff/ForwardDiff.jl/pull/481\n        if opT.instance == Base.:(==)\n            return_value_real = quote\n                out && iszero(delta(st))\n            end\n            return_value_st = quote\n                out2 = out && (delta(st1) == delta(st2))\n            end\n        else\n            return_value_real = quote\n                out\n            end\n            return_value_st = quote\n                out\n            end\n        end\n        @eval function (f::$opT)(st::StochasticTriple, x::Real)\n            val = value(st)\n            out = f(val, x)\n            if !alltrue(Δ -> (f(val + Δ, x) == out), st.Δs)\n                error(\"Output of boolean predicate cannot depend on input (unsupported by StochasticAD)\")\n            end\n            return $return_value_real\n        end\n        @eval function (f::$opT)(x::Real, st::StochasticTriple)\n            val = value(st)\n            out = f(x, val)\n            if !alltrue(Δ -> (f(x, val + Δ) == out), st.Δs)\n                error(\"Output of boolean predicate cannot depend on input (unsupported by StochasticAD)\")\n            end\n            return $return_value_real\n        end\n        @eval function (f::$opT)(st1::StochasticTriple, st2::StochasticTriple)\n            val1 = value(st1)\n            val2 = value(st2)\n            out = f(val1, val2)\n\n            Δs_coupled = couple((st1.Δs, st2.Δs); out_rep = (val1, val2))\n            safe_perturb = alltrue(Δs -> f(val1 + Δs[1], val2 + Δs[2]) == out, Δs_coupled)\n            if !safe_perturb\n                error(\"Output of boolean predicate cannot depend on input (unsupported by StochasticAD)\")\n            end\n            return $return_value_st\n        end\n    elseif N == 1\n        if Base.return_types(frule, (Tuple{NoTangent, Real}, opT, Real))[1] <:\n           Tuple{Any, NoTangent}\n            return\n        end\n        @eval function (f::$opT)(st::StochasticTriple{T}; kwargs...) where {T}\n            run_frule = δ -> begin\n                args_tangent = (NoTangent(), δ)\n                return frule(args_tangent, f, value(st); kwargs...)\n            end\n            val, δ0 = run_frule(delta(st))\n            δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0\n            if !iszero(st.Δs)\n                Δs = map(Δ -> f(st.value + Δ; kwargs...) - val, st.Δs;\n                    deriv = last ∘ run_frule, out_rep = val)\n            else\n                Δs = similar_empty(st.Δs, typeof(val))\n            end\n            return StochasticTriple{T}(val, δ, Δs)\n        end\n    elseif N == 2\n        if Base.return_types(frule, (Tuple{NoTangent, Real, Real}, opT, Real, Real))[1] <:\n           Tuple{Any, NoTangent}\n            return\n        end\n        for R in AMBIGUOUS_TYPES\n            @eval function (f::$opT)(st::StochasticTriple{T}, x::$R; kwargs...) where {T}\n                run_frule = δ -> begin\n                    args_tangent = (NoTangent(), δ, zero(x))\n                    return frule(args_tangent, f, value(st), x; kwargs...)\n                end\n                val, δ0 = run_frule(delta(st))\n                δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ?\n                                 zero(value(st)) : δ0\n                if !iszero(st.Δs)\n                    Δs = map(Δ -> f(st.value + Δ, x; kwargs...) - val, st.Δs;\n                        deriv = last ∘ run_frule, out_rep = val)\n                else\n                    Δs = similar_empty(st.Δs, typeof(val))\n                end\n                return StochasticTriple{T}(val, δ, Δs)\n            end\n            @eval function (f::$opT)(x::$R, st::StochasticTriple{T}; kwargs...) where {T}\n                run_frule = δ -> begin\n                    args_tangent = (NoTangent(), zero(x), δ)\n                    return frule(args_tangent, f, x, value(st); kwargs...)\n                end\n                val, δ0 = run_frule(delta(st))\n                δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ?\n                                 zero(value(st)) : δ0\n                if !iszero(st.Δs)\n                    Δs = map(Δ -> f(x, st.value + Δ; kwargs...) - val, st.Δs;\n                        deriv = last ∘ run_frule, out_rep = val)\n                else\n                    Δs = similar_empty(st.Δs, typeof(val))\n                end\n                return StochasticTriple{T}(val, δ, Δs)\n            end\n        end\n        @eval function (f::$opT)(sts::Vararg{StochasticTriple{T}, 2}; kwargs...) where {T}\n            run_frule = δs -> begin\n                args_tangent = (NoTangent(), δs...)\n                args = (f, value.(sts)...)\n                return frule(args_tangent, args...; kwargs...)\n            end\n            val, δ0 = run_frule(delta.(sts))\n            δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0\n\n            Δs_all = map(st -> getfield(st, :Δs), sts)\n            if all(iszero.(Δs_all))\n                Δs = similar_empty(first(sts).Δs, typeof(val))\n            else\n                vals_in = value.(sts)\n                Δs_coupled = couple(Tuple(Δs_all); out_rep = vals_in)\n                mapfunc = let vals_in = vals_in\n                    Δ -> (f((vals_in .+ Δ)...; kwargs...) - val)\n                end\n                Δs = map(mapfunc, Δs_coupled; deriv = last ∘ run_frule, out_rep = val)\n            end\n            return StochasticTriple{T}(val, δ, Δs)\n        end\n    end\nend\n\non_new_rule(define_triple_overload, frule)\n\n### Extra overloads\n\n# TODO: generalize the below logic to compactly handle a wider range of functions.\n# See also https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/dual.jl.\n\nfunction Base.hash(st::StochasticTriple, hsh::UInt)\n    if !isempty(st.Δs)\n        error(\"Hashing a stochastic triple with perturbations not yet supported.\")\n    end\n    hash(StochasticAD.value(st), hsh)\nend\n\n#=\nThis is a hacky experimental way to convert a float-like stochastic triple\ninto an integer-like one, to facilitate generic coding.\n=#\nfunction Base.round(I::Type{<:Integer}, st::StochasticTriple{T, V}) where {T, V}\n    return StochasticTriple{T}(round(I, st.value), map(Δ -> round(I, st.value + Δ), st.Δs))\nend\n\nfor op in UNARY_TYPEFUNCS_NOWRAP\n    function (::typeof(op))(::Type{<:StochasticTriple{T, V, FIs}}) where {T, V, FIs}\n        return op(V)\n    end\nend\n\nfor op in UNARY_TYPEFUNCS_WRAP\n    function (::typeof(op))(::Type{StochasticTriple{T, V, FIs}}) where {T, V, FIs}\n        return StochasticTriple{T, V, FIs}(op(V), zero(V), empty(FIs))\n    end\n    function (::typeof(op))(st::StochasticTriple)\n        return op(typeof(st))\n    end\nend\n\nfor op in RNG_TYPEFUNCS_WRAP\n    function (::typeof(op))(rng::AbstractRNG,\n            ::Type{StochasticTriple{T, V, FIs}}) where {T, V, FIs}\n        return StochasticTriple{T, V, FIs}(op(rng, V), zero(V), empty(FIs))\n    end\nend\n\n#=\nThe short-circuit \"x == y\" case in Base.isapprox is bad for us\nbecause it could unnecessarily lead to a boolean-predicate\ndepends on output error where StochasticAD cannot prove correctness.\nWe patch up the rule by removing the short-circuit, allowing some common\ncases to work.\n\nIn the future, we will ideally handle the overloading rule in a more general\nway. (E.g. by catching the chain rule for isapprox and recursively calling isapprox\non the values.)\n=#\nfunction Base.isapprox(st1::StochasticTriple, st2::StochasticTriple;\n        atol::Real = 0, rtol::Real = Base.rtoldefault(st1, st2, atol),\n        nans::Bool = false, norm::Function = abs)\n    (isfinite(st1) && isfinite(st2) &&\n     norm(st1 - st2) <= max(atol, rtol * max(norm(st1), norm(st2)))) ||\n        (nans && isnan(st1) && isnan(st2))\nend\nfunction Base.isapprox(st1::StochasticTriple, x::Real;\n        atol::Real = 0, rtol::Real = Base.rtoldefault(st1, x, atol),\n        nans::Bool = false, norm::Function = abs)\n    (isfinite(st1) && isfinite(x) &&\n     norm(st1 - x) <= max(atol, rtol * max(norm(st1), norm(x)))) ||\n        (nans && isnan(st1) && isnan(x))\nend\nfunction Base.isapprox(x::Real, st::StochasticTriple; kwargs...)\n    return Base.isapprox(st, x; kwargs...)\nend\n\n# Alternate version of _isassigned that does not fall back on try/catch.\n_isassigned(C, i) = (i in eachindex(C))\n\n\"\"\"\n    Base.getindex(C::AbstractArray, st::StochasticTriple{T})\n\nA simple prototype rule for array indexing. Assumes that underlying type of `st` can index into collection C.\n\"\"\"\n# TODO: support multiple indices, cartesian indices, non abstract array indexables, other use cases...\n# Example to fix: A[:, :, st]\nfunction Base.getindex(C::AbstractArray, st::StochasticTriple{T, V, FIs}) where {T, V, FIs}\n    val = C[st.value]\n    do_map = (Δ, state) -> begin\n        return value(C[st.value + Δ], state) - value(val, state)\n    end\n\n    # TODO: below doesn't support sparse arrays, use something like nextind\n    deriv = δ -> begin\n        scale = if _isassigned(C, st.value + 1) && _isassigned(C, st.value - 1)\n            1 / 2 * (value(C[st.value + 1]) - value(C[st.value - 1]))\n        elseif _isassigned(C, st.value + 1)\n            value(C[st.value + 1]) - value(C[st.value])\n        elseif _isassigned(C, st.value - 1)\n            value(C[st.value]) - value(C[st.value - 1])\n        else\n            zero(eltype(C))\n        end\n        return scale * δ\n    end\n\n    Δs = StochasticAD.map_Δs(do_map, st.Δs; deriv, out_rep = value(val))\n    if val isa StochasticTriple\n        Δs = combine((Δs, val.Δs))\n    end\n    return StochasticTriple{T}(value(val), delta(val), Δs)\nend\n"
  },
  {
    "path": "src/misc.jl",
    "content": "@doc raw\"\"\"\n    StochasticModel(X, p)\n\nCombine stochastic program `X` with parameter `p` into \na trainable model using [Functors](https://fluxml.ai/Functors.jl/stable/), where\n`p <: AbstractArray`.\nFormulate as a minimization problem, i.e. find ``p`` that minimizes ``\\mathbb{E}[X(p)]``.\n\"\"\"\nstruct StochasticModel{S <: AbstractArray, T}\n    X::T\n    p::S\nend\n@functor StochasticModel (p,)\n\n\"\"\"\n    stochastic_gradient(m::StochasticModel)\n\nCompute gradient with respect to the trainable parameter `p` of `StochasticModel(X, p)`.\n\"\"\"\nfunction stochastic_gradient(m::StochasticModel)\n    fmap(p -> derivative_estimate(m.X, p), m)\nend\n"
  },
  {
    "path": "src/prelude.jl",
    "content": "const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)\n\nconst UNARY_PREDICATES = [isinf, isnan, isfinite, iseven, isodd, isreal, isinteger]\n\nconst BINARY_PREDICATES = [\n    isequal,\n    isless,\n    <,\n    >,\n    ==,\n    !=,\n    <=,\n    >=\n]\n\nconst UNARY_TYPEFUNCS_NOWRAP = [Base.rtoldefault]\nconst UNARY_TYPEFUNCS_WRAP = [\n    typemin,\n    typemax,\n    floatmin,\n    floatmax,\n    zero,\n    one\n]\nconst RNG_TYPEFUNCS_WRAP = [rand, randn, randexp]\n\n\"\"\"\n    structural_iterate(args)\n\nInternal helper function for iterating through the scalar values of a functor, \nwhere AbstractFIs are also counted as scalars.\n\"\"\"\nfunction structural_iterate(args)\n    make_iterator(x) = x isa AbstractArray ? x : (x,)\n    exclude(x) = Functors.isleaf(x) || (x isa AbstractFIs)\n    iter = fmap(make_iterator, args; walk = Functors.IterateWalk(), cache = nothing,\n        exclude)\n    return iter\nend\nstructural_iterate(args::NTuple{N, Union{Real, AbstractFIs}}) where {N} = args\nstructural_iterate(args::AbstractArray{T}) where {T <: Union{Real, AbstractFIs}} = args\nstructural_iterate(args::T) where {T <: Real} = (args,)\n\n\"\"\"\n    structural_map(f, args)\n\nInternal helper function for a structure-preserving map, \noften to be used on a function's input/output arguments. \nCurrently uses [fmap](https://fluxml.ai/Functors.jl/stable/api/#Functors.fmap) \nfrom Functors.jl as a backend.\n\"\"\"\nfunction structural_map(f, args...; only_vals = nothing)\n    walk = if only_vals isa Val{true}\n        Functors.StructuralWalk()\n    elseif (only_vals isa Val{false}) || isnothing(only_vals)\n        Functors.DefaultWalk()\n    else\n        error(\"Unsupported argument only_vals = $only_vals\")\n    end\n    fmap((args...) -> args[1] isa AbstractArray ? f.(args...) : f(args...), args...;\n        cache = nothing,\n        walk)\nend\n"
  },
  {
    "path": "src/propagate.jl",
    "content": "\"\"\"\nA version of `value`` that allows unrecognized args to pass through. \n\"\"\"\nfunction get_value(arg)\n    if arg isa StochasticTriple\n        return value(arg)\n    else\n        # potentially dangerous, see also note in get_Δs\n        return arg\n    end\nend\n\nfunction get_Δs(arg, FIs)\n    if arg isa StochasticTriple\n        return arg.Δs\n    else\n        #=\n        this case is a bit dangerous: perturbations could be dropped here\n        if a leaf of a functor somehow contains a type that is not one of \n        the two above.\n        =#\n        return empty(similar_type(FIs, typeof(arg)))\n    end\nend\n\nfunction strip_Δs(arg; use_dual = Val(true))\n    if arg isa StochasticTriple\n        # TODO: replace check below with a more robust notion of discreteness.\n        if valtype(arg) <: Integer\n            return value(arg)\n        else\n            if use_dual isa Val{true}\n                return ForwardDiff.Dual{tag(arg)}(value(arg), delta(arg))\n            else\n                return StochasticAD.StochasticTriple{tag(arg)}(\n                    value(arg), delta(arg), empty(arg.Δs))\n            end\n        end\n    else\n        return arg\n    end\nend\n\n\"\"\"\n    propagate(f, args...; keep_deltas = Val(false))\n\nPropagates `args` through a function `f`, handling stochastic triples by independently running `f` on the primal\nand the alternatives, rather than by inspecting the internals of `f` (which may possibly be unsupported by `StochasticAD`).\nCurrently handles deterministic functions `f` with any input and output that is `fmap`-able by `Functors.jl`.\nIf `f` has a continuously differentiable component, provide `keep_deltas = Val(true)`.\n\nThis functionality is orthogonal to dispatch: the idea is for this function to be the \"backend\" for operator \noverloading rules based on dispatch. For example:\n\n```jldoctest\nusing StochasticAD, Distributions\nimport Random # hide\nRandom.seed!(4321) # hide\n\nfunction mybranch(x)\n    str = repr(x) # string-valued intermediate!\n    if length(str) < 2\n        return 3\n    else\n        return 7\n    end\nend\n\nfunction f(x)\n    return mybranch(9 + rand(Bernoulli(x)))\nend\n\n# stochastic_triple(f, 0.5) # this would fail\n\n# Add a dispatch rule for mybranch using StochasticAD.propagate\nmybranch(x::StochasticAD.StochasticTriple) = StochasticAD.propagate(mybranch, x)\n\nstochastic_triple(f, 0.5) # now works\n\n# output\n\nStochasticTriple of Int64:\n3 + 0ε + (4 with probability 2.0ε)\n```\n\n!!! warning\n    This function is experimental and subject to change.\n\"\"\"\nfunction propagate(f,\n        args...;\n        keep_deltas = Val(false),\n        provided_st_rep = nothing,\n        deriv = nothing)\n    # TODO: support kwargs to f (or just use kwfunc in macro)\n    #= \n    TODO: maybe don't iterate through every scalar of array below, \n    but rather have special array dispatch\n    =#\n    st_rep = if provided_st_rep === nothing\n        args_iter = structural_iterate(args)\n        function args_fold(arg1, arg2)\n            if arg1 isa StochasticTriple\n                if (arg2 isa StochasticTriple) && (tag(arg1) !== tag(arg2))\n                    throw(ArgumentError(\"Tags of combined stochastic triples do not match!\"))\n                end\n                return arg1\n            else\n                return arg2\n            end\n        end\n        foldl(args_fold, args_iter)\n    else\n        provided_st_rep\n    end\n\n    if !(st_rep isa StochasticTriple)\n        return f(args...)\n    end\n\n    primal_args = structural_map(get_value, args)\n    input_args = keep_deltas isa Val{false} ? primal_args : structural_map(strip_Δs, args)\n    #= \n    TODO: the below is dangerous is general.\n    It should be safe so long as f does not close over stochastic triples.\n    (If f is a closure, the parameters of f should be treated like any other parameters;\n    if they are stochastic triples and we are ignoring them, dangerous in general.)\n    =#\n    out = f(input_args...)\n    val = structural_map(value, out)\n    # TODO: what does the only_vals do in the below and why?\n    Δs_all = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), args;\n        only_vals = Val{true}())\n    # TODO: Coupling approach below needs to handle non-perturbable objects.\n    Δs_coupled = couple(backendtype(st_rep), Δs_all; rep = st_rep.Δs, out_rep = val)\n\n    function map_func(Δ_coupled)\n        perturbed_args = structural_map(+, primal_args, Δ_coupled)\n        #= \n        TODO: for f discrete random with randomness independent of params,\n        could couple here. But difficult without a splittable RNG. \n        =#\n        alt = f(perturbed_args...)\n        return structural_map((x, y) -> value(x) - y, alt, val)\n    end\n    Δs = map(map_func, Δs_coupled; out_rep = val, deriv)\n    # TODO: make sure all FI backends support interface needed below\n    new_out = structural_map(out, scalarize(Δs; out_rep = val)) do leaf_out, leaf_Δs\n        StochasticAD.StochasticTriple{tag(st_rep)}(value(leaf_out), delta(leaf_out),\n            leaf_Δs)\n    end\n    return new_out\nend\n"
  },
  {
    "path": "src/smoothing.jl",
    "content": "### Particle resampling\n\n@doc raw\"\"\"\n    new_weight(p::Real)\n\n    Simulate a Bernoulli variable whose primal output is always 1. \n    Uses a smoothing rule for use in forward and reverse-mode AD, which is exactly unbiased when the quantity is only\n    used in linear functions  (e.g. used as an [importance weight](https://en.wikipedia.org/wiki/Importance_sampling)).\n\"\"\"\nnew_weight(p::Real) = 1\n\nfunction new_weight(p::ForwardDiff.Dual{T}) where {T}\n    Δp = ForwardDiff.partials(p)\n    val_p = ForwardDiff.value(p)\n    val_p = max(1e-5, val_p) # TODO: is this necessary?\n    ForwardDiff.Dual{T}(one(val_p), Δp / val_p)\nend\n\nfunction ChainRulesCore.frule((_, Δp), ::typeof(new_weight), p::Real)\n    val_p = max(1e-5, p) # TODO: is this necessary?\n    return one(p), Δp / val_p\nend\n\nfunction ChainRulesCore.rrule(::typeof(new_weight), p)\n    function new_weight_pullback(∇Ω)\n        return (ChainRulesCore.NoTangent(), ∇Ω / p)\n    end\n    return (one(p), new_weight_pullback)\nend\n\n# Smoothed rules for univariate single-parameter distributions. \n\nfunction smoothed_delta(d, val, δ, derivative_coupling)\n    Δs_empty = SmoothedFIs{typeof(val)}(0.0)\n    return derivative_contribution(δtoΔs(d, val, δ, Δs_empty, derivative_coupling))\nend\n\nfor (dist, i, field) in [\n    (:Geometric, :1, :p),\n    (:Bernoulli, :1, :p),\n    (:Binomial, :2, :p),\n    (:Poisson, :1, :λ),\n    (:Categorical, :1, :p)\n] # i = index of parameter p\n    # dual overloading \n    @eval function Base.rand(rng::AbstractRNG,\n            d_dual::$dist{<:ForwardDiff.Dual{T}}) where {T}\n        return randst(rng, d_dual)\n    end\n    @eval function randst(rng::AbstractRNG,\n            d_dual::$dist{<:ForwardDiff.Dual{T}};\n            derivative_coupling = InversionMethodDerivativeCoupling()) where {T}\n        dual = params(d_dual)[$i]\n        # dual could represent an array of duals or a single one; map handles both cases.\n        p = map(value, dual)\n        # Generate a δ for each partial component.\n        partials_indices = ntuple(identity, length(first(dual).partials))\n        δs = map(i -> map(d -> ForwardDiff.partials(d)[i], dual), partials_indices)\n        d = $dist(params(d_dual)[1:($i - 1)]..., p,\n            params(d_dual)[($i + 1):end]...)\n        val = convert(Signed, rand(rng, d))\n        partials = ForwardDiff.Partials(map(\n            δ -> smoothed_delta(d, val, δ, derivative_coupling), δs))\n        ForwardDiff.Dual{T}(val, partials)\n    end\n    # frule\n    @eval function ChainRulesCore.frule(Δargs, ::typeof(rand), rng::AbstractRNG,\n            d::$dist)\n        return frule(Δargs, randst, rng, d)\n    end\n    @eval function ChainRulesCore.frule((_, _, Δd), ::typeof(randst), rng::AbstractRNG,\n            d::$dist; derivative_coupling = InversionMethodDerivativeCoupling())\n        val = convert(Signed, rand(rng, d))\n        Δval = smoothed_delta(d, val, Δd, derivative_coupling)\n        return (val, Δval)\n    end\n    # rrule\n    @eval function ChainRulesCore.rrule(::typeof(rand), rng::AbstractRNG, d::$dist)\n        return rrule(randst, rng, d)\n    end\n    @eval function ChainRulesCore.rrule(::typeof(randst),\n            rng::AbstractRNG,\n            d::$dist;\n            derivative_coupling = InversionMethodDerivativeCoupling())\n        val = convert(Signed, rand(rng, d))\n        function rand_pullback(∇out)\n            p = params(d)[$i]\n            if p isa Real\n                Δp = smoothed_delta(d, val, one(val), derivative_coupling)\n            else\n                # TODO: this rule is O(length(p)^2), whereas we should be able to do O(length(p)) by reversing through δtoΔs.\n                I = eachindex(p)\n                V = eltype(p)\n                onehot(i) = map(j -> j == i ? one(V) : zero(V), I)\n                Δp = map(i -> smoothed_delta(d, val, onehot(i), derivative_coupling), I)\n            end\n            # rrule_via_ad approach below not used because slow.\n            # Δp = rrule_via_ad(config, smoothed_delta, d, val, map(one, p))[2](∇out)[4]\n            Δd = ChainRulesCore.Tangent{typeof(d)}(; $field = ∇out * Δp)\n            return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), Δd)\n        end\n        return (val, rand_pullback)\n    end\nend\n"
  },
  {
    "path": "src/stochastic_triple.jl",
    "content": "\"\"\" \n    StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}}\n\nStores the primal value of the computation, alongside a \"dual\" component\nrepresenting an infinitesimal change, and a \"triple\" component that tracks\ndiscrete change(s) with infinitesimal probability. \n\nPretty printed as \"value + δε + ({pretty print of Δs})\".\n\n## Constructor\n\n- `value`: the primal value.\n- `δ`: the value of the almost-sure derivative, i.e. the rate of \"infinitesimal\" change.\n- `Δs`: alternate values with associated weights, i.e. Finite perturbations with Infinitesimal probability,\n        represented by a backend `FIs <: AbstractFIs`.\n\"\"\"\nstruct StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}} <: Real\n    value::V\n    δ::V # infinitesimal change\n    Δs::FIs # finite changes with infinitesimal probabilities # (Δ = 3, p = 1*h)\n    function StochasticTriple{T, V, FIs}(value::V, δ::V,\n            Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}\n        new{T, V, FIs}(value, δ, Δs)\n    end\nend\n\n\"\"\"\n    value(st::StochasticTriple)\n\nReturn the primal value of `st`.\n\"\"\"\nvalue(x::Real, state = nothing) = x\n# Experimental method for obtaining the alternate value of a stochastic triple associated with a certain backend state.\nvalue(st::StochasticTriple) = st.value\nfunction value(st::StochasticTriple, state)\n    st.value + filter_state(st.Δs, state)\nend\n#=\nSupport ForwardDiff.Dual for internal usage.\nAssumes batch size is 1.\n=#\nvalue(d::ForwardDiff.Dual, state = nothing) = ForwardDiff.value(d)\n\n\"\"\"\n    delta(st::StochasticTriple)\n\nReturn the almost-sure derivative of `st`, i.e. the rate of infinitesimal change.\n\"\"\"\ndelta(x::Real) = zero(x)\ndelta(st::StochasticTriple) = st.δ\n# Support ForwardDiff.Dual for internal usage.\ndelta(d::ForwardDiff.Dual) = ForwardDiff.partials(d)[1]\n\n\"\"\"\n    perturbations(st::StochasticTriple)\n\nReturn the finite perturbation(s) of `st`, in a format dependent on the [backend](devdocs.md) used for storing perturbations.\n\"\"\"\nperturbations(x::Real) = ()\nperturbations(st::StochasticTriple) = perturbations(st.Δs)\n\n\"\"\"\n    send_signal(st::StochasticTriple, signal::AbstractPerturbationSignal)\n    send_signal(Δs::StochasticAD.AbstractFIs, signal::AbstractPerturbationSignal)\n\nSend a certain signal to a stochastic triple's perturbation collection `st.Δs` (or to a `Δs` directly), \nwhich the backend may process as it wishes. Semantically, unbiasedness should not be affected by the \nsending of the signal. The new version of the first argument (`st` or `Δs`) after signal processing is \nreturned.\n\"\"\"\nsend_signal(st::Real, ::AbstractPerturbationSignal) = st\nfunction send_signal(st::StochasticTriple{T}, signal::AbstractPerturbationSignal) where {T}\n    new_Δs = send_signal(st.Δs, signal)\n    return StochasticTriple{T}(st.value, st.δ, new_Δs)\nend\n\n\"\"\"\n    derivative_contribution(st::StochasticTriple)\n\nReturn the derivative estimate given by combining the dual and triple components of `st`.\n\"\"\"\nderivative_contribution(x::Real) = zero(x)\nderivative_contribution(d::ForwardDiff.Dual) = delta(d)\nderivative_contribution(st::StochasticTriple) = st.δ + derivative_contribution(st.Δs)\n\n\"\"\"\n    tag(st::StochasticTriple)\n    tag(::Type{<:StochasticTriple}))\n\nGet the tag of a stochastic triple.\n\"\"\"\ntag(::Type{<:StochasticTriple{T}}) where {T} = T\ntag(::Type{<:ForwardDiff.Dual{T}}) where {T} = T\ntag(::StochasticTriple{T}) where {T} = T\ntag(::ForwardDiff.Dual{T}) where {T} = T\n\n\"\"\"\n    valtype(st::StochasticTriple)\n    valtype(st::Type{<:StochasticTriple})\n\nGet the underlying type of the value tracked by a stochastic triple.\n\"\"\"\nvaltype(st::StochasticTriple) = valtype(typeof(st))\nvaltype(::Type{<:StochasticTriple{T, V}}) where {T, V} = V\n\n\"\"\"\n    backendtype(st::StochasticTriple)\n    backendtype(st::Type{<:StochasticTriple})\n\nGet the backend type of a stochastic triple.\n\"\"\"\nbackendtype(st::StochasticTriple) = backendtype(typeof(st))\nbackendtype(::Type{<:StochasticTriple{T, V, FIs}}) where {T, V, FIs} = FIs\n\n\"\"\"\n    smooth_triple(st::StochasticTriple)\n\nSmooth the dual and triple components of a stochastic triple into a single dual component.\nUseful for avoiding unnecessary pruning when running multilinear functions on triples.\n\"\"\"\nsmooth_triple(x::Real) = x\nfunction smooth_triple(st::StochasticTriple{T, V, FIs}) where {T, V, FIs}\n    return StochasticTriple{T}(value(st), derivative_contribution(st), empty(FIs))\nend\n\n### Extra constructors of stochastic triples\n\nfunction StochasticTriple{T}(value::V, δ::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}\n    StochasticTriple{T, V, FIs}(value, δ, Δs)\nend\n\nfunction StochasticTriple{T}(value::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}\n    StochasticTriple{T}(value, zero(value), Δs)\nend\n\nfunction StochasticTriple{T}(value::A, δ::B,\n        Δs::FIs) where {T, A, B, C, FIs <: AbstractFIs{C}}\n    V = promote_type(A, B, C)\n    StochasticTriple{T}(convert(V, value), convert(V, δ), convert(similar_type(FIs, V), Δs))\nend\n\n### Conversion rules\n\n# TODO: is this the right thing to do? Maybe, different from the promote case because there V was guaranteed to be an ancestor. \n# Also, bad to do when already same type?\nfunction Base.convert(::Type{StochasticTriple{T1, V, FIs}},\n        x::StochasticTriple{T2}) where {T1, T2, V, FIs}\n    (T1 !== T2) && throw(ArgumentError(\"Tags of combined stochastic triples do not match.\"))\n    StochasticTriple{T1, V, FIs}(convert(V, x.value), convert(V, x.δ), convert(FIs, x.Δs))\nend\n\n# TODO: ForwardDiff's promotion rules are a little more complicated, see https://github.com/JuliaDiff/ForwardDiff.jl/issues/322\n# May need to look into why and possibly use them here too.\nfunction Base.promote_rule(::Type{StochasticTriple{T, V1, FIs}},\n        ::Type{StochasticTriple{T, V2, FIs2}}) where {T, V1, FIs, V2,\n        FIs2}\n    V = promote_type(V1, V2)\n    StochasticTriple{T, V, similar_type(FIs, V)}\nend\n\nfunction Base.promote_rule(::Type{StochasticTriple{T, V1, FIs}},\n        ::Type{V2}) where {T, V1, FIs, V2 <: Real}\n    V = promote_type(V1, V2)\n    StochasticTriple{T, V, similar_type(FIs, V)}\nend\n\nfunction Base.convert(::Type{StochasticTriple{T, V, FIs}}, x::Real) where {T, V, FIs}\n    StochasticTriple{T, V, FIs}(convert(V, x), zero(V), empty(FIs))\nend\n\n### Creating the first stochastic triple in a computation\n\nfunction StochasticTriple{T}(value::V, δ::V, backend::AbstractFIsBackend) where {T, V}\n    StochasticTriple{T}(value, δ, create_Δs(backend, V))\nend\n\nfunction StochasticTriple{T}(value::V, backend::AbstractFIsBackend) where {T, V}\n    StochasticTriple{T}(value, zero(V), backend)\nend\n\nfunction StochasticTriple{T}(value::A, δ::B, backend::AbstractFIsBackend) where {T, A, B}\n    V = promote_type(A, B)\n    StochasticTriple{T}(convert(V, value), convert(V, δ), backend)\nend\n\n### Showing a stochastic triple\n\nfunction Base.summary(::StochasticTriple{T, V}) where {T, V}\n    return \"StochasticTriple of $V\"\nend\n\nfunction Base.show(io::IO, ::MIME\"text/plain\", st::StochasticTriple)\n    println(io, \"$(summary(st)):\")\n    show(io, st)\nend\n\nfunction Base.show(io::IO, st::StochasticTriple)\n    print(io, \"$(st.value) + $(st.δ)ε\")\n    if (!isempty(st.Δs))\n        print(io, \" + ($(repr(st.Δs)))\")\n    end\nend\n\n### Higher level functions\n\nstruct Tag{F, V}\nend\n\nfunction stochastic_triple_direction(f, p::V, direction; backend) where {V}\n    Δs = create_Δs(backend, Int) # TODO: necessity of hardcoding some type here suggests interface improvements\n    sts = structural_map(p, direction) do p_i, direction_i\n        StochasticTriple{Tag{typeof(f), V}}(p_i, direction_i,\n            similar_empty(Δs, typeof(p_i)))\n    end\n    return f(sts)\nend\n\n\"\"\"\n    stochastic_triple(X, p; backend=PrunedFIsBackend(), direction=nothing)\n    stochastic_triple(p; backend=PrunedFIsBackend(), direction=nothing)\n\nFor any `p` that is supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/),\ne.g. scalars or abstract arrays,\ndifferentiate the output with respect to each value of `p`,\nreturning an output of similar structure to `p`, where a particular value contains\nthe stochastic-triple output of `X` when perturbing the corresponding value in `p`\n(i.e. replacing the original value `x` with `x + ε`).\n\nWhen `direction` is provided, return only the stochastic-triple output of `X` with respect to a perturbation\nof `p` in that particular direction.\nWhen `X` is not provided, the identity function is used. \n\nThe `backend` keyword argument describes the algorithm used by the third component\nof the stochastic triple, see [technical details](devdocs.md) for more details.\n\n# Example\n```jldoctest\njulia> using Distributions, Random, StochasticAD; Random.seed!(4321);\n\njulia> stochastic_triple(rand ∘ Bernoulli, 0.5)\nStochasticTriple of Int64:\n0 + 0ε + (1 with probability 2.0ε)\n```\n\"\"\"\nfunction stochastic_triple(\n        f, p; direction = nothing, backend::AbstractFIsBackend = PrunedFIsBackend())\n    if direction !== nothing\n        return stochastic_triple_direction(f, p, direction; backend)\n    end\n    counter = begin\n        c = 0\n        (_) -> begin\n            c += 1\n            return c\n        end\n    end\n    indices = structural_map(counter, p)\n    map_func = perturbed_index -> begin\n        direction = structural_map(indices, p) do i, p_i\n            i == perturbed_index ? one(p_i) : zero(p_i)\n        end\n        stochastic_triple_direction(f, p, direction; backend)\n    end\n    return structural_map(map_func, indices)\nend\n\nstochastic_triple(p; kwargs...) = stochastic_triple(identity, p; kwargs...)\n\n\"\"\"\n    dual_number(X, p; backend=PrunedFIsBackend(), direction=nothing)\n    dual_number(p; backend=PrunedFIsBackend(), direction=nothing)\n\nA lightweight wrapper around [`stochastic_triple`](#StochasticAD.stochastic_triple) that entirely ignores the\nderivative contribution of all discrete random components, so that it behaves like a regular dual number.\nMostly for fun -- this, of course, leads to a useless derivative estimate for discrete random functions!\n\"\"\"\nfunction dual_number(f, p; backend = PrunedFIsBackend(), kwargs...)\n    backend = StrategyWrapperFIsBackend(backend, IgnoreDiscreteStrategy())\n    stochastic_triple(f, p; backend, kwargs...)\nend\ndual_number(p; kwargs...) = dual_number(identity, p; kwargs...)\n\nfunction derivative_estimate(f, p; kwargs...)\n    StochasticAD.structural_map(derivative_contribution, stochastic_triple(f, p; kwargs...))\nend\n"
  },
  {
    "path": "test/game_of_life.jl",
    "content": "using StochasticAD\nusing Test\nusing Statistics\n\ninclude(\"../tutorials/game_of_life/core.jl\")\nusing .GoLCore: fd_clever, play, p, nsamples\n\n@testset \"AD and Finite Differences\" begin\n    samples_fd_clever = [fd_clever(p) for i in 1:nsamples]\n    samples_st = [derivative_estimate(play, p) for i in 1:nsamples]\n\n    @test mean(samples_st)≈mean(samples_fd_clever) rtol=5e-2\nend\n"
  },
  {
    "path": "test/random_walk.jl",
    "content": "using StochasticAD\nusing Test\nusing Statistics\nusing ForwardDiff: derivative\n\ninclude(\"../tutorials/random_walk/core.jl\")\nusing .RandomWalkCore: n, p, nsamples\nusing .RandomWalkCore: fX, get_dfX\n\n@testset \"Check unbiasedness\" begin\n    fX_deriv = derivative(p -> get_dfX(p, n), p)\n    fX_deriv_estimate = mean(derivative_estimate(fX, p) for i in 1:nsamples)\n    @test isapprox(fX_deriv, fX_deriv_estimate; rtol = 1e-2)\nend\n"
  },
  {
    "path": "test/resampling.jl",
    "content": "using StochasticAD\nusing Random, Test\nusing Distributions\nusing LinearAlgebra\nusing ForwardDiff\n\n# test forward-mode AD and reverse-mode AD on the particle filter\n\n### Particle Filter Functions\ninclude(\"../tutorials/particle_filter/core.jl\")\nseed = 237347578\n\n### Define model\nRandom.seed!(seed)\n\nT = 3\nd = 2\nA(θ, a = 0.01) = [exp(-a)*cos(θ[]) exp(-a)*sin(θ[])\n                  -exp(-a)*sin(θ[]) exp(-a)*cos(θ[])]\nobs(x, θ) = MvNormal(x, 0.01 * collect(I(d)))\ndyn(x, θ) = MvNormal(A(θ) * x, 0.02 * collect(I(d)))\nx0 = [2.0, 0.0] # start value of the simulation\nstart(θ) = Dirac(x0)\nθtrue = [0.20]\n# put it all together\nstochastic_model = ParticleFilterCore.StochasticModel(T, start, dyn, obs)\n\n### simulate model\nRandom.seed!(seed)\nxs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue)\n###\n\n### initialize sampler\nm = 1000\nparticle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys,\n    ParticleFilterCore.sample_stratified)\n###\n\n@testset \"new weight\" begin\n    p = 0.5\n    st = stochastic_triple(p)\n    d = ForwardDiff.Dual(p, (1.0, 2.0))\n    @test new_weight(p) == one(p)\n    @test StochasticAD.value(new_weight(st)) == one(p)\n    @test StochasticAD.delta(new_weight(st)) == 1.0 / p\n    @test ForwardDiff.value(new_weight(d)) == one(p)\n    @test collect(ForwardDiff.partials(new_weight(d))) == [1.0 / p, 2.0 / p]\nend\n\n@testset \"forward-mode and reverse-mode AD: single run\" begin\n    Random.seed!(seed)\n    grad_forw = ParticleFilterCore.forw_grad(θtrue, particle_filter)\n    Random.seed!(seed)\n    grad_back = ParticleFilterCore.back_grad(θtrue, particle_filter)\n    @test grad_forw ≈ grad_back\nend\n\n@testset \"AD and Finite Differences\" begin\n    h = 0.02 # finite diff\n    N = 500 # number of samples\n    grad_fw = [ParticleFilterCore.forw_grad(θtrue, particle_filter)[1] for i in 1:N]\n    # grad_bw = @time [back_grad(θtrue, particle_filter) for i in 1:N]\n    grad_fd = [(ParticleFilterCore.log_likelihood(particle_filter, θtrue .+ h) -\n                ParticleFilterCore.log_likelihood(particle_filter, θtrue .- h)) / (2h)\n               for i in 1:N]\n\n    @test mean(grad_fd)≈mean(grad_fw) rtol=5e-2\nend\n"
  },
  {
    "path": "test/runtests.jl",
    "content": "using SafeTestsets\nusing Test, Pkg\nimport Random\n\nRandom.seed!(1234)\n\nconst GROUP = get(ENV, \"GROUP\", \"All\")\nconst is_APPVEYOR = Sys.iswindows() && haskey(ENV, \"APPVEYOR\")\n\n@time begin\n    if GROUP == \"All\"\n        @time @safetestset \"Triples\" begin\n            include(\"triples.jl\")\n        end\n        @time @safetestset \"Game of life\" begin\n            include(\"game_of_life.jl\")\n        end\n        @time @safetestset \"Random walk\" begin\n            include(\"random_walk.jl\")\n        end\n        @time @safetestset \"Resampling\" begin\n            include(\"resampling.jl\")\n        end\n    end\nend\n"
  },
  {
    "path": "test/triples.jl",
    "content": "using StochasticAD\nusing Test\nusing Distributions\nusing ForwardDiff\nusing OffsetArrays\nusing ChainRulesCore\nusing Random\nusing Zygote\n\nconst backends = [\n    PrunedFIsBackend(),\n    PrunedFIsAggressiveBackend(),\n    DictFIsBackend()\n]\n\nconst backends_smoothed = [\n    SmoothedFIsBackend(),\n    StrategyWrapperFIsBackend(SmoothedFIsBackend(), StochasticAD.TwoSidedStrategy())\n]\n\n@testset \"Distributions w.r.t. continuous parameter\" begin\n    for backend in vcat(backends,\n        backends_smoothed,\n        :smoothing_autodiff)\n        MAX = 10000\n        nsamples = 100000\n        rtol = 5e-2 # friendly tolerance for stochastic comparisons. TODO: more motivated choice of tolerance.\n\n        ### Make test cases\n\n        distributions = [\n            Bernoulli,\n            Geometric,\n            Poisson,\n            (p -> Categorical([p^2, 1 - p^2])),\n            (p -> Categorical([0, p^2, 0, 0, 1 - p^2])), # check that 0's are skipped over\n            (p -> Categorical([1.0, exp(p)] ./ (1.0 + exp(p)))), # test fix for #38 (floating point comparisons in Categorical logic)\n            (p -> Binomial(3, p)),\n            (p -> Binomial(20, p))\n        ]\n        p_ranges = [(0.2, 0.8) for _ in 1:8]\n        out_ranges = [0:1, 0:MAX, 0:MAX, 1:2, 1:5, 1:2, 0:3, 0:20]\n        test_cases = collect(zip(distributions, p_ranges, out_ranges))\n        test_funcs = [x -> 7 * x - 3, x -> (x + 1)^2, x -> sqrt(x + 1)]\n\n        if backend isa DictFIsBackend\n            # Only test dictionary backend on Bernoulli to speed things up. Should still cover interface.\n            test_cases = test_cases[1:1]\n        elseif backend == :smoothing_autodiff || backend in backends_smoothed\n            # Only test smoothing backend on each unique distribution once to seed tests up. \n            test_cases = vcat(test_cases[1:4], test_cases[7])\n            # Only test unbiasedness of smoothing for linear function\n            test_funcs = test_funcs[1:1]\n        end\n\n        for (distr, p_range, out_range) in test_cases\n            for f in test_funcs\n                function get_mean(p)\n                    dp = distr(p)\n                    sum(pdf(dp, i) * f(i) for i in out_range)\n                end\n\n                low_p, high_p = p_range\n                for g in [p -> p, p -> high_p + low_p - p] # test both sides of derivative\n                    full_func = f ∘ rand ∘ distr ∘ g\n                    p = low_p + (high_p - low_p) * rand()\n                    exact_deriv = ForwardDiff.derivative(p -> get_mean(g(p)), p)\n                    if backend == :smoothing_autodiff\n                        batched_full_func(p) = mean([full_func(p) for i in 1:nsamples])\n                        # The array input used for ForwardDiff below is a trick to test multiple partials\n                        triple_deriv_forward = mean(ForwardDiff.gradient(\n                            arr -> batched_full_func(sum(arr)),\n                            [2 * p, -p]))\n                        triple_deriv_backward = Zygote.gradient(batched_full_func, p)[1]\n                        @test isapprox(triple_deriv_forward, exact_deriv, rtol = rtol)\n                        @test isapprox(triple_deriv_backward, exact_deriv, rtol = rtol)\n                    else\n                        get_deriv = () -> derivative_estimate(full_func, p; backend)\n                        triple_deriv = mean(get_deriv() for i in 1:nsamples)\n                        @test isapprox(triple_deriv, exact_deriv, rtol = rtol)\n                    end\n                end\n            end\n        end\n    end\nend\n\n@testset \"Perturbing n of binomial\" begin\n    function get_triple_deriv(Δ)\n        # Manually create a finite perturbation to avoid any randomness in its creation\n        Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), Δ,\n            3.5)\n        st = StochasticAD.StochasticTriple{0}(5, 0, Δs)\n        st_continuous = stochastic_triple(0.5)\n        return derivative_contribution(rand(Binomial(st, st_continuous)))\n    end\n    for Δ in -2:2\n        triple_deriv = mean(get_triple_deriv(Δ) for i in 1:100000)\n        exact_deriv = 3.5 * 0.5 * Δ + 5\n        @test isapprox(triple_deriv, exact_deriv, rtol = 5e-2)\n    end\nend\n\n@testset \"Nested binomials\" begin\n    binbin = p -> rand(Binomial(rand(Binomial(10, p)), p)) # ∼ Binomial(10, p^2)\n    for p in [0.3, 0.7]\n        triple_deriv = mean(derivative_estimate(binbin, p) for i in 1:100000)\n        exact_deriv = 10 * 2 * p\n        @test isapprox(triple_deriv, exact_deriv, rtol = 5e-2)\n    end\nend\n\n@testset \"Boolean comparisons\" begin\n    for backend in backends\n        tested = falses(2)\n        while !(all(tested))\n            st = stochastic_triple(rand ∘ Bernoulli, 0.5; backend)\n            x = StochasticAD.value(st)\n            if x == 0\n                # Ensure errors on unsafe/unsupported boolean comparisons\n                @test_throws Exception st>0.5\n                @test_throws Exception 0.5<st\n                @test_throws Exception st==0\n            else\n                @test st > 0.5\n                @test 0.5 < st\n                @test st == 1\n            end\n            tested[x + 1] = true\n        end\n        @test stochastic_triple(1.0; backend) != 1\n    end\nend\n\n@testset \"Array indexing\" begin\n    for backend in vcat(backends, backends_smoothed)\n        p = 0.3\n        # Test indexing into array of floats with stochastic triple index\n        arr = [3.5, 5.2, 8.4]\n        (backend in backends_smoothed) && (arr[3] = 6.9) # make linear for smoothing test\n        function array_index(p)\n            index = rand(Categorical([p / 2, p / 2, 1 - p]))\n            return arr[index]\n        end\n        array_index_mean(p) = sum([p / 2, p / 2, (1 - p)] .* arr)\n        triple_array_index_deriv = mean(derivative_estimate(array_index, p; backend)\n        for i in 1:50000)\n        exact_array_index_deriv = ForwardDiff.derivative(array_index_mean, p)\n        @test isapprox(triple_array_index_deriv, exact_array_index_deriv, rtol = 5e-2)\n        # Don't run subsequent tests with smoothing backend\n        (backend in backends_smoothed) && continue\n        # Test indexing into array of stochastic triples with stochastic triple index\n        function array_index2(p)\n            arr2 = [rand(Bernoulli(p)), rand(Bernoulli(p)), rand(Bernoulli(p))] .* arr\n            index = rand(Categorical([p / 2, p / 2, 1 - p]))\n            return arr2[index]\n        end\n        array_index2_mean(p) = sum([p / 2 * p, p / 2 * p, (1 - p) * p] .* arr)\n        triple_array_index2_deriv = mean(derivative_estimate(array_index2, p; backend)\n        for i in 1:50000)\n        exact_array_index2_deriv = ForwardDiff.derivative(array_index2_mean, p)\n        @test isapprox(triple_array_index2_deriv, exact_array_index2_deriv, rtol = 5e-2)\n        # Test case where triple and alternate array value are coupled\n        function array_index3(p)\n            st = rand(Bernoulli(p))\n            arr2 = [-5, st]\n            return arr2[st + 1]\n        end\n        array_index3_mean(p) = -5 * (1 - p) + 1 * p\n        triple_array_index3_deriv = mean(derivative_estimate(array_index3, p; backend)\n        for i in 1:50000)\n        exact_array_index3_deriv = ForwardDiff.derivative(array_index3_mean, p)\n        @test isapprox(triple_array_index3_deriv, exact_array_index3_deriv, rtol = 5e-2)\n    end\nend\n\n@testset \"Array/functor inputs to higher level functions\" begin\n    for backend in backends\n        # Try a deterministic test function to compare to ForwardDiff\n        f(x) = (x[1] * x[2] * sin(x[3]) + exp(x[1] * x[2])) / x[3]\n        x = [1, 2, π / 2]\n\n        stochastic_ad_grad = derivative_estimate(f, x; backend)\n        stochastic_ad_grad2 = derivative_contribution.(stochastic_triple(f, x; backend))\n        stochastic_ad_grad_firsttwo = derivative_estimate(\n            f, x; direction = [1.0, 1.0, 0.0],\n            backend)\n        fd_grad = ForwardDiff.gradient(f, x)\n        @test stochastic_ad_grad ≈ fd_grad\n        @test stochastic_ad_grad ≈ stochastic_ad_grad2\n        @test stochastic_ad_grad[1] + stochastic_ad_grad[2] ≈ stochastic_ad_grad_firsttwo\n\n        # Try an OffsetArray too\n        f_off(x) = (x[0] * x[1] * sin(x[2]) + exp(x[0] * x[1])) / x[2]\n        x_off = OffsetArray([1, 2, π / 2], 0:2)\n        stochastic_ad_grad_off = derivative_estimate(f_off, x_off)\n        @test stochastic_ad_grad_off ≈ OffsetArray(stochastic_ad_grad, 0:2)\n\n        # Try a Functor\n        f_func(x) = (x[1] * x[2][1] * sin(x[2][2]) + exp(x[1] * x[2][1])) / x[2][2]\n        x_func = (1, [2, π / 2])\n        stochastic_ad_grad_func = derivative_estimate(f_func, x_func)\n        stochastic_ad_grad_func_expected = (stochastic_ad_grad[1], stochastic_ad_grad[2:3])\n        compare_grad_funcs = StochasticAD.structural_map(≈, stochastic_ad_grad_func,\n            stochastic_ad_grad_func_expected)\n        @test all(compare_grad_funcs |> StochasticAD.structural_iterate)\n\n        # Test StochasticModel + stochastic_gradient combination\n        m = StochasticModel(f, x)\n        @test stochastic_gradient(m).p ≈ stochastic_ad_grad\n    end\nend\n\n@testset \"Propagation using frule with ZeroTangent\" begin\n    st = stochastic_triple(0.5)\n\n    # Verify that the rule for imag indeed gives a ZeroTangent\n    value = StochasticAD.value(st)\n    δ = StochasticAD.delta(st)\n    @test frule((NoTangent(), δ), imag, value)[2] isa ZeroTangent\n    # Test that stochastic triples flow through this rule\n    out_st = imag(st)\n    @test StochasticAD.value(out_st) ≈ 0\n    @test StochasticAD.delta(out_st) ≈ 0\n    @test isempty(out_st.Δs)\nend\n\n@testset \"Unary functions converting type to fixed instance\" begin\n    for val in [0.5, 1]\n        st = stochastic_triple(val)\n        for op in StochasticAD.UNARY_TYPEFUNCS_WRAP\n            f = getfield(Base, Symbol(op))\n            out_st = f(st)\n            @test out_st isa StochasticAD.StochasticTriple\n            @test StochasticAD.value(out_st) ≈ f(val) ≈ f(typeof(val))\n            @test StochasticAD.delta(out_st) ≈ 0\n            @test isempty(out_st.Δs)\n            @test f(typeof(st)) == out_st\n        end\n        #=\n        It so happens that the UNARY_TYPEFUNCS_WRAP funcs all support both instances and types\n        whereas UNARY_TYPEFUNCS_NOWRAP only supports types, so we only test types in the below,\n        but this is a coincidence that may not hold in the future.\n        =#\n        for op in StochasticAD.UNARY_TYPEFUNCS_NOWRAP\n            f = getfield(Base, Symbol(op))\n            out = f(typeof(st))\n            @test out isa typeof(val)\n            @test out ≈ f(typeof(val))\n        end\n        RNG = copy(Random.GLOBAL_RNG)\n        for op in StochasticAD.RNG_TYPEFUNCS_WRAP\n            f = getfield(Random, Symbol(op))\n            out_st = f(copy(RNG), typeof(st))\n            @test out_st isa StochasticAD.StochasticTriple\n            @test StochasticAD.value(out_st) ≈ f(copy(RNG), typeof(val))\n            @test StochasticAD.delta(out_st) ≈ 0\n            @test isempty(out_st.Δs)\n        end\n    end\nend\n\n@testset \"Hashing\" begin\n    st = stochastic_triple(3.0)\n    @test_nowarn hash(st)\n    @test_nowarn hash(st, UInt(5))\n    d = Dict()\n    @test_nowarn d[st] = 5\n    @test d[st] == 5\n    @test d[3] == 5\n    # Test that we get an error with discrete random dictionary indices,\n    # since this isn't supported and we want to avoid silent failures.\n    Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0, 1.0)\n    st = StochasticAD.StochasticTriple{0}(1.0, 0, Δs)\n    @test_throws ErrorException d[rand(Bernoulli(st))]\nend\n\n@testset \"Coupled comparison\" begin\n    Δs_1 = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0,\n        1.0)\n    Δs_2 = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0,\n        1.0)\n    st_1 = StochasticAD.StochasticTriple{0}(1.0, 0, Δs_1)\n    st_2 = StochasticAD.StochasticTriple{0}(1.0, 0, Δs_2)\n    @test st_1 == st_1\n    @test_throws ErrorException st_1==st_2\nend\n\n@testset \"Converting float stochastic triples to integer triples\" begin\n    st = stochastic_triple(0.6)\n    @test round(Int, st) isa StochasticAD.StochasticTriple\n    @test StochasticAD.delta(round(Int, st)) ≈ 0\n    @test round(Int, st) ≈ 1\nend\n\n@testset \"Approximate comparisons\" begin\n    st = stochastic_triple(0.5)\n    @test st ≈ st\n    # Check that the rtol is indeed reasonable\n    @test st ≈ st + 1e-14\n    @test !(st ≈ st + 1)\n    @test_broken stochastic_triple(Inf) ≈ stochastic_triple(Inf)\nend\n\n@testset \"Error on unmatched tags\" begin\n    st1 = stochastic_triple(0.5)\n    st2 = stochastic_triple(x -> x^2, 0.5)\n    @test_throws ArgumentError convert(typeof(st1), st2)\nend\n\n@testset \"Finite perturbation backend interface\" begin\n    for backend in vcat(backends,\n        backends_smoothed)\n        # this boolean may need to become more fine-grained in the future\n        is_smoothed_backend = backend in backends_smoothed\n        #=\n        Test the backend interface across the finite perturbation backends,\n        which is currently a bit implicitly defined.\n        =#\n        V0 = Int\n        V1 = Float64\n        #=\n        All four of the below approaches should create an empty backend,\n        although the backend's internal state management may differ. \n        =#\n        Δs0 = StochasticAD.create_Δs(backend, V0) # used to create first triple in computation\n        FIs = typeof(Δs0)\n        Δs1 = empty(Δs0)\n        Δs2 = empty(typeof(Δs0))\n        Δs3 = StochasticAD.similar_empty(Δs0, V1)\n        for (Δs, V) in ((Δs0, V0), (Δs1, V0), (Δs2, V0), (Δs3, V1))\n            @test StochasticAD.valtype(Δs) === V\n            @test Δs isa StochasticAD.similar_type(FIs, V)\n            !is_smoothed_backend && @test isempty(Δs)\n            @test iszero(derivative_contribution(Δs))\n        end\n        # Test creation of a single perturbation\n        for Δ in (1, 3.0)\n            Δs0 = StochasticAD.create_Δs(backend, V0)\n            Δs1 = StochasticAD.similar_new(Δs0, Δ, 3.0)\n            @test StochasticAD.valtype(Δs1) === typeof(Δ)\n            @test Δs1 isa StochasticAD.similar_type(FIs, typeof(Δ))\n            !is_smoothed_backend && @test !isempty(Δs1)\n            @test derivative_contribution(Δs1) == 3Δ\n            # Test StochasticAD.alltrue\n            @test StochasticAD.alltrue(_Δ -> true, Δs1)\n            @test !StochasticAD.alltrue(_Δ -> false, Δs1) || is_smoothed_backend\n            # Test map\n            # We use a dummy deriv here and below. TODO: use a more interesting dummy for better testing.\n            Δs1_map = Base.map(Δ -> Δ^2, Δs1; deriv = identity, out_rep = Δ)\n            !is_smoothed_backend && @test derivative_contribution(Δs1_map) ≈ Δ^2 * 3.0\n            # Test map with weight (make a new copy so that original does not get reweighted)\n            Δs2 = StochasticAD.similar_new(StochasticAD.create_Δs(backend, V0), Δ, 3.0)\n            Δs2_weight_map = StochasticAD.weighted_map_Δs((Δ, _) -> (Δ^2, 2.0),\n                Δs2;\n                deriv = identity,\n                out_rep = Δ)\n            !is_smoothed_backend &&\n                @test derivative_contribution(Δs2_weight_map) ≈ Δ^2 * 3.0 * 2.0\n            # Also test scale\n            w2 = derivative_contribution(Δs2)\n            Δs2_scaled = StochasticAD.scale(Δs2, 2.0)\n            w2_scaled = derivative_contribution(Δs2_scaled)\n            @test w2_scaled ≈ 2.0 * w2\n            # Test map_Δs with filter state\n            if !is_smoothed_backend\n                Δs1_plus_Δs0 = StochasticAD.map_Δs(\n                    (Δ, state) -> Δ +\n                                  StochasticAD.filter_state(Δs0,\n                        state),\n                    Δs1)\n                @test derivative_contribution(Δs1_plus_Δs0) ≈ Δ * 3.0\n                Δs1_plus_mapped = StochasticAD.map_Δs(\n                    (Δ, state) -> Δ +\n                                  StochasticAD.filter_state(Δs1,\n                        state),\n                    Δs1_map)\n                @test derivative_contribution(Δs1_plus_mapped) ≈ Δ * 3.0 + Δ^2 * 3.0\n            end\n        end\n        # Test coupling\n        Δ_coupleds = (3, [4.0, 5.0], (2, [3.0, 4.0]))\n        for Δ_coupled in Δ_coupleds\n            function get_Δs_coupled(; do_combine = false, use_get_rep = false)\n                Δs0 = StochasticAD.create_Δs(backend, Int)\n                Δs1 = StochasticAD.similar_new(Δs0, 1, 3.0) # perturbation 1\n                Δs2 = StochasticAD.similar_new(Δs0, 1, 2.0) # perturbation 2\n                # A group of perturbations that all stem from perturbation 1. \n                Δs_all1 = StochasticAD.structural_map(Δ_coupled) do Δ\n                    Base.map(_Δ -> Δ, Δs1; deriv = identity, out_rep = Δ)\n                end\n                # A group of perturbations that all stem from perturbation 2. \n                Δs_all2 = StochasticAD.structural_map(Δ_coupled) do Δ\n                    Base.map(_Δ -> 2 * Δ, Δs2; deriv = (δ -> 2δ), out_rep = Δ)\n                end\n                # Join them into a single structure that should be coupled\n                Δs_all = (Δs_all1, Δs_all2)\n                kwargs = use_get_rep ? (; rep = StochasticAD.get_rep(FIs, Δs_all)) : (;)\n                if do_combine\n                    return StochasticAD.combine(FIs, Δs_all; kwargs...)\n                else\n                    return StochasticAD.couple(FIs, Δs_all;\n                        out_rep = (Δ_coupled, Δ_coupled),\n                        kwargs...)\n                end\n            end\n            #=\n            As a test function to apply to the coupled perturbation, we apply\n            a matmul followed by a sigmoid activation function and a sum.\n            =#\n            l = 2 * length(collect(StochasticAD.structural_iterate(Δ_coupled)))\n            A = rand(l, l)\n            function mapfunc(Δ_coupled)\n                arr = collect(StochasticAD.structural_iterate(Δ_coupled))\n                sum(x -> 1 / (1 + exp(-x)), A * arr)\n            end\n            # Test the above function, and also a simple sum.\n            for use_get_rep in (false, true)\n                Δs_coupled = get_Δs_coupled(; use_get_rep)\n                @test StochasticAD.valtype(Δs_coupled) == typeof((Δ_coupled, Δ_coupled))\n                for (mapfunc, check_combine) in ((mapfunc, false),\n                    (Δ_coupled -> sum(StochasticAD.structural_iterate(Δ_coupled)),\n                        true))\n                    function get_contribution()\n                        Δs_coupled = get_Δs_coupled(; use_get_rep)\n                        Δs_coupled_mapped = map(mapfunc, Δs_coupled; deriv = (δ -> 1.0),\n                            out_rep = 0.0)\n                        return derivative_contribution(Δs_coupled_mapped)\n                    end\n                    zero_Δ_coupled = StochasticAD.structural_map(zero, Δ_coupled)\n                    expected_contribution1 = 3.0 * mapfunc((Δ_coupled, zero_Δ_coupled))\n                    expected_contribution2 = 2.0 * mapfunc((zero_Δ_coupled,\n                        StochasticAD.structural_map(x -> 2x,\n                            Δ_coupled)))\n                    expected_contribution = expected_contribution1 + expected_contribution2\n                    if !is_smoothed_backend\n                        @test isapprox(mean(get_contribution() for i in 1:1000),\n                            expected_contribution; rtol = 5e-2)\n                    end\n                    # For a simple sum, this should be equivalent to the combine behaviour.\n                    if check_combine && !is_smoothed_backend\n                        @test isapprox(\n                            mean(derivative_contribution(get_Δs_coupled(;\n                                     do_combine = true))\n                            for i in 1:1000),\n                            expected_contribution;\n                            rtol = 5e-2)\n                    end\n                    # Check scalarize\n                    Δs_coupled2 = StochasticAD.couple(FIs,\n                        StochasticAD.scalarize(Δs_coupled;\n                            out_rep = (Δ_coupled,\n                                Δ_coupled)),\n                        out_rep = (Δ_coupled, Δ_coupled))\n                    @test derivative_contribution(map(mapfunc, Δs_coupled;\n                        deriv = (δ -> 1.0),\n                        out_rep = 0.0)) ≈\n                          derivative_contribution(map(mapfunc, Δs_coupled2;\n                        deriv = (δ -> 1.0),\n                        out_rep = 0.0))\n                end\n            end\n        end\n    end\nend\n\n@testset \"Getting information about stochastic triples\" begin\n    for backend in vcat(backends,\n        backends_smoothed)\n        Random.seed!(4321)\n        f(x) = rand(Bernoulli(x)) + x\n        st = stochastic_triple(f, 0.5; backend)\n        # Expected: 0.5 + 1.0ε + (1.0 with probability 2.0ε)\n        dual = ForwardDiff.Dual(0.5, 1.0)\n\n        @test StochasticAD.value(0.5) == 0.5\n        @test StochasticAD.value(st) == 0.5\n        @test StochasticAD.value(dual) == 0.5\n\n        @test iszero(StochasticAD.delta(0.5))\n        @test StochasticAD.delta(st) == 1.0\n        @test StochasticAD.delta(dual) == 1.0\n\n        if !(backend in backends_smoothed)\n            #= \n            NB: since the implementation of perturbations can be backend-specific, the\n            below property need not hold in general, but does for the current non-smoothed backends.\n            =#\n            p = only(perturbations(st))\n            @test p.Δ == 1 && p.weight == 2.0\n            @test derivative_contribution(st) == 3.0\n        else\n            # Since smoothed algorithm uses the two-sided strategy, we get a different derivative contribution.\n            @test derivative_contribution(st) == 2.0\n        end\n\n        @test StochasticAD.tag(st) === StochasticAD.Tag{typeof(f), Float64}\n        @test StochasticAD.valtype(st) === Float64\n        @test StochasticAD.valtype(st.Δs) === Float64\n    end\nend\n\n@testset \"Propagation via StochasticAD.propagate\" begin\n    for backend in backends\n        function form_triple(primal, δ, Δ, Δs_base)\n            Δs = map(_Δ -> Δ, Δs_base)\n            return StochasticAD.StochasticTriple{0}(primal, δ, Δs)\n        end\n\n        function test_propagate(f, primals, Δs; test_deltas = false)\n            Δs_base = StochasticAD.similar_new(StochasticAD.create_Δs(backend, Int),\n                0, 1.0)\n            _form_triple(x, δ, Δ) = form_triple(x, δ, Δ, Δs_base)\n            out = f(primals...)\n            out_Δ_expected = StochasticAD.structural_map(-,\n                f(StochasticAD.structural_map(+,\n                    primals,\n                    Δs)...),\n                f(primals...))\n            if test_deltas\n                duals = StochasticAD.structural_map(primals) do x\n                    x isa AbstractFloat ? ForwardDiff.Dual{0}(x, rand(typeof(x))) : x\n                end\n                δs = StochasticAD.structural_map(StochasticAD.delta, duals)\n                out_δ_expected = StochasticAD.structural_map(StochasticAD.delta,\n                    f(duals...))\n            else\n                δs = StochasticAD.structural_map(zero, primals)\n                out_δ_expected = StochasticAD.structural_map(zero, out)\n            end\n            input_sts = StochasticAD.structural_map(_form_triple, primals, δs, Δs)\n            out_st = StochasticAD.propagate(f, input_sts...; keep_deltas = Val{test_deltas})\n            # Test type\n            StochasticAD.structural_map(out_st, out, out_δ_expected,\n                out_Δ_expected) do x_st, x, δ, Δ\n                @test x_st isa StochasticAD.StochasticTriple{0, typeof(x)}\n                @test StochasticAD.value(x_st) == x\n                @test StochasticAD.delta(x_st) ≈ δ\n                p = only(perturbations(x_st))\n                @test p.Δ == Δ && p.weight == 1.0\n            end\n        end\n\n        #=\n        Test propagation through some simple functions. \n            f1: a simple if statement.\n            f2: involves array-containing-fucntor input and output.\n            f3: involves array-containing-functor input, but real output.\n            f4: length ∘ repr (real or array input, real output).\n            f5: mutates input array! Broken since unsupported.\n            f6: the first-arg (blob) should just be passed through without attempting\n                to perturb. Broken since unsupported.\n            f7: involves matrix-containing-functor input and output.\n        =#\n        function f1(x)\n            if x == 0\n                return 1\n            elseif x == 3\n                return 2\n            else\n                return 5\n            end\n        end\n\n        @test StochasticAD.propagate(f1, 0) === f1(0)\n        for (primal, Δ) in [(0, 3), (0, 4), (3, -1)]\n            test_propagate(f1, (primal,), (Δ,))\n        end\n\n        function f2(arr, scalar)\n            if sum(arr) + scalar <= 5\n                return arr .* scalar, sum(arr) * scalar\n            else\n                return arr .- scalar, sum(arr) - scalar\n            end\n        end\n        f3(arr, scalar) = f2(arr, scalar)[2]\n\n        primals1 = ([1, 1], 2)\n        Δs1 = ([2, 3], 5)\n        primals2 = ([1, 2], 1)\n        Δs2 = ([1, -2], 1)\n        primals3 = ([5, 2], -1)\n        Δs3 = ([-3, 1], 0)\n\n        for (primals, Δs) in [(primals1, Δs1), (primals2, Δs2), (primals3, Δs3)]\n            for test_deltas in (false, true)\n                if test_deltas\n                    primals = StochasticAD.structural_map(float, primals)\n                    Δs = StochasticAD.structural_map(float, Δs)\n                end\n                test_propagate(f2, primals, Δs; test_deltas)\n                test_propagate(f3, primals, Δs; test_deltas)\n            end\n        end\n\n        f4(x) = Base.length(repr(x))\n\n        for (primals, Δs) in [(2, 11), (([3, 14],), ([14, -152],))]\n            test_propagate(f4, primals, Δs)\n        end\n\n        function f5(arr)\n            if arr == [1, 2]\n                arr .+= 1\n            else\n                arr .-= 1\n            end\n        end\n\n        # Tests for f6 skipped (would break)\n        for (primals, Δs) in [([1, 2], [1, -1]), ([2, 4], [-1, -2]), ([2, 4], [-1, -1])]\n            @test_skip \"propagate f5\"\n            # test_propagate(f5, primals, Δs)\n        end\n\n        f6(blob, arr) = blob, f5(arr)\n\n        # Tests for f6 missing (would break)\n        @test_skip \"propagate f6\"\n\n        function f7(mat, scalar)\n            return mat * scalar, scalar + sum(mat)\n        end\n\n        test_propagate(f7, (rand(2, 2), 4.0), (rand(2, 2), 1.0); test_deltas = true)\n    end\nend\n\n@testset \"zero'ing of Inf/NaN (#79)\" begin\n    st = stochastic_triple(0.5)\n    st_zero = zero(1 / zero(st))\n    @test iszero(StochasticAD.value(st_zero))\n    @test iszero(StochasticAD.delta(st_zero))\nend\n\n@testset \"smooth_triple\" begin\n    f(p) = sum(rand(Bernoulli(p)) * i for i in 1:100)\n    f2(p) = sum(smooth_triple(rand(Bernoulli(p))) * i for i in 1:100)\n    p = 0.6\n    f_est = mean(derivative_estimate(f, p) for i in 1:10000)\n    f2_est = mean(derivative_estimate(f2, p) for i in 1:10000)\n    @test f_est≈f2_est rtol=5e-2\nend\n\n@testset \"No unnecessary float promotion\" begin\n    f(p) = rand(Bernoulli(p))^2\n    st = stochastic_triple(f, 0.5)\n    @test StochasticAD.valtype(st) == typeof(convert(Signed, f(0.5)))\nend\n"
  },
  {
    "path": "tutorials/Project.toml",
    "content": "[deps]\nBenchmarkTools = \"6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf\"\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nDistributions = \"31c24e10-a181-5473-b8eb-7969acd0382f\"\nDistributionsAD = \"ced4e74d-a319-5a8a-b0ac-84af2272839c\"\nEnzyme = \"7da242da-08ed-463a-9acd-ee780be4f1d9\"\nFileIO = \"5789e2e9-d7fb-5bc7-8068-2c6fae9b9549\"\nForwardDiff = \"f6369f11-7733-5829-9624-2563aa707210\"\nFunctors = \"d9f16b24-f501-4c13-a1f2-28368ffc5196\"\nGR = \"28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71\"\nGaussianDistributions = \"43dcc890-d446-5863-8d1a-14597580bb8d\"\nGeometryBasics = \"5c1252a2-5f33-56bf-86c9-59e7332b4326\"\nLaTeXStrings = \"b964fa9f-0449-5b57-a5c2-d3ea65f4040f\"\nMeasurements = \"eff96d63-e80a-5855-80a2-b1b0885c5ab7\"\nOffsetArrays = \"6fe1bfb0-de20-5000-8ca7-80f57d26f881\"\nPkgBenchmark = \"32113eaa-f34f-5b0d-bd6c-c81e245fc73d\"\nPlots = \"91a5bcdd-55d7-5caf-9e0b-520d859cae80\"\nProgressMeter = \"92933f4c-e287-5a05-a399-4b506db050ca\"\nRandom = \"9a3f8284-a2c9-5f02-9a11-845980a1fd5c\"\nStaticArrays = \"90137ffa-7385-5640-81b9-e52037218182\"\nStatsBase = \"2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91\"\nStochasticAD = \"e4facb34-4f7e-4bec-b153-e122c37934ac\"\nUnPack = \"3a884ed6-31ef-47d7-9d2a-63182c4928ed\"\nZygote = \"e88e6eb3-aa80-5325-afca-941959d7151f\"\n"
  },
  {
    "path": "tutorials/README.md",
    "content": "The raw source code for our tutorials. You likely want to look at the documentation instead, where they are presented more clearly! \n"
  },
  {
    "path": "tutorials/game_of_life/core.jl",
    "content": "module GoLCore\n\nusing Random\nusing Distributions\nusing LinearAlgebra\nusing StochasticAD\nusing StaticArrays\nusing OffsetArrays\n\nfunction update_state!(all_probs, N, board, board_old)\n    for i in (-N):N\n        for j in (-N):N\n            neighbours = board_old[i + 1, j] + board_old[i - 1, j] + board_old[i, j - 1] +\n                         board_old[i, j + 1]\n            index = board[i, j] * 5 + neighbours + 1 # trick necessary because we do not have implementation support for stochastic triple not <: Real\n            b = rand(Bernoulli(all_probs[index]))\n            board[i, j] += (1 - 2 * board[i, j]) * b\n        end\n    end\nend\n\nfunction play_game_of_life(p, all_probs, N, T, log = false)\n    dual_type = promote_type(typeof(rand(Bernoulli(p))),\n        typeof.(rand.(Bernoulli.(all_probs)))...) # TODO: better way of getting the concrete type\n    board = OffsetArray(zeros(dual_type, 2 * N + 3, 2 * N + 3), (-(N + 1)):(N + 1),\n        (-(N + 1)):(N + 1)) # pad by 1\n    for i in (-N):N\n        for j in (-N):N\n            board[i, j] = rand(Bernoulli(p))\n        end\n    end\n    board_old = similar(board)\n    log && (history = [])\n    for time_step in 1:T\n        copy!(board_old, board)\n        update_state!(all_probs, N, board, board_old)\n        log && push!(history, copy(board))\n    end\n    if !log\n        return sum(board)\n    else\n        return sum(board), board, history\n    end\nend\n\nfunction play(p, θ = 0.1, N = 3, T = 3; log = false)\n    # N is the board half-length, T are game time steps\n    low = θ\n    high = 1 - θ\n    birth_probs = SA[low, low, low, high, low] # 0, 1, 2, 3, 4 neighbours\n    death_probs = SA[high, high, low, low, high] # 0, 1, 2, 3, 4 neighbours\n    return play_game_of_life(p, vcat(birth_probs, death_probs), N, T, log)\nend\n\n# An implementation of finite differences that uses \"common random numbers\"\n# (the same seed), for more accurate checking, albeit with a finite step size h\n# such that there is weight degeneracy as h → 0.\nfunction fd_clever(p, h = 0.01)\n    state = copy(Random.default_rng())\n    run1 = play(p + h)\n    copy!(Random.default_rng(), state)\n    run2 = play(p - h)\n    (run1 - run2) / (2h)\nend\n\n# Provide some default parameters\np = 0.5\nnsamples = 200_000\n\nend\n"
  },
  {
    "path": "tutorials/game_of_life/plot_board.jl",
    "content": "include(\"core.jl\")\nusing Plots\nusing Statistics\nusing BenchmarkTools\n\np = 0.5\n_, board, history = stochastic_triple(p -> GoLCore.play(p; log = true), p)\n\nanim1 = @animate for (i, board) in enumerate(history)\n    heatmap(collect(StochasticAD.value.(board)), title = \"time $i\", clim = (-1, 1),\n        c = :grays)\nend\nanim2 = @animate for (i, board) in enumerate(history)\n    heatmap(collect(StochasticAD.derivative_contribution.(board)), title = \"time $i\",\n        clim = (-1, 1), c = :grays)\nend\n\ngif(anim1, \"game.gif\", fps = 15)\ngif(anim2, \"perturbation.gif\", fps = 15)\nfig1 = heatmap(collect(StochasticAD.value.(board)), clim = (-1, 1), c = :grays)\nfig2 = heatmap(collect(derivative_contribution.(board)), clim = (-1, 1), c = :grays) # TODO: graph perturbed values instead of derivative contribution\nsavefig(fig1, \"board.png\")\nsavefig(fig2, \"perturbation.png\")\n"
  },
  {
    "path": "tutorials/particle_filter/benchmark.jl",
    "content": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing BenchmarkTools\nusing Measurements\n\n# Benchmark for primal, forward- and reverse-mode AD of particle sampler\n\n### compute gradients\n# secs for how long the benchmark should run, see https://juliaci.github.io/BenchmarkTools.jl/stable/\nsecs = 10\n\nsuite = BenchmarkGroup()\nsuite[\"scaling\"] = BenchmarkGroup([\"grads\"])\n\nsuite[\"scaling\"][\"primal\"] = @benchmarkable ParticleFilterCore.log_likelihood(\n    particle_filter,\n    θtrue)\nsuite[\"scaling\"][\"forward\"] = @benchmarkable ParticleFilterCore.forw_grad(θtrue,\n    particle_filter)\nsuite[\"scaling\"][\"backward\"] = @benchmarkable ParticleFilterCore.back_grad(θtrue,\n    particle_filter)\n\ntune!(suite)\nresults = run(suite, verbose = true, seconds = secs)\n\nt1 = measurement(mean(results[\"scaling\"][\"primal\"].times),\n    std(results[\"scaling\"][\"primal\"].times) /\n    sqrt(length(results[\"scaling\"][\"primal\"].times)))\nt2 = measurement(mean(results[\"scaling\"][\"forward\"].times),\n    std(results[\"scaling\"][\"forward\"].times) /\n    sqrt(length(results[\"scaling\"][\"forward\"].times)))\nt3 = measurement(mean(results[\"scaling\"][\"backward\"].times),\n    std(results[\"scaling\"][\"backward\"].times) /\n    sqrt(length(results[\"scaling\"][\"backward\"].times)))\n@show t1 t2 t3\n\nts = (t1, t2, t3) ./ 10^6 # ms\n@show ts\n\nBenchmarkTools.save(\"benchmark_data_\" * string(d) * \".json\", results)\n"
  },
  {
    "path": "tutorials/particle_filter/bias.jl",
    "content": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing Random\n\n# Comparison of the derivative of the particle filter with and without differentiating the resampling step.\n\n### compute gradients\nRandom.seed!(seed)\nX = [ParticleFilterCore.forw_grad(θtrue, particle_filter) for i in 1:1000] # gradient of the particle filter *with* differentiation of the resampling step\nRandom.seed!(seed)\nXbiased = [ParticleFilterCore.forw_grad_biased(θtrue, particle_filter) for i in 1:1000] # Gradient of the particle filter *without* differentiation of the resampling step\n# pick an arbitrary coordinate\nindex = 1 # take derivative with respect to first parameter (2-dimensional example has a rotation matrix with four parameters in total)\n# plot histograms for the sampled derivative values\nfig = plot(normalize(fit(Histogram, getindex.(X, index), nbins = 50), mode = :pdf),\n    legend = false) # ours\nplot!(normalize(fit(Histogram, getindex.(Xbiased, index), nbins = 50), mode = :pdf)) # biased\nvline!([mean(X)[index]], color = 1)\nvline!([mean(Xbiased)[index]], color = 2)\n# add derivative of differentiable Kalman filter as a comparison\nXK = ParticleFilterCore.forw_grad_Kalman(θtrue, kalman_filter)\nvline!([XK[index]], color = \"black\")\n\ndisplay(fig)\nsavefig(fig, \"tails.pdf\")\n"
  },
  {
    "path": "tutorials/particle_filter/core.jl",
    "content": "module ParticleFilterCore\n\n# load dependencies\nusing Distributions\nusing DistributionsAD\nusing Random\nusing Statistics\nusing StatsBase\nusing LinearAlgebra\nusing Zygote\nusing StochasticAD\nusing ForwardDiff\nusing GaussianDistributions\nusing GaussianDistributions: correct, ⊕\nusing UnPack\n\n### Particle Filter Functions\n\n# Model defs\n\n\"\"\"\n    StochasticModel{dType<:Integer,TType<:Integer,T1,T2,T3}\n\nFor parameters `θ`,  `rand(start(θ))` gives a sample from the prior distribution of the\nstarting distribution. For current state `x` and parameters `θ`, `xnew = rand(dyn(x, θ))`\nsamples the new state (i.e. `dyn` gives for each `x, θ` a distribution-like object). Finally,\n`y = rand(obs(x, θ))` samples an observation.\n\n## Constructor\n\n- `T`: total number of time steps.\n- `start`: starting distribution for the initial state. For example, in the form of a narrow\n   Gaussian `start(θ) = Gaussian(x0, 0.001 * I(d))`.\n- `dyn`: pointwise differentiable stochastic program in the form of Markov transition densities.\n   For example, `dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q(θ))`, where `Q(θ)` denotes the\n   covariance matrix.\n- `obs`: observation model having a smooth conditional probability density depending on\n   current state `x` and parameters `θ`. For example, `obs(x, θ) = MvNormal(x, R(θ))`,\n   where `R(θ)` denotes the covariance matrix.\n\"\"\"\nstruct StochasticModel{TType <: Integer, T1, T2, T3}\n    T::TType # time steps\n    start::T1 # prior\n    dyn::T2 # dynamical model\n    obs::T3 # observation model\nend\n\n# Particle filter\n\"\"\"\n\n    ParticleFilter{mType<:Integer,MType<:StochasticModel,yType,sType}\n\nWraps a stochastic model `StochM::StochasticModel` and observational data `ys`.\nAssumes a observation-likelihood is available via `pdf(obs(x, θ), y)`.\n\n## Constructor\n\n- `m`: number of particles.\n- `StochM`: stochastic model of type `StochasticModel`.\n- `ys`: observations.\n- `sample_strategy`: strategy for the resampling step of the particle filter. For example,\n  stratified sampling as implemented in `sample_stratified`.\n\"\"\"\nstruct ParticleFilter{mType <: Integer, MType <: StochasticModel, yType, sType}\n    m::mType # number of particles\n    StochM::MType # stochastic model\n    ys::yType # observations\n    sample_strategy::sType # sampling function\nend\n\n# Kalman filter\n\"\"\"\n\n    KalmanFilter{dType<:Integer,MType<:StochasticModel,HType,RType,QType,yType}\n\nDifferentiable Kalman filter following https://github.com/mschauer/Kalman.jl/blob/master/README.md.\nWraps a stochastic model `StochM::StochasticModel` and observational data `ys`. Assumes a\nobservation-likelihood is implemented via `llikelihood(yres, S)`. For example:\n ```\n llikelihood(yres, S) = GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres)\n ```\n\n## Constructor\n\n- `d`: dimension of the state-transition matrix Φ according to x = Φ*x + w with w ~ Normal(0,Q).\n- `StochM`: Stochastic model of type `StochasticModel`.\n- `H`: linear map from the state space into the observed space according to y = H x + ν with ν ~ Normal(0, R).\n- `R`: covariance matrix entering the observation model according to y = H x + ν with ν ~ Normal(0, R).\n- `Q`: covariance matrix entering the state-transition model according to x = Φ*x + w with w ~ Normal(0,Q).\n- `ys`: observations.\n\"\"\"\nstruct KalmanFilter{dType <: Integer, MType <: StochasticModel, HType, RType, QType, yType}\n    # H, R = obs\n    # θ, Q = dyn\n    d::dType\n    StochM::MType # stochastic model\n    H::HType # observation model, maps the true state space into the observed space\n    R::RType # observation model, covariance matrix\n    Q::QType # dynamical model, covariance matrix\n    ys::yType # observations\nend\n\n\"\"\"\n    simulate_single(StochM::StochasticModel, θ)\n\nSimulate a single particle from the forward model returning\na vector of observations (no resampling steps), e.g.\n```\nRandom.seed!(seed)\nxs, ys = simulate_single(StochM, θtrue)\n```\nto get observations ys from the latent states xs based on the\n(true, potentially unknown) parameters θ.\n\"\"\"\nfunction simulate_single(StochM::StochasticModel, θ)\n    @unpack T, start, dyn, obs = StochM\n    x = rand(start(θ))\n    y = rand(obs(x, θ))\n    xs = [x]\n    ys = [y]\n    for t in 2:T\n        x = rand(dyn(x, θ))\n        y = rand(obs(x, θ))\n        push!(xs, x)\n        push!(ys, y)\n    end\n    xs, ys\nend\n\n\"\"\"\n    sample_stratified(p, K, sump=1)\n\nStratified resampling strategy, see for example https://arxiv.org/abs/1202.6163.\nHere, `p` denotes the probabilities of `K` particles with `sump = sum(p)`.\n\"\"\"\nfunction sample_stratified(p, K, sump = 1)\n    n = length(p)\n    U = rand()\n    is = zeros(Int, K)\n    i = 1\n    cw = p[1]\n    for k in 1:K\n        t = sump * (k - 1 + U) / K\n        while cw < t && i < n\n            i += 1\n            @inbounds cw += p[i]\n        end\n        is[k] = i\n    end\n    return is\nend\n\n\"\"\"\n    resample(m, X, W, ω, sample_strategy, use_new_weight=true)\n\nResampling step wrapped for use in particle filter using differentiable\nresampling from the article (`use_new_weight`). Returns states `X_new`\nand weights `W_new` of resampled particles.\n\n## args\n- `m`: number of particles.\n- `X`: current particle states.\n- `W`: current weight vector of the particles.\n- `ω == sum(W)` is an invariant.\n- `sample_strategy`: specific resampling strategy to be used. Currently, only `sample_stratified` is available.\n- `use_new_weight=true`: Allows one to switch between biased, stop-gradient method and\n   differentiable resampling step.\n\"\"\"\nfunction resample(m, X, W, ω, sample_strategy, use_new_weight = true)\n    js = Zygote.ignore(() -> sample_strategy(W, m, ω))\n    X_new = X[js]\n    if use_new_weight\n        # differentiable resampling\n        W_chosen = W[js]\n        W_new = map(w -> ω * new_weight(w / ω) / m, W_chosen)\n    else\n        # stop gradient, biased approach\n        W_new = fill(ω / m, m)\n    end\n    X_new, W_new\nend\n\n\"\"\"\n (F::ParticleFilter)(θ; store_path=false, use_new_weight=true, s=1)\n\nRun particle filter. The particle filter propagates particles with weights `W` preserving the\ninvariant `ω == sum(W)`. `W` is never normalized and `ω` contains therefore likelihood information.\nDefaults to return particle positions and weights at `T` if `store_path=false`.\n\n## args\n- `θ`: parameters for the stochastic program (state-transition and observation model).\n- `store_path=false`: Option to store the path of the particles, e.g. to visualize/inspect\n  their trajectories.\n- `use_new_weight=true`: Option to switch between the stop-gradient and our differentiable\n  resampling step method. Defaults to using differentiable resampling.\n- `s`: controls the number of resampling steps according to `t > 1 && t < T && (t % s == 0)`.\n\"\"\"\nfunction (F::ParticleFilter)(θ; store_path = false, use_new_weight = true, s = 1)\n    # s controls the number of resampling steps\n    @unpack m, StochM, ys, sample_strategy = F\n    @unpack T, start, dyn, obs = StochM\n\n    X = [rand(start(θ)) for j in 1:m] # particles\n    W = [1 / m for i in 1:m] # weights\n    ω = 1 # total weight\n    store_path && (Xs = [X])\n    for (t, y) in zip(1:T, ys)\n        # update weights & likelihood using observations\n        wi = map(x -> pdf(obs(x, θ), y), X)\n        W = W .* wi\n        ω_old = ω\n        ω = sum(W)\n        # resample particles\n        if t > 1 && t < T && (t % s == 0) # && 1 / sum((W / ω) .^ 2) < length(W) ÷ 32\n            X, W = resample(m, X, W, ω, sample_strategy, use_new_weight)\n        end\n        # update particle states\n        if t < T\n            X = map(x -> rand(dyn(x, θ)), X)\n            store_path && Zygote.ignore(() -> push!(Xs, X))\n        end\n    end\n    (store_path ? Xs : X), W\nend\n\n# differentiable Kalman filter, following https://github.com/mschauer/Kalman.jl/blob/master/README.md\nfunction llikelihood(yres, S)\n    GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres)\nend\n\n\"\"\"\n    (F::KalmanFilter)(θ)\n\nRun differentiable Kalman filter. Returns updated posterior state estimate and log likelihood.\n\n## args\n- `θ`: parameters for the stochastic program (state-transition and observation model).\n\"\"\"\nfunction (F::KalmanFilter)(θ)\n    @unpack d, StochM, H, R, Q, ys = F\n    @unpack start = StochM\n\n    x = start(θ)\n    Φ = reshape(θ, d, d)\n\n    x, yres, S = GaussianDistributions.correct(x, ys[1] + R, H)\n    ll = llikelihood(yres, S)\n    xs = Any[x]\n    for i in 2:length(ys)\n        x = Φ * x ⊕ Q\n        x, yres, S = GaussianDistributions.correct(x, ys[i] + R, H)\n        ll += llikelihood(yres, S)\n\n        push!(xs, x)\n    end\n    xs, ll\nend\n\n# compute log-likelihood of Particle Sampler\n\"\"\"\n   log_likelihood(F::ParticleFilter, θ, use_new_weight=true, s=1)\n\nCompute log-likelihood of particle sampler. See `ParticleFilter` for `use_new_weight` and `s`.\n\n## args\n- `θ`: parameters for the stochastic program (state-transition and observation model).\n\"\"\"\nfunction log_likelihood(F::ParticleFilter, θ, use_new_weight = true, s = 1)\n    _, W = F(θ; store_path = false, use_new_weight = use_new_weight, s = s)\n    log(sum(W))\nend\n\n# compute log-likelihood of Kalman Filter\n\"\"\"\n   log_likelihood(F::KalmanFilter, θ)\n\nCompute log-likelihood of Kalman filter.\n\n## args\n- `θ`: parameters for the stochastic program (state-transition and observation model).\n\"\"\"\nfunction log_likelihood(F::KalmanFilter, θ)\n    _, ll = F(θ)\n    ll\nend\n\n# forward differentiation of particle sampler\nfunction forw_grad(θ, F::ParticleFilter; s = 1)\n    ForwardDiff.gradient(θ -> log_likelihood(F, θ, true, s), θ)\nend\n# backward differentiation of particle sampler\nfunction back_grad(θ, F::ParticleFilter; s = 1)\n    Zygote.gradient(θ -> log_likelihood(F, θ, true, s), θ)[1]\nend\n# biased forward differentiation of particle sampler, avoiding differentiation of the resampling step\nfunction forw_grad_biased(θ, F::ParticleFilter; s = 1)\n    ForwardDiff.gradient(θ -> log_likelihood(F, θ, false, s), θ)\nend\n# forward-mode AD of Kalman filter\nforw_grad_Kalman(θ, F::KalmanFilter) = ForwardDiff.gradient(θ -> log_likelihood(F, θ), θ)\nend\n"
  },
  {
    "path": "tutorials/particle_filter/model.jl",
    "content": "# ParticleFilter Model\n\nusing Random, LinearAlgebra, GaussianDistributions, Distributions\n\n# particle filter core function definitions\ninclude(\"core.jl\")\n\n### Define model\n\nd = 2 # dimension\nT = 20 # time steps\n\n# generate a rotation matrix, dynamical model, observation model, prior distribution as a function of d\nfunction generate_system(d, T)\n    # here: n-dimensional rotation matrix\n    seed = 423897\n    Random.seed!(seed)\n\n    M = randn(d, d)\n    c = 0.3 # scaling\n    O = exp(c * (M - transpose(M)) / 2)\n    @assert det(O) ≈ 1\n    @assert transpose(O) * O ≈ I(d)\n    # true parameter\n    θtrue = vec(O)\n\n    # observation model\n    R = 0.01 * collect(I(d))\n    obs(x, θ) = MvNormal(x, R) # y = H x + ν with ν ~ Normal(0, R)\n\n    # dynamical model\n    Q = 0.02 * collect(I(d))\n    dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q) #  x = Φ*x + w with w ~ Normal(0,Q)\n\n    # starting position\n    x0 = randn(d)\n    # prior distribution\n    start(θ) = Gaussian(x0, 0.001 * collect(I(d)))\n\n    # put it all together\n    stochastic_model = ParticleFilterCore.StochasticModel(T, start, dyn, obs)\n\n    # relevant corresponding Kalman filterng defs\n    H_Kalman = collect(I(d))\n    R_Kalman = Gaussian(zeros(Float64, d), R)\n    # Φ_Kalman = O\n    Q_Kalman = Gaussian(zeros(Float64, d), Q)\n\n    ### simulate model\n    Random.seed!(seed)\n    xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue)\n    ###\n\n    ### initialize filters\n    m = 1000 # number of particles\n    kalman_filter = ParticleFilterCore.KalmanFilter(\n        d, stochastic_model, H_Kalman, R_Kalman,\n        Q_Kalman, ys)\n    particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys,\n        ParticleFilterCore.sample_stratified)\n\n    return θtrue, xs, ys, stochastic_model, kalman_filter, particle_filter\nend\n\nθtrue, xs, ys, stochastic_model, kalman_filter, particle_filter = generate_system(d, T)\n"
  },
  {
    "path": "tutorials/particle_filter/variance.jl",
    "content": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing Random\n\nRandom.seed!(seed)\n# Comparison of the variance of the particle filter with and without differentiating the resampling step *as a function of the time steps*.\nvars_pf = []\nvars_pf_biased = []\nTs = 5:5:30\nfor T in Ts\n    # Random.seed!(seed) is fixed in model!\n    θtrue, xs, ys, stochastic_model, kalman_filter, particle_filter = generate_system(d, T)\n    xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue)\n    particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys,\n        ParticleFilterCore.sample_stratified)\n    ### compute var of gradients\n    # Gradient of the particle filter *with* differentiation of the resampling step\n    var_pf = @time var([ParticleFilterCore.forw_grad(θtrue, particle_filter) for i in 1:100])\n    # Gradient of the particle filter *without* differentiation of the resampling step\n    var_pf_biased = @time var([ParticleFilterCore.forw_grad_biased(θtrue, particle_filter)\n                               for i in 1:100])\n\n    push!(vars_pf, var_pf)\n    push!(vars_pf_biased, var_pf_biased)\nend\n\n@show vars_pf\n@show vars_pf_biased\n\n# pick an arbitrary coordinate\nindex = 1 # take derivative with respect to first parameter\nfig = plot(Ts, getindex.(vars_pf, index), color = 1, label = \"unbiased\", size = (300, 250),\n    xlabel = L\"n\", ylabel = \"variance\", legend = :topleft, y_scale = :log)\nscatter!(Ts, getindex.(vars_pf, index), color = 1, label = false)\nplot!(Ts, getindex.(vars_pf_biased, index), color = 2, label = \"biased\")\nscatter!(Ts, getindex.(vars_pf_biased, index), color = 2, label = false)\ndisplay(fig)\nsavefig(fig, \"particle_filter_variance_steps.pdf\")\n\n# Comparison of the variance of the particle filter with and without differentiating the resampling step *as a function of the system size*.\nvars_pf = []\nvars_pf_biased = []\nds = 2:1:6\nfor d in ds\n    # Random.seed!(seed) is fixed in model!\n    θtrue, xs, ys, stochastic_model, kalman_filter, particle_filter = generate_system(d, 10)\n    xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue)\n    particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys,\n        ParticleFilterCore.sample_stratified)\n    ### compute var of gradients\n    # Gradient of the particle filter *with* differentiation of the resampling step\n    var_pf = @time var([ParticleFilterCore.forw_grad(θtrue, particle_filter) for i in 1:50])\n    # Gradient of the particle filter *without* differentiation of the resampling step\n    var_pf_biased = @time var([ParticleFilterCore.forw_grad_biased(θtrue, particle_filter)\n                               for i in 1:50])\n\n    push!(vars_pf, var_pf)\n    push!(vars_pf_biased, var_pf_biased)\nend\n\nfig = plot(ds, getindex.(vars_pf, index), color = 1, label = \"unbiased\", size = (300, 250),\n    xlabel = L\"d\", ylabel = \"variance\", legend = :topleft, y_scale = :log)\nscatter!(ds, getindex.(vars_pf, index), color = 1, label = false)\nplot!(ds, getindex.(vars_pf_biased, index), color = 2, label = \"biased\")\nscatter!(ds, getindex.(vars_pf_biased, index), color = 2, label = false)\ndisplay(fig)\nsavefig(fig, \"particle_filter_variance_size.pdf\")\n"
  },
  {
    "path": "tutorials/particle_filter/visualize.jl",
    "content": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\n\n# visualization of stochastic process (observations and latent states), particle filter, and Kalman filter\n\n### run and visualize filters\nXs, W = particle_filter(θtrue; store_path = true)\nfig = plot(getindex.(xs, 1), getindex.(xs, 2), legend = false, xlabel = L\"x_1\",\n    ylabel = L\"x_2\") # x1 and x2 are bad names..conflicting notation\nscatter!(fig, getindex.(ys, 1), getindex.(ys, 2))\nfor i in 1:min(m, 100) # note that Xs has obs noise.\n    local xs = [Xs[t][i] for t in 1:T]\n    scatter!(fig, getindex.(xs, 1), getindex.(xs, 2), marker_z = 1:T, color = :cool,\n        alpha = 0.1) # color to indicate time step\nend\n\nxs_Kalman, ll_Kalman = kalman_filter(θtrue)\nplot!(getindex.(mean.(xs_Kalman), 1), getindex.(mean.(xs_Kalman), 2), legend = false,\n    color = \"red\")\ndisplay(fig)\nsavefig(fig, \"filter.pdf\")\n"
  },
  {
    "path": "tutorials/random_walk/compare_score.jl",
    "content": "include(\"core.jl\")\nusing Plots, LaTeXStrings\nusing Statistics\nusing StochasticAD\nusing ForwardDiff: derivative\nusing ProgressMeter\n\nbegin\n    stds_triple = Float64[]\n    stds_smoothed = Float64[]\n    stds_score = Float64[]\n    stds_score_baseline = Float64[]\n    @showprogress for (n, p) in zip(RandomWalkCore.n_range, RandomWalkCore.p_range)\n        std_triple = std(derivative_estimate(p -> RandomWalkCore.fX(p, n), p)\n        for i in 1:(RandomWalkCore.nsamples))\n        std_smoothed = std(derivative(\n                               p -> RandomWalkCore.fX(p,\n                                   n;\n                                   hardcode_leftright_step = true),\n                               p)\n        for i in 1:(RandomWalkCore.nsamples))\n        std_score = std(RandomWalkCore.score_fX_deriv(p, n, 0.0)\n        for i in 1:(RandomWalkCore.nsamples))\n        avg = mean(RandomWalkCore.fX(p, n) for i in 1:10000)\n        std_score_baseline = std(RandomWalkCore.score_fX_deriv(p, n, avg)\n        for i in 1:(RandomWalkCore.nsamples))\n        push!(stds_triple, std_triple)\n        push!(stds_score, std_score)\n        push!(stds_score_baseline, std_score_baseline)\n        push!(stds_smoothed, std_smoothed)\n    end\nend\n\n@show stds_triple\n@show stds_score\n@show stds_score_baseline\n@show stds_smoothed\n\nbegin\n    show_smoothed = false\n    fig = plot(RandomWalkCore.n_range, stds_score, color = 2, label = \"score-function\",\n        size = (300, 250),\n        xlabel = L\"n\", ylabel = \"standard deviation\", legend = :topleft)\n    scatter!(RandomWalkCore.n_range, stds_score, color = 2, label = false)\n    plot!(RandomWalkCore.n_range, stds_score_baseline, color = 3,\n        label = \"score-function w/ CV\")\n    scatter!(RandomWalkCore.n_range, stds_score_baseline, color = 3, label = false)\n    plot!(RandomWalkCore.n_range, stds_triple, color = 1, label = \"stochastic triples\")\n    scatter!(RandomWalkCore.n_range, stds_triple, color = 1, label = false)\n    if show_smoothed\n        plot!(RandomWalkCore.n_range,\n            stds_smoothed,\n            color = 4,\n            label = \"smoothed stochastic triples\")\n        scatter!(RandomWalkCore.n_range, stds_smoothed, color = 4, label = false)\n    end\n    display(fig)\n    plot!(fig, dpi = 500)\n    savefig(fig, \"random_walk.png\")\nend\n"
  },
  {
    "path": "tutorials/random_walk/core.jl",
    "content": "module RandomWalkCore\n\nusing Random\nusing Statistics\nusing Distributions\nusing LinearAlgebra\nusing StochasticAD\nusing StaticArrays\nusing OffsetArrays: Origin\nimport ForwardDiff\nusing ForwardDiff: Dual, derivative, value, partials\n\n## Parameters\n\nsteps = SA[-1, 1]\nmake_probs(p) = X -> SA[1 - exp(-X / p), exp(-X / p)]\nf = x -> x^2 # function to apply to X\n\nn = 50# number of steps\np = 100 # default parameter value\nn_range = 10:10:100 # range for testing asymptotics\np_range = 2 .* n_range\n\nnsamples = 10000 # number of times to run gradient estimators\n\n## Simulate\n\nfunction simulate_walk(probs, steps, n; debug = false, hardcode_leftright_step = false)\n    X = 0\n    for i in 1:n\n        probs_X = probs(X) # transition probabilities\n        debug && @show probs_X\n        step_index = rand(Categorical(probs_X)) # produces an integer-valued StochasticTriple\n        debug && @show step_index\n        if hardcode_leftright_step\n            step = 2 * (step_index - 1) - 1\n        else\n            step = steps[step_index] # differentiate through array indexing\n        end\n        X += step\n        debug && @show X\n    end\n    return X\nend\n\nX(p, n; kwargs...) = simulate_walk(make_probs(p), steps, n; kwargs...)\nfX(p, n; kwargs...) = f(X(p, n; kwargs...))\nX(p; kwargs...) = X(p, n; kwargs...)\nfX(p; kwargs...) = fX(p, n; kwargs...)\n\n## Simulate with score method manually added on\n\nfunction simulate_walk_score(probs, steps, n; debug = false)\n    X = 0.0\n    dlogP = 0.0\n    for i in 1:n\n        probs_X = probs(X) # transition probabilities\n        step_index = convert(Int, ForwardDiff.value(rand(Categorical(probs_X)))) # just a number\n        step = steps[step_index] # differentiate through array indexing\n        dlogP += partials(log(probs_X[step_index]))[1]\n        X += step # take step\n    end\n    return (X, dlogP)\nend\n\nscore_X(p, n) = simulate_walk_score(make_probs(Dual(p, 1.0)), steps, n)\nfunction score_X_deriv(p, n, avg)\n    X, dlogP = score_X(p, n)\n    (X - avg) * dlogP\nend\nfunction score_fX_deriv(p, n, avg)\n    X, dlogP = score_X(p, n)\n    return (f(X) - avg) * dlogP\nend\nscore_X_deriv(p; avg = 0.0) = score_X_deriv(p, n, avg)\nscore_fX_deriv(p; avg = 0.0) = score_fX_deriv(p, n, avg)\n\n## Exactly compute transition matrix M\n\nrange = 0:n\nrange_start = 1 # range[range_start] = 0\n\nfunction get_M(p)\n    probs = make_probs(p)\n    M = zeros(eltype(first(probs(range[range_start]))), length(range), length(range))\n    low = minimum(range)\n    for x in range\n        for (step, prob) in zip(steps, probs(x))\n            if (x + step) in range\n                M[x + step - low + 1, x - low + 1] = prob\n            end\n        end\n    end\n    M\nend\n\nfunction probdensity(p, n)\n    M = get_M(p)\n    vec = zeros(length(range))\n    vec[range_start] = 1\n    M^n * vec\nend\n\nget_dX(p, n) = sum(probdensity(p, n) .* range)\nget_dfX(p, n) = sum(probdensity(p, n) .* (f.(range)))\n\nend\n"
  },
  {
    "path": "tutorials/random_walk/show_unbiased.jl",
    "content": "include(\"core.jl\")\nprintln(\"## Exact computation\\n\")\n\nusing ForwardDiff: derivative\nusing BenchmarkTools\nusing .RandomWalkCore: n, p, nsamples\nusing .RandomWalkCore: X, f, fX, get_dX, get_dfX\nusing .RandomWalkCore: score_X_deriv, score_fX_deriv\nusing StochasticAD\nusing Statistics\nimport Random\n\nX_deriv = derivative(p -> get_dX(p, n), p)\nfX_deriv = derivative(p -> get_dfX(p, n), p)\nprintln(\"X derivative: $X_deriv\")\nprintln(\"f(X) derivative: $fX_deriv\")\nprintln()\n\nprintln(\"## Stochastic triple computation\\n\")\n\n@btime fX(p)\n@btime derivative_estimate(fX, p; backend = PrunedFIsAggressiveBackend())\n@btime derivative_estimate(fX, p; backend = PrunedFIsBackend())\n\ntriple_X_derivs = [derivative_estimate(X, p) for i in 1:nsamples]\ntriple_fX_derivs = [derivative_estimate(fX, p) for i in 1:nsamples]\nprintln(\"Stochastic triple X derivative mean: $(mean(triple_X_derivs))\")\nprintln(\"Stochastic triple X derivative std : $(std(triple_X_derivs))\")\nprintln(\"Stochastic triple f(X) derivative mean: $(mean(triple_fX_derivs))\")\nprintln(\"Stochastic triple f(X) derivative std: $(std(triple_fX_derivs))\")\nprintln()\n\nsmoothed_X_derivs = [derivative(p -> X(p; hardcode_leftright_step = true), p)\n                     for i in 1:nsamples]\nsmoothed_fX_derivs = [derivative(p -> fX(p; hardcode_leftright_step = true), p)\n                      for i in 1:nsamples]\nprintln(\"Smoothed X derivative mean: $(mean(smoothed_X_derivs))\")\nprintln(\"Smoothed X derivative std : $(std(smoothed_X_derivs))\")\nprintln(\"Smoothed f(X) derivative mean: $(mean(smoothed_fX_derivs))\")\nprintln(\"Smoothed f(X) derivative std: $(std(smoothed_fX_derivs))\")\nprintln()\n\nprintln(\"## Score function computation\\n\")\n\n# baseline\navg_X = mean(X(p) for i in 1:10000)\navg_fX = mean(fX(p) for i in 1:10000)\nscore_X_derivs = [score_X_deriv(p; avg = avg_X)\n                  for i in 1:nsamples]\nscore_fX_derivs = [score_fX_deriv(p; avg = avg_fX)\n                   for i in 1:nsamples]\nprintln(\"Score X derivative mean: $(mean(score_X_derivs))\")\nprintln(\"Score X derivative std: $(std(score_X_derivs))\")\nprintln(\"Score f(X) derivative mean: $(mean(score_fX_derivs))\")\nprintln(\"Score f(X) derivative std: $(std(score_fX_derivs))\")\nprintln()\n\nprintln(\"## Finite differences\\n\")\n\nfunction fd(X, p, h = 10)\n    state = copy(Random.default_rng())\n    run1 = X(p - h / 2)\n    copy!(Random.default_rng(), state)\n    run2 = X(p + h / 2)\n    (run2 - run1) / h\nend\n\nfd_X_derivs = [fd(X, p) for i in 1:nsamples]\nfd_fX_derivs = [fd(f ∘ X, p) for i in 1:nsamples]\nprintln(\"FD X derivative mean: $(mean(fd_X_derivs))\")\nprintln(\"FD X derivative std: $(std(fd_X_derivs))\")\nprintln(\"FD f(X) derivative mean: $(mean(fd_fX_derivs))\")\nprintln(\"FD f(X) derivative std: $(std(fd_fX_derivs))\")\nprintln()\n"
  },
  {
    "path": "tutorials/reverse_example/reverse_demo.jl",
    "content": "#text # Simple reverse mode example \n\n#text ```@setup random_walk\n#text import Pkg\n#text Pkg.activate(\"../../../tutorials\")\n#text Pkg.develop(path=\"../../..\")\n#text Pkg.instantiate()\n#text\n#text import Random \n#text Random.seed!(1234)\n#text ```\n\nimport Random #src\nRandom.seed!(1234) #src\n\n##cell\n#text Load our packages\n\nusing StochasticAD\nusing Distributions\nusing Enzyme\nusing LinearAlgebra\n\n##cell\n#text Let us define our target function.\n\n# Define a toy `StochasticAD`-differentiable function for computing an integer value from a string.\nstring_value(strings, index) = Int(sum(codepoint, strings[index]))\nfunction string_value(strings, index::StochasticTriple)\n    StochasticAD.propagate(index -> string_value(strings, index), index)\nend\n\nfunction f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling())\n    strings = [\"cat\", \"dog\", \"meow\", \"woofs\"]\n    index = randst(Categorical(θ); derivative_coupling)\n    return string_value(strings, index)\nend\n\nθ = [0.1, 0.5, 0.3, 0.1]\n@show f(θ)\nnothing\n\n##cell\n#text First, let's compute the sensitivity of `f` in a particular direction via forward-mode Stochastic AD.\nu = [1.0, 2.0, 4.0, -7.0]\n@show derivative_estimate(\n    f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)\nnothing\n\n##cell\n#text Now, let's do the same with reverse-mode, via [`EnzymeReverseAlgorithm`](@ref).\n\n@show derivative_estimate(\n    f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))\n\n##cell\n#text Let's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative.\n\nfunction forward()\n    derivative_estimate(\n        f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)\nend\nfunction reverse()\n    derivative_estimate(\n        f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))\nend\n\nN = 40000\ndirectional_derivs_fwd = [forward() for i in 1:N]\nderivs_bwd = [reverse() for i in 1:N]\ndirectional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd]\nprintln(\"Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))\")\nprintln(\"Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))\")\n@assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2)\n\nnothing\n\n##cell\n#! format: off #src\nusing Literate #src\ndo_documenter = true #src\n\nfunction preprocess(content) #src\n    new_lines = map(split(content, \"\\n\")) do line #src\n        if endswith(line, \"#src\") #src\n            line #src\n        elseif startswith(line, \"##cell\") #src\n            \"#src\" #src\n        elseif startswith(line, \"#text\") #src\n            replace(line, \"#text\" => \"#\") #src\n        # try and save comments; strip necessasry since Literate.jl also treats indented comments on their own line as markdown. #src\n        elseif startswith(strip(line), \"#\") && !startswith(strip(line), \"#=\") &&\n               !startswith(strip(line), \"#-\") #src\n            # TODO: should be replace first occurence only? #src\n            replace(line, \"#\" => \"##\") #src\n        else #src\n            line #src\n        end #src\n    end #src\n    return join(new_lines, \"\\n\") #src\nend #src\n\nwithenv(\"JULIA_DEBUG\" => \"Literate\") do #src\n    dir = joinpath(dirname(dirname(pathof(StochasticAD))), \"docs\", \"src\", \"tutorials\") #src\n    if do_documenter #src\n        @time Literate.markdown(\n            @__FILE__, dir; execute = false, flavor = Literate.DocumenterFlavor(),\n            preprocess = preprocess, documenter = true) #src\n    else #src\n        @time Literate.markdown(@__FILE__, dir; execute = true,\n            flavor = Literate.CommonMark(), preprocess = preprocess) #src\n    end #src\nend #src\n"
  },
  {
    "path": "tutorials/toy_optimizations/Project.toml",
    "content": "[deps]\nCairoMakie = \"13f3f980-e62b-5c42-98c6-ff1f3baf88f0\"\nDistributions = \"31c24e10-a181-5473-b8eb-7969acd0382f\"\nOptimisers = \"3bd65402-5787-11e9-1adc-39752487f4e2\"\nRandom = \"9a3f8284-a2c9-5f02-9a11-845980a1fd5c\"\nStochasticAD = \"e4facb34-4f7e-4bec-b153-e122c37934ac\"\nTilde = \"73a6ac3c-4b34-4cca-a813-308f7589d80d\"\n"
  },
  {
    "path": "tutorials/toy_optimizations/igarch.jl",
    "content": "# Poisson autoregression\ncd(@__DIR__)\nusing StochasticAD, Distributions\nusing Optimisers\nimport Random\nRandom.seed!(1234)\nRandom.seed!(StochasticAD.RNG, 1234)\nPLOT = true\nif PLOT\n    using CairoMakie\nend\n\n# Poisson autoregression model, returning end value after `n` iterations\nfunction igarch(a, b, c, n, λ)\n    z = rand(Poisson(λ))\n    λ = a + b * z + c * λ\n    for i in 2:n\n        z = rand(Poisson(λ))\n        λ = a + b * z + c * λ\n    end\n    return λ, z\nend\n\nλ0 = 5.42 # true starting value\n\n## Generate observations\nn = 10\na, b, c = [0.25, 0.9, 0.51]\n_, z_obs = igarch(a, b, c, n, λ0) # 140 in first run\n\n# Posterior density estimate of parameter p=λ0 given z_obs=140 (assume we don't know)\nfunction X(p, z_obs = 140, n = 10)\n    a, b, c = [0.25, 0.9, 0.51]\n    λ, _ = igarch(a, b, c, n - 1, p)\n    pdf(Exponential(100.0), λ) * pdf(Poisson(λ), z_obs)\nend\n\n# Maximize posterior with Adam and Optimize\np0 = [20.5]\niterations = 5000\nm = StochasticAD.StochasticModel(p0, x -> -X(x)) # Formulate as minimization problem\ntrace = Float64[]\no = Adam(0.08)\ns = Optimisers.setup(o, m)\nfor i in 1:iterations\n    Optimisers.update!(s, m, StochasticAD.stochastic_gradient(m))\n    push!(trace, m.p[])\nend\np_opt = m.p[]\n\nif PLOT\n    ps = range(0, 10, length = 50)\n    N = 1000\n    expected = [mean(X(p) for _ in 1:N) for p in ps]\n    slope = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]\n\n    f = Figure()\n    ax = f[1, 1] = Axis(f, title = \"Estimates\")\n    lines!(ax, ps, expected, label = \"≈ E X(p)\")\n    lines!(ax, ps, slope, label = \"≈ (E X(p))'\")\n    vlines!(ax, [p_opt], label = \"p_opt\", color = :green, linewidth = 2.0)\n    vlines!(ax, [λ0], linestyle = :dot, linewidth = 2.0)\n    hlines!(ax, [0.0], color = :black, linewidth = 1.0)\n\n    f[1, 2] = Legend(f, ax, framevisible = false)\n    ylims!(ax, (-1e-5, 2e-5))\n    ax = f[2, 1:2] = Axis(f, title = \"Optimizer trace\")\n    lines!(ax, trace, color = :green, linewidth = 2.0)\n    hlines!(ax, [λ0], linestyle = :dot, linewidth = 2.0)\n    ylims!(ax, (0, 20))\n    save(\"igarch.png\", f)\n    display(f)\nend\n"
  },
  {
    "path": "tutorials/toy_optimizations/intro.jl",
    "content": "# Toy expectation optimization problem \ncd(@__DIR__)\nusing StochasticAD, Distributions, Optimisers\nimport Random # hide\nRandom.seed!(1234) # hide\nPLOT = true\nif PLOT\n    using CairoMakie\nend\n\n# The \"crazy\" stochastic program from the introduction\nfunction X(p)\n    a = p * (1 - p)\n    b = rand(Binomial(10, p))\n    c = 2 * b + 3 * rand(Bernoulli(p))\n    return a * c * rand(Normal(b, a))\nend\n\n# Maximize E[X(p)] using Adam and Optimize\np0 = [0.5]\niterations = 5000\nm = StochasticAD.StochasticModel(p0, x -> -X(x)) # Formulate as minimization problem\ntrace = Float64[]\no = Adam()\ns = Optimisers.setup(o, m)\nfor i in 1:iterations\n    Optimisers.update!(s, m, StochasticAD.stochastic_gradient(m))\n    push!(trace, m.p[])\nend\np_opt = m.p[]\n\nif PLOT\n    dp = 1 / 50\n    N = 1000\n    ps = dp:dp:(1 - dp)\n    avg = [mean(X(p) for _ in 1:N) for p in ps]\n    derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]\n\n    f = Figure()\n    ax = f[1, 1] = Axis(f, title = \"Estimates\")\n    lines!(ax, ps, avg, label = \"≈ E[X(p)]\")\n    lines!(ax, ps, derivative, label = \"≈ d/dp E[X(p)]\")\n    vlines!(ax, [p_opt], label = \"p_opt\", color = :green, linewidth = 2.0)\n    hlines!(ax, [0.0], color = :black, linewidth = 1.0)\n\n    f[1, 2] = Legend(f, ax, framevisible = false)\n    ylims!(ax, (-50, 80))\n    ax = f[2, 1:2] = Axis(f, title = \"Optimizer trace\")\n    lines!(ax, trace, color = :green, linewidth = 2.0)\n    save(\"intro.png\", f)\n    display(f)\nend\n"
  },
  {
    "path": "tutorials/toy_optimizations/variational.jl",
    "content": "# Toy variational problem: Find Poisson(p) close to NegativeBinomial(10, 1-30/(10+30))\n# by minimization of the Kullback Leibler distance\ncd(@__DIR__)\nusing StochasticAD, Distributions, Optimisers\nimport Random # hide\nRandom.seed!(1234) # hide\nPLOT = true\nif PLOT\n    using CairoMakie\nend\n\n# Sample the likelihood ratio. E[X(p)] is the Kullback-Leibler distance between the models\nfunction X(p)\n    i = rand(Poisson(p))\n    return logpdf(Poisson(p), i) - logpdf(NegativeBinomial(10, 1 - 30 / (10 + 30)), i)\nend\n\n# Minimize E[X] = KL(Poisson(p)| NegativeBinomial(10, 1-30/(10+30))) using Adam and Optimize.jl\niterations = 5000\np0 = [10.0]\nm = StochasticAD.StochasticModel(p0, X) # Formulate as minimization problem\ntrace = Float64[]\no = Adam(0.1)\ns = Optimisers.setup(o, m)\nfor i in 1:iterations\n    Optimisers.update!(s, m, StochasticAD.stochastic_gradient(m))\n    push!(trace, m.p[])\nend\np_opt = m.p[]\n\nif PLOT\n    dp = 1 / 2\n    N = 1000\n    ps = 10:dp:50\n    avg = [mean(X(p) for _ in 1:N) for p in ps]\n    derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]\n    f = Figure()\n    ax = f[1, 1] = Axis(f, title = \"Estimates\")\n    lines!(ax, ps, avg, label = \"≈ E[X(p)]\")\n    lines!(ax, ps, derivative, label = \"≈ d/dp E[X(p)]\")\n    vlines!(ax, [p_opt], label = \"p_opt\", color = :green, linewidth = 2.0)\n    hlines!(ax, [0.0], color = :black, linewidth = 1.0)\n\n    f[1, 2] = Legend(f, ax, framevisible = false)\n    ylims!(ax, (-10, 10))\n    ax = f[2, 1:2] = Axis(f, title = \"Optimizer trace\")\n    lines!(ax, trace, color = :green, linewidth = 2.0)\n    save(\"variational.png\", f)\n    display(f)\nend\n"
  }
]