Full Code of gaurav-arya/StochasticAD.jl for AI

main 24788b5bc82d cached
72 files
217.3 KB
65.4k tokens
1 requests
Download .txt
Showing preview only (237K chars total). Download the full file or copy to clipboard to get everything.
Repository: gaurav-arya/StochasticAD.jl
Branch: main
Commit: 24788b5bc82d
Files: 72
Total size: 217.3 KB

Directory structure:
gitextract__gvecf0r/

├── .JuliaFormatter.toml
├── .git-blame-ignore-revs
├── .github/
│   └── workflows/
│       ├── CI.yml
│       ├── CompatHelper.yml
│       ├── Documentation.yml
│       ├── FormatCheck.yml
│       ├── TagBot.yml
│       └── benchmark.yml
├── .gitignore
├── CITATION.bib
├── LICENSE
├── Project.toml
├── README.md
├── benchmark/
│   ├── benchmarks.jl
│   ├── game_of_life.jl
│   ├── iteration.jl
│   ├── random_walk.jl
│   ├── runbenchmarks.jl
│   ├── simple_ops.jl
│   └── utils.jl
├── docs/
│   ├── Project.toml
│   ├── make.jl
│   └── src/
│       ├── assets/
│       │   └── extra_styles.css
│       ├── devdocs.md
│       ├── index.md
│       ├── limitations.md
│       ├── public_api.md
│       └── tutorials/
│           ├── game_of_life.md
│           ├── optimizations.md
│           ├── particle_filter.md
│           ├── random_walk.md
│           └── reverse_demo.md
├── ext/
│   └── StochasticADEnzymeExt.jl
├── src/
│   ├── StochasticAD.jl
│   ├── algorithms.jl
│   ├── backends/
│   │   ├── abstract_wrapper.jl
│   │   ├── dict.jl
│   │   ├── pruned.jl
│   │   ├── pruned_aggressive.jl
│   │   ├── smoothed.jl
│   │   └── strategy_wrapper.jl
│   ├── discrete_randomness.jl
│   ├── finite_infinitesimals.jl
│   ├── general_rules.jl
│   ├── misc.jl
│   ├── prelude.jl
│   ├── propagate.jl
│   ├── smoothing.jl
│   └── stochastic_triple.jl
├── test/
│   ├── game_of_life.jl
│   ├── random_walk.jl
│   ├── resampling.jl
│   ├── runtests.jl
│   └── triples.jl
└── tutorials/
    ├── Project.toml
    ├── README.md
    ├── game_of_life/
    │   ├── core.jl
    │   └── plot_board.jl
    ├── particle_filter/
    │   ├── benchmark.jl
    │   ├── bias.jl
    │   ├── core.jl
    │   ├── model.jl
    │   ├── variance.jl
    │   └── visualize.jl
    ├── random_walk/
    │   ├── compare_score.jl
    │   ├── core.jl
    │   └── show_unbiased.jl
    ├── reverse_example/
    │   └── reverse_demo.jl
    └── toy_optimizations/
        ├── Project.toml
        ├── igarch.jl
        ├── intro.jl
        └── variational.jl

================================================
FILE CONTENTS
================================================

================================================
FILE: .JuliaFormatter.toml
================================================
style = "sciml"


================================================
FILE: .git-blame-ignore-revs
================================================
# Run this command to always ignore these in local `git blame`:
# git config blame.ignoreRevsFile .git-blame-ignore-revs

# Run formatter
70fd432667fb431e08ba52728734108d822a1922
# Run formatter 
21038a047c023330876feb9259cd5c92add3ca81
# Run formatter after bracket alignment removal
799277f9652258282a91ecfe976df5fb8ab64c82
# Format
db4333c604cc23c3c36420f09aa998d01ef0214b


================================================
FILE: .github/workflows/CI.yml
================================================
name: CI
on:
  pull_request:
  push:
    branches:
      - main 
    tags: '*'
jobs:
  unittest:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        group:
          - Core
        version:
          - '1'
          - '1.7'
    steps:
      - uses: actions/checkout@v2
      - uses: julia-actions/setup-julia@v1
        with:
          version: ${{ matrix.version }}
      - uses: julia-actions/julia-buildpkg@v1
      - uses: julia-actions/julia-runtest@v1
      - uses: julia-actions/julia-processcoverage@v1
      - uses: codecov/codecov-action@v2
        with:
          file: lcov.info


================================================
FILE: .github/workflows/CompatHelper.yml
================================================
name: CompatHelper
on:
  schedule:
    - cron: 0 0 * * *
  workflow_dispatch:
jobs:
  CompatHelper:
    runs-on: ubuntu-latest
    steps:
      - name: Pkg.add("CompatHelper")
        run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
      - name: CompatHelper.main()
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
          COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
        run: julia -e 'using CompatHelper; CompatHelper.main()'


================================================
FILE: .github/workflows/Documentation.yml
================================================
name: Documentation

on:
  push:
    branches:
      - main
    tags: '*'
  pull_request:

jobs:
  build:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
      - uses: julia-actions/setup-julia@v1
        with:
          version: '1'
      - name: Install dependencies
        run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
      - name: Build and deploy
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token
          DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key
          DATADEPS_ALWAYS_ACCEPT: true
        run: julia --project=docs/ docs/make.jl


================================================
FILE: .github/workflows/FormatCheck.yml
================================================
name: format-check

on:
  push:
    branches:
      - 'main'
      - 'release-'
    tags: '*'
  pull_request:

jobs:
  build:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        julia-version: [1]
        julia-arch: [x86]
        os: [ubuntu-latest]
    steps:
      - uses: julia-actions/setup-julia@latest
        with:
          version: ${{ matrix.julia-version }}

      - uses: actions/checkout@v1
      - name: Install JuliaFormatter and format
        # This will use the latest version by default but you can set the version like so:
        #
        # julia  -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))'
        run: |
          julia  -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
          julia  -e 'using JuliaFormatter; format(".", verbose=true)'
      - name: Format check
        run: |
          julia -e '
          out = Cmd(`git diff`) |> read |> String
          if out == ""
              exit(0)
          else
              @error "Some files have not been formatted !!!"
              write(stdout, out)
              exit(1)
          end'


================================================
FILE: .github/workflows/TagBot.yml
================================================
name: TagBot
on:
  issue_comment:
    types:
      - created
  workflow_dispatch:
jobs:
  TagBot:
    if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
    runs-on: ubuntu-latest
    steps:
      - uses: JuliaRegistries/TagBot@v1
        with:
          token: ${{ secrets.GITHUB_TOKEN }}
          ssh: ${{ secrets.DOCUMENTER_KEY }}


================================================
FILE: .github/workflows/benchmark.yml
================================================

name: Benchmarks

on:
  pull_request:
  push:
    branches:
      - main 
    tags: '*'

jobs:
  benchmark:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
      - uses: julia-actions/setup-julia@latest
        with:
          version: 1
      - name: Install dependencies
        run: julia -e 'using Pkg; Pkg.activate("tutorials"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate();'
      - name: Run benchmarks
        run: julia --project=tutorials --color=yes benchmark/runbenchmarks.jl 


================================================
FILE: .gitignore
================================================
Manifest.toml

================================================
FILE: CITATION.bib
================================================
@inproceedings{arya2022automatic,
 author = {Arya, Gaurav and Schauer, Moritz and Sch\"{a}fer, Frank and Rackauckas, Christopher},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
 pages = {10435--10447},
 publisher = {Curran Associates, Inc.},
 title = {Automatic Differentiation of Programs with Discrete Randomness},
 url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/43d8e5fc816c692f342493331d5e98fc-Paper-Conference.pdf},
 volume = {35},
 year = {2022}
}


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 Gaurav Arya <aryag@mit.edu> and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: Project.toml
================================================
name = "StochasticAD"
uuid = "e4facb34-4f7e-4bec-b153-e122c37934ac"
authors = ["Gaurav Arya <aryag@mit.edu> and contributors"]
version = "0.1.26"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
StochasticADEnzymeExt = "Enzyme"

[compat]
ChainRulesCore = "1.15"
ChainRulesOverloadGeneration = "0.1"
Dictionaries = "0.3"
Distributions = "0.25"
DistributionsAD = "0.6"
ExprTools = "0.1"
ForwardDiff = "0.10"
Functors = "0.4.3"
julia = "1"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["LinearAlgebra", "Pkg", "Printf", "Test", "Statistics", "SafeTestsets", "OffsetArrays", "StaticArrays", "Zygote", "ForwardDiff", "GaussianDistributions", "Measurements", "UnPack", "StatsBase", "DiffResults", "ChainRulesCore"]


================================================
FILE: README.md
================================================
![](docs/src/images/path_skeleton.png#gh-light-mode-only)
![](docs/src/images/path_skeleton_dark.png#gh-dark-mode-only)

# StochasticAD

[![Build Status](https://github.com/gaurav-arya/StochasticAD.jl/workflows/CI/badge.svg?branch=main)](https://github.com/gaurav-arya/StochasticAD.jl/actions?query=workflow:CI)
[![](https://img.shields.io/badge/docs-main-blue.svg)](https://gaurav-arya.github.io/StochasticAD.jl/dev/)
[![arXiv article](https://img.shields.io/badge/article-arXiv%3A10.48550-B31B1B)](https://arxiv.org/abs/2210.08572)

StochasticAD is an experimental, research package for automatic differentiation (AD) of stochastic programs. It implements AD algorithms for handling programs that can contain *discrete* randomness, based on the methodology developed in [this NeurIPS 2022 paper](https://doi.org/10.48550/arXiv.2210.08572). We're still working on docs and code cleanup!

## Installation

The package can be installed with the Julia package manager:

```julia
julia> using Pkg;
julia> Pkg.add("StochasticAD");
```

## Citation

```
@inproceedings{arya2022automatic,
 author = {Arya, Gaurav and Schauer, Moritz and Sch\"{a}fer, Frank and Rackauckas, Christopher},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
 pages = {10435--10447},
 publisher = {Curran Associates, Inc.},
 title = {Automatic Differentiation of Programs with Discrete Randomness},
 url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/43d8e5fc816c692f342493331d5e98fc-Paper-Conference.pdf},
 volume = {35},
 year = {2022}
}
```


================================================
FILE: benchmark/benchmarks.jl
================================================
using BenchmarkTools

include("random_walk.jl")
include("game_of_life.jl")
include("iteration.jl")
include("simple_ops.jl")

const SUITE = BenchmarkGroup()
SUITE["random_walk"] = RandomWalkBenchmark.suite
SUITE["game_of_life"] = GoLBenchmark.suite
SUITE["iteration"] = IterationBenchmark.suite
SUITE["simple_ops"] = SimpleOpsBenchmark.suite


================================================
FILE: benchmark/game_of_life.jl
================================================
module GoLBenchmark

using BenchmarkTools

using StochasticAD
using Statistics
using ForwardDiff: derivative
include("../tutorials/game_of_life/core.jl")
using .GoLCore: play, p

const suite = BenchmarkGroup()

suite["original"] = @benchmarkable $play($p)
suite["PrunedFIs"] = @benchmarkable derivative_estimate($play, $p;
    backend = PrunedFIsBackend())
suite["PrunedFIsAggressive"] = @benchmarkable derivative_estimate($play, $p;
    backend = PrunedFIsAggressiveBackend())
suite["SmoothedFIs"] = @benchmarkable derivative_estimate($play, $p;
    backend = SmoothedFIsBackend())

end


================================================
FILE: benchmark/iteration.jl
================================================
"""
In the library we have tried to avoid generated functions, instead reductions from base with the
hope that the iteration will be optimized to be zero-cost.
This suite tests the performance of iteration on small nested structures, which crop up when `propagate` is called
on small structures of scalars.
The `couple` and `combine` operations of FIss, which use iteration, are benchmarked.
"""
module IterationBenchmark

using BenchmarkTools
using StochasticAD
using StaticArrays

const suite = BenchmarkGroup()

# Examples consist of flat and non-flat versions of structures, to test zero-cost iteration.
tups = Dict("easy" => (ntuple(identity, 3), (1, (2, 3))),
    "hard" => (ntuple(identity, 9), (1, (2, 3), (4, (5, (6, 7, 8), 9)))))
SAs = Dict("easy" => (SA[1, 2, 3], (1, SA[2, 3])),
    "hard" => (SA[1, 2, 3, 4, 5, 6, 7, 8, 9],
        (1, SA[2, 3], (4, (5, SA[6, 7, 8], 9)))))

for (setname, set) in (("tups", tups), ("SAs", SAs))
    suite[setname] = BenchmarkGroup()
    setsuite = suite[setname]
    for case in ["easy", "hard"]
        casesuite = setsuite[case] = BenchmarkGroup()
        for isflat in [false, true]
            flatsuite = casesuite[isflat ? "flat" : "not flat"] = BenchmarkGroup()
            values = set[case][isflat ? 1 : 2]
            flatsuite["make_iterate_values"] = @benchmarkable StochasticAD.structural_iterate($values)
            iter_values = StochasticAD.structural_iterate(values)
            flatsuite["foldl_values"] = @benchmarkable foldl(+, $(iter_values))
            flatsuite["iterate_values"] = @benchmarkable for i in $(iter_values)
            end
            for backend in [PrunedFIsBackend(), PrunedFIsAggressiveBackend()]
                FIs_suite = flatsuite[backend] = BenchmarkGroup()
                Δs = StochasticAD.create_Δs(backend, Int)
                Δs1 = StochasticAD.similar_new(Δs, 1, 1)
                Δs_all = StochasticAD.structural_map(x -> map(Δ -> x, Δs1), values)
                FIs_suite["make_iterate_Δs"] = @benchmarkable StochasticAD.structural_iterate($Δs_all)
                # We don't interpolate backend directly in below (i.e. do $FIs) because string interpolating a type
                # seems to lead to slow benchmarks.
                FIs_suite["couple_same"] = @benchmarkable StochasticAD.couple(typeof($Δs),
                    $Δs_all)
                FIs_suite["combine_same"] = @benchmarkable StochasticAD.combine(
                    typeof($Δs),
                    $Δs_all)
            end
        end
    end
end

end


================================================
FILE: benchmark/random_walk.jl
================================================
module RandomWalkBenchmark

using BenchmarkTools

using StochasticAD
using Statistics
using ForwardDiff: derivative
include("../tutorials/random_walk/core.jl")
using .RandomWalkCore: n, p, nsamples
using .RandomWalkCore: fX, get_dfX

const suite = BenchmarkGroup()

suite["original"] = @benchmarkable $(fX)($p)
suite["PrunedFIs"] = @benchmarkable derivative_estimate($fX, $p;
    backend = PrunedFIsBackend())
suite["PrunedFIsAggressive"] = @benchmarkable derivative_estimate($fX, $p;
    backend = PrunedFIsAggressiveBackend())
suite["SmoothedFIs"] = @benchmarkable derivative_estimate($fX, $p;
    backend = SmoothedFIsBackend())
forwarddiff_func = p -> fX(p; hardcode_leftright_step = true)
suite["ForwardDiff_smoothing"] = @benchmarkable derivative($forwarddiff_func, $p)

end


================================================
FILE: benchmark/runbenchmarks.jl
================================================
using PkgBenchmark

include("utils.jl")
using .Utils

results = benchmarkpkg(dirname(@__DIR__),
    BenchmarkConfig(env = Dict("JULIA_NUM_THREADS" => "1",
        "OMP_NUM_THREADS" => "1")),
    resultfile = joinpath(@__DIR__, "result.json"))
@show results = print_group(results.benchmarkgroup)


================================================
FILE: benchmark/simple_ops.jl
================================================
module SimpleOpsBenchmark

using BenchmarkTools

using StochasticAD

const suite = BenchmarkGroup()

suite["add"] = BenchmarkGroup()
suite["add_via_propagate_nodeltas"] = BenchmarkGroup()
suite["add_via_propagate"] = BenchmarkGroup()

suite["add"]["original"] = @benchmarkable +(0.5, 0.5)
suite["add_via_propagate_nodeltas"]["original"] = @benchmarkable StochasticAD.propagate(+,
    0.5,
    0.5)
suite["add_via_propagate"]["original"] = @benchmarkable StochasticAD.propagate(+, 0.5, 0.5;
    keep_deltas = Val{
        true,
    })
for backend in [PrunedFIsBackend(), PrunedFIsAggressiveBackend()]
    suite["add"][backend] = @benchmarkable +(st, st) setup=(st = stochastic_triple(0.5;
        backend = $backend))
    suite["add_via_propagate_nodeltas"][backend] = @benchmarkable StochasticAD.propagate(+,
        st,
        st) setup=(st = stochastic_triple(0.5;
        backend = $backend))
    suite["add_via_propagate"][backend] = @benchmarkable StochasticAD.propagate(+, st, st;
        keep_deltas = Val{
            true,
        }) setup=(st = stochastic_triple(0.5;
        backend = $backend))
end

end


================================================
FILE: benchmark/utils.jl
================================================
module Utils

export print_group

using Functors
using BenchmarkTools

## Printing

# Type piracy, fine since just in benchmarking. (design of Functors should probably allow for user-customized functors)
@functor BenchmarkTools.BenchmarkGroup

function print_trial(t)
    ptime = BenchmarkTools.prettytime(time(t))
    pallocs = "$(allocs(t)) allocs"
    return "$ptime, $pallocs"
end

function print_group(b)
    fmap(t -> (t isa BenchmarkTools.Trial ? print_trial(t) : t), b)
end

end


================================================
FILE: docs/Project.toml
================================================
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac"


================================================
FILE: docs/make.jl
================================================
using Pkg

using Documenter
using StochasticAD
using DocThemeIndigo
using Literate

### Formatting

indigo = DocThemeIndigo.install(StochasticAD)
format = Documenter.HTML(prettyurls = false,
    assets = [indigo, "assets/extra_styles.css"],
    repolink = "https://github.com/gaurav-arya/StochasticAD.jl",
    edit_link = "main")

### Pagination

pages = [
    "Overview" => "index.md",
    "Tutorials" => [
        "tutorials/random_walk.md",
        "tutorials/game_of_life.md",
        "tutorials/particle_filter.md",
        "tutorials/optimizations.md",
        "tutorials/reverse_demo.md"
    ],
    "Public API" => "public_api.md",
    "Developer documentation" => "devdocs.md",
    "Limitations" => "limitations.md"
]

### Prepare literate tutorials

# TODO (for now they are manually built into docs/src/tutorials and checked into repo)

### Make docs

makedocs(sitename = "StochasticAD.jl",
    authors = "Gaurav Arya and other contributors",
    modules = [StochasticAD],
    format = format,
    pages = pages,
    warnonly = [:missing_docs])

try
    deploydocs(repo = "github.com/gaurav-arya/StochasticAD.jl",
        devbranch = "main",
        push_preview = true)
catch e
    println("Error encountered while deploying docs:")
    showerror(stdout, e)
end


================================================
FILE: docs/src/assets/extra_styles.css
================================================
.display-light-only {display: block;}
.display-dark-only {display: none;}
.theme--documenter-dark .display-light-only {display: none;}
.theme--documenter-dark .display-dark-only {display: block;}

================================================
FILE: docs/src/devdocs.md
================================================
# Developer documentation (WIP)

## Writing a custom rule for stochastic triples

### via `StochasticAD.propagate`

To handle a deterministic discrete construct that `StochasticAD` does not automatically handle (e.g. branching via `if`, boolean comparisons), it is often sufficient to simply add a dispatch rule that calls out to `StochasticAD.propagate`.

```@docs
StochasticAD.propagate
```

### via a custom dispatch

If a function does not meet the conditions of `StochasticAD.propagate` and is not already supported, a custom
dispatch may be necessary. For example, consider the following function which manually implements a geometric random variable:

```@example rule
import Random
Random.seed!(1234) # hide
using Distributions
# make rng input explicit
function mygeometric(rng, p)
    x = 0
    while !(rand(rng, Bernoulli(p)))
        x += 1
    end
    return x
end
mygeometric(p) = mygeometric(Random.default_rng(), p)
```

This is equivalent to `rand(Geometric(p))` which is already supported, but for pedagogical purposes we will
implement our own rule from scratch. Using the stochastic derivative formulas from [Automatic Differentiation of Programs with Discrete Randomness](https://doi.org/10.48550/arXiv.2210.08572), the right stochastic derivative of this program is given by
```math
Y_R = X - 1, w_R = \frac{x}{p(1-p)},
```
and the left stochastic derivative of this program is given by
```math
Y_L = X + 1, w_L = -\frac{x+1}{p}.
```

Using these expressions, we can now write the dispatch rule for stochastic triples:

```@example rule
using StochasticAD
import StochasticAD: StochasticTriple, similar_new, similar_empty, combine
function mygeometric(rng, p_st::StochasticTriple{T}) where {T}
    p = p_st.value
    rng_copy = copy(rng) # save a copy for coupling later
    x = mygeometric(rng, p)

    # Form the new discrete perturbations (combinations of weight w and perturbation Y - X)
    Δs1 = if p_st.δ > 0
        # right stochastic derivative
        w = p_st.δ * x / (p * (1 - p))
        x > 0 ? similar_new(p_st.Δs, -1, w) : similar_empty(p_st.Δs, Int)
    elseif p_st.δ < 0
        # left stochastic derivative
        w = -p_st.δ * (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L
        similar_new(p_st.Δs, 1, w)
    else
        similar_empty(p_st.Δs, Int)
    end

    # Propagate any existing perturbations to p through the function
    function map_func(Δ)
        # Couple the samples by using the same RNG. (A simpler strategy would have been independent sampling, i.e. mygeometric(p + Δ) - x)
        mygeometric(copy(rng_copy), p + Δ) - x 
    end
    Δs2 = map(map_func, p_st.Δs)

    # Return the output stochastic triple
    StochasticTriple{T}(x, zero(x), combine((Δs2, Δs1)))
end
```
In the above, we used some of the interface functions supported by a collection of perturbations `Δs::StochasticAD.AbstractFIs`. These were `similar_empty(Δs, V)`, which created an empty perturbation of type `V`, `similar_new(Δs, Δ, w)`, which created a new perturbation of size `Δ` and weight `w`, `map(map_func, Δs)`,
which propagates a collection of perturbations through a mapping function, and `combine((Δs2, Δs1)))` which combines multiple collections of perturbations together.

We can test out our rule:
```@example rule
@show stochastic_triple(mygeometric, 0.1)

# try feeding an input that already has a pertrubation
f(x) = mygeometric(2 * x + 0.1 * rand(Bernoulli(x)))^2
@show stochastic_triple(f, 0.1)

# verify against black-box finite differences
N = 1000000
samples_stochad = [derivative_estimate(f, 0.1) for i in 1:N]
samples_fd = [(f(0.105) - f(0.095)) / 0.01 for i in 1:N]

println("Stochastic AD: $(mean(samples_stochad)) ± $(std(samples_stochad) / sqrt(N))")
println("Finite differences: $(mean(samples_fd)) ± $(std(samples_fd) / sqrt(N))")

nothing # hide
```

## Distribution-specific customization of differentiation algorithm 

```@docs
randst
InversionMethodDerivativeCoupling
```

================================================
FILE: docs/src/index.md
================================================
```@raw html
<img class="display-light-only" src="images/path_skeleton.png">
<img class="display-dark-only" src="images/path_skeleton_dark.png">
```

# StochasticAD

[StochasticAD](https://github.com/gaurav-arya/StochasticAD.jl) is an experimental, research package for automatic differentiation (AD) of stochastic programs.
It implements AD algorithms for handling programs that can contain *discrete* randomness, based on the methodology developed in [this NeurIPS 2022 paper](https://doi.org/10.48550/arXiv.2210.08572).

## Introduction

Derivatives are all about how functions are affected by a tiny change `ε` in their input. First, let's imagine perturbing the input of a deterministic, differentiable function such as $f(p) = p^2$ at $p = 2$.
```@example continuous
using StochasticAD
f(p) = p^2
stochastic_triple(f, 2) # Feeds 2 + ε into f
```
The output tells us that if we change the input `2` by a tiny amount `ε`, the output of `f` will change by approximately `4ε`. This is the case we're familiar with: we can get the value `4` by applying the chain rule, $\frac{\mathrm{d}}{\mathrm{d} p} p^2 = 2p = 4$. Thinking in terms of tiny changes, the output above looks a lot like a [dual number](https://en.wikipedia.org/wiki/Dual_number). But what happens with a discrete random function? Let's give it a try. 
```@example discrete
import Random # hide
Random.seed!(4321) # hide
using StochasticAD, Distributions
f(p) = rand(Bernoulli(p)) # 1 with probability p, 0 otherwise
stochastic_triple(f, 0.5) # Feeds 0.5 + ε into f
```
The output of a [Bernoulli variable](https://en.wikipedia.org/wiki/Bernoulli_distribution) cannot change by a tiny amount: it is either `0` or `1`. But in the probabilistic world, there is another way to change by a tiny amount *on average*: jump by a large amount, with tiny probability. `StochasticAD` introduces a stochastic triple object, which generalizes dual numbers by including a *third* component to describe these perturbations. Here, the stochastic triple says that the original random output was `0`, but given a small change `ε` in the input, the output will jump up to `1` with probability approximately `2ε`.

Stochastic triples can be used to construct a new random program whose average is the derivative of the average of the original program. We simply propagate stochastic triples through the program via [`stochastic_triple`](@ref), and then sum up the "dual" and "triple" components at the end via [`derivative_contribution`](@ref). This process is packaged together in the function [`derivative_estimate`](@ref). Let's try a crazier example, where we mix discrete and continuous randomness!
```@example estimate
using StochasticAD, Distributions
import Random # hide
Random.seed!(1234) # hide

function X(p)
    a = p * (1 - p)
    b = rand(Binomial(10, p))
    c = 2 * b + 3 * rand(Bernoulli(p))
    return a * c * rand(Normal(b, a))
end

st = @show stochastic_triple(X, 0.6) # sample a single stochastic triple at p = 0.6
@show derivative_contribution(st) # which produces a single derivative estimate...

samples = [derivative_estimate(X, 0.6) for i in 1:1000] # many samples from derivative program
derivative = mean(samples)
uncertainty = std(samples) / sqrt(1000)
println("derivative of 𝔼[X(p)] = $derivative ± $uncertainty")
```

## Index

See the [public API](public_api.md) for a walkthrough of the API, and the tutorials on differentiating a [random walk](tutorials/random_walk.md), a [stochastic game of life](tutorials/game_of_life.md), and a [particle filter](tutorials/particle_filter.md), and solving [stochastic optimization and variational problems](tutorials/optimizations.md) with discrete randomness. This is a prototype package with a number of [limitations](limitations.md).



================================================
FILE: docs/src/limitations.md
================================================
# Limitations of StochasticAD

`StochasticAD` has a number of limitations that are important to be aware of:

* `StochasticAD` uses operator-overloading just like [ForwardDiff](https://juliadiff.org/ForwardDiff.jl/stable/), so all of the [limitations](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/) listed there apply here too. Also note that some useful features of `ForwardDiff`, such as chunking for greater efficiency with a large number of parameters, have not yet been implemented here.
* We have limited support for reverse-mode AD via [smoothing](public_api.md#Smoothing), which cannot be guaranteed to be unbiased in all cases. 
* We do not yet support `if` statements with discrete random input. A workaround can be to use array indexing to express discrete random choices (see [the random walk tutorial](tutorials/random_walk.md) for an example).
* We do not yet support non-real values as intermediate values (e.g. a function such as `length(A[rand(Bernoulli(p))])` where `A` is an array of strings is in theory differentiable).
* We do not support discrete random variables that are implicitly implemented using continuous random variables, e.g. `rand() < p`.
* We support a limited assortment of discrete random variables: currently `Bernoulli`, `Binomial`, `Geometric`, `Poisson`, and `Categorical` from [Distributions](https://juliastats.org/Distributions.jl/). We are working on increasing coverage across `Distributions` as well as other libraries providing discrete random samplers such as [MeasureTheory](https://cscherrer.github.io/MeasureTheory.jl/stable/).
* Higher-order differentiation is not supported.

`StochasticAD` is still in active development! PRs are welcome.



================================================
FILE: docs/src/public_api.md
================================================
# API walkthrough
 
The function [`derivative_estimate`](@ref) transforms a stochastic program containing discrete randomness into a new program whose average is the derivative of the original.
```@docs
derivative_estimate
```
While [`derivative_estimate`](@ref) is self-contained, we can also use the functions below to work with stochastic triples directly.
```@docs
StochasticAD.stochastic_triple
StochasticAD.derivative_contribution
StochasticAD.value
StochasticAD.delta
StochasticAD.perturbations
```
Note that [`derivative_estimate`](@ref) is simply the composition of [`stochastic_triple`](@ref) and [`derivative_contribution`](@ref). We also provide a convenience function for mimicking the behaviour
of standard AD, where derivatives of discrete random steps are dropped:
```@docs
StochasticAD.dual_number
```

## Algorithms 

```@docs
StochasticAD.ForwardAlgorithm
StochasticAD.EnzymeReverseAlgorithm
```

## Smoothing

What happens if we were to run [`derivative_contribution`](@ref) after each step, instead of only at the end? This is *smoothing*, which combines the second and third components of a single stochastic triple into a single dual component. 
Smoothing no longer has a guarantee of unbiasedness, but is surprisingly accurate in a number of situations. 
For example, the popular [straight through gradient estimator](https://stackoverflow.com/questions/38361314/the-concept-of-straight-through-estimator-ste) can be viewed as a special case of smoothing.
Forward smoothing rules are provided through `ForwardDiff`, and backward rules through `ChainRules`, so that e.g. `Zygote.gradient` and `ForwardDiff.derivative` will use smoothed rules for discrete random variables rather than dropping the gradients entirely. 
Currently, special discrete->discrete constructs such as array indexing are not supported for smoothing.


## Optimization

We also provide utilities to make it easier to get started with forming and training a model via stochastic gradient descent:
```@docs
StochasticAD.StochasticModel
StochasticAD.stochastic_gradient
```
These are used in the [tutorial on stochastic optimization](tutorials/optimizations.md).


================================================
FILE: docs/src/tutorials/game_of_life.md
================================================
# Stochastic Game of Life

We consider a stochastic version of [Conway's Game of Life](https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life), played on a two-dimensional board. We shall use the following packages,
```@setup game_of_life
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()
```
```@example game_of_life
using Distributions
using StochasticAD
using OffsetArrays 
using StaticArrays
```

## Setting up the stochastic Game of Life

Each turn, the standard Game of Life applies the following rules to each cell,
```math
\text{dead and 3 neighbours alive} \to \text{ alive}, \\
\text{alive and 0, 1, or 4 neighbours alive} \to \text{ dead}.
```
The cell's status does not change otherwise. In our stochastic version, these rules instead occur with probability `1-θ`, while the opposite event has probability `θ`. To initialize the board at the beginning of the game, we randomly set each cell alive with probability `p`. 

The following high level function sets up the probabilities and provides them to `play_game_of_life`.
```@example game_of_life
function play(p, θ=0.1, N=12, T=10; log=false)
    # N is the board half-length, T are game time steps
    low = θ
    high = 1-θ
    birth_probs = SA[low, low, low, high, low] # 0, 1, 2, 3, 4 neighbours
    death_probs = SA[high, high, low, low, high] # 0, 1, 2, 3, 4 neighbours 
    return play_game_of_life(p, vcat(birth_probs, death_probs), N, T; log)
end
```
We can now implement the Game of Life based on the specification. At the end of the game, we return the total number of alive cells.
```@example game_of_life
# A single turn of the game
function update_state(all_probs, N, board_new, board_old)
    for i in -N:N
        for j in -N:N
            neighbours = board_old[i+1, j] + board_old[i-1, j] + board_old[i, j-1] + board_old[i, j+1]
            index = board_new[i,j] * 5 + neighbours + 1 
            b = rand(Bernoulli(all_probs[index]))
            board_new[i,j] += (1 - 2 * board_new[i,j]) * b 
        end
    end
end

function play_game_of_life(p, all_probs, N, T; log=false)
    dual_type = promote_type(typeof(rand(Bernoulli(p))), typeof.(rand.(Bernoulli.(all_probs)))...) # a hacky way of getting the correct array type 
    board = OffsetArray(zeros(dual_type, 2*N + 3, 2*N + 3), -(N+1):(N+1), -(N+1):(N+1)) # center board at (0,0), pad by 1 

    # initialize the board	
    for i in -N:N
        for j in -N:N
            board[i,j] = rand(Bernoulli(p))
        end
    end
    board_old = similar(board)
    log && (history = [])

    # play the game
    for time_step in 1:T
        copy!(board_old, board)
        update_state(all_probs, N, board, board_old)
        log && push!(history, copy(board))
    end

    if !log
        return sum(board)
    else
        return sum(board), board, history
    end
end

play(0.5, 0.1) # play the game with p = 0.5 and θ = 0.1
```

!!! note 
    Note that we did have to be careful to write this program to be compatible with the [current capabilities of `StochasticAD`](../limitations.md). For example, we concatenated `birth_probs` and `death_probs` into a single array `all_probs` and used the index `board[i, j] * 5 + neighbours + 1` to find the probability, rather than use the more natural `if alive... else...` syntax.

## Differentiating the Game of Life

Let's differentiate the Game of Life!
```@example game_of_life
@show stochastic_triple(play, 0.5) # let's take a look at a single stochastic triple

samples = [derivative_estimate(play, 0.5) for i in 1:10000] # take many samples
derivative = mean(samples)
uncertainty = std(samples) / sqrt(10000)
println("derivative of 𝔼[play(p)] = $derivative ± $uncertainty")
```

The following sketch of the final state of the board for a single run gives some insight into what the stochastic triples are doing. The original board is depicted in grey and white for dead and alive, and the cells which flip from dead to alive in the "alternative" path consider by the triples are marked with + signs, while the cells which flip from alive to dead are marked with X signs.

```@raw html
<img src="../images/final_gol_board.png" width="50%"/>
``` ⠀






================================================
FILE: docs/src/tutorials/optimizations.md
================================================
# Stochastic optimizations with discrete randomness

```@setup random_walk
import Pkg
Pkg.activate("../../../tutorials/toy_optimizations")
Pkg.develop(path="../../..")
Pkg.instantiate()
```

In this tutorial, we solve two stochastic optimization problems using `StochasticAD` where the optimization objective is formed using discrete distributions. We will need the following packages:
```@example optimizations
using Distributions # defines several supported discrete distributions 
using StochasticAD
using CairoMakie # for plotting
using Optimisers # for stochastic gradient descent
```

## Optimizing our toy program

Recall the "crazy" program from the intro:
```@example optimizations
function X(p)
    a = p * (1 - p)
    b = rand(Binomial(10, p))
    c = 2 * b + 3 * rand(Bernoulli(p))
    return a * c * rand(Normal(b, a))
end
```

Let's maximize $\mathbb{E}[X(p)]$! First, let's setup the problem, using the [`StochasticModel`](@ref) helper utility to create a trainable model:
```@example optimizations
p0 = [0.5] # initial value of p, wrapped in an array for use in the stochastic model
m = StochasticModel(p -> -X(p[1]), p0) # formulate as minimization problem
```
Now, let's perform stochastic gradient descent using [Adam](https://arxiv.org/abs/1412.6980), where we use [`stochastic_gradient`](@ref) to obtain a gradient of the model.
```@example optimizations
iterations = 1000
trace = Float64[]
o = Adam() # use Adam for optimization
s = Optimisers.setup(o, m)
for i in 1:iterations
    # Perform a gradient step
    Optimisers.update!(s, m, stochastic_gradient(m))
    push!(trace, m.p[])
end
p_opt = m.p[] # Our optimized value of p
```
Finally, let's plot the results of our optimization, and also perform a sweep through the parameter space to verify the accuracy of our estimator:
```@example optimizations
## Sweep through parameters to find average and derivative
ps = 0.02:0.02:0.98 # values of p to sweep
N = 1000 # number of samples at each p
avg = [mean(X(p) for _ in 1:N) for p in ps]
derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]

## Make plots
f = Figure()
ax = f[1, 1] = Axis(f, title = "Estimates", xlabel="Value of p")
lines!(ax, ps, avg, label = "≈ E[X(p)]")
lines!(ax, ps, derivative, label = "≈ d/dp E[X(p)]")
vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0)
hlines!(ax, [0.0], color = :black, linewidth = 1.0)
ylims!(ax, (-50, 80))

f[1, 2] = Legend(f, ax, framevisible = false)
ax = f[2, 1:2] = Axis(f, title = "Optimizer trace", xlabel="Iterations", ylabel="Value of p")
lines!(ax, trace, color = :green, linewidth = 2.0)
save("crazy_opt.png", f,  px_per_unit = 4) # hide
nothing # hide
```
![](crazy_opt.png)

## Solving a variational problem

Let's consider a toy variational program: we find a [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) that is close to the distribution of a [negative Binomial](https://en.wikipedia.org/wiki/Negative_binomial_distribution), via minimization of the [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) $D_{\mathrm{KL}}$. Concretely, let us solve
```math
\underset{p \in \mathbb{R}}{\operatorname{argmin}}\; D_{\mathrm{KL}}\left(\mathrm{Pois}(p) \hspace{.3em}\middle\|\hspace{.3em} \mathrm{NBin}(10, 0.25) \right).
```
The following program produces an unbiased estimate of the objective:
```@example optimizations
function X(p)
    i = rand(Poisson(p))
    return logpdf(Poisson(p), i) - logpdf(NegativeBinomial(10, 0.25), i)
end
```
We can now optimize the KL-divergence via stochastic gradient descent!
```@example optimizations
# Minimize E[X] = KL(Poisson(p)| NegativeBinomial(10, 0.25))
iterations = 1000
p0 = [10.0]
m = StochasticModel(p -> X(p[1]), p0)
trace = Float64[]
o = Adam(0.1)
s = Optimisers.setup(o, m)
for i in 1:iterations
    Optimisers.update!(s, m, stochastic_gradient(m))
    push!(trace, m.p[])
end
p_opt = m.p[]
```
Let's plot our results in the same way as before:
```@example optimizations
ps = 10:0.5:50
N = 1000
avg = [mean(X(p) for _ in 1:N) for p in ps]
derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]
f = Figure()
ax = f[1, 1] = Axis(f, title = "Estimates", xlabel="Value of p")
lines!(ax, ps, avg, label = "≈ E[X(p)]")
lines!(ax, ps, derivative, label = "≈ d/dp E[X(p)]")
vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0)
hlines!(ax, [0.0], color = :black, linewidth = 1.0)
ylims!(ax, (-2.5, 5))

f[1, 2] = Legend(f, ax, framevisible = false)
ax = f[2, 1:2] = Axis(f, title = "Optimizer trace", ylabel="Value of p", xlabel="Iterations")
lines!(ax, trace, color = :green, linewidth = 2.0)
save("variational.png", f, px_per_unit = 4) # hide
nothing # hide
```
![](variational.png)


================================================
FILE: docs/src/tutorials/particle_filter.md
================================================
# Differentiable particle filter

Using a bootstrap particle sampler, we can approximate the posterior distributions
of the states given noisy and partial observations of the state of a hidden Markov
model by a cloud of `K` weighted particles with weights `W`.

In this tutorial, we are going to:
- implement a differentiable particle filter based on `StochasticAD.jl`.
- visualize the particle filter in ``d = 2`` dimensions.
- compare the gradient based on the differentiable particle filter to a biased
  gradient estimator as well as to the gradient of a differentiable Kalman filter.
- show how to benchmark primal evaluation, forward- and reverse-mode AD of the
  particle filter.

## Setup

We will make use of several julia packages. For example, we are going to use
`Distributions` and `DistributionsAD` that implement the reparameterization trick
for Gaussian distributions used in the observation and state-transition model, which
we specify below. We also import `GaussianDistributions.jl` to implement the
differentiable Kalman filter.

### Package dependencies

```@setup particle_filter
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()
```

```@example particle_filter
# activate tutorial project file

# load dependencies
using StochasticAD
using Distributions
using DistributionsAD
using Random
using Statistics
using StatsBase
using LinearAlgebra
using Zygote
using ForwardDiff
using GaussianDistributions
using GaussianDistributions: correct, ⊕
using Measurements
using UnPack
using Plots
using LaTeXStrings
using BenchmarkTools
```

### Particle filter

For convenience, we first introduce the new type `StochasticModel` with the following
fields:

- `T`: total number of time steps.
- `start`: starting distribution for the initial state. For example, in the form of a narrow
   Gaussian `start(θ) = Gaussian(x0, 0.001 * I(d))`.
- `dyn`: pointwise differentiable stochastic program in the form of Markov transition densities.
   For example, `dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q(θ))`, where `Q(θ)` denotes the
   covariance matrix.
- `obs`: observation model having a smooth conditional probability density depending on
   current state `x` and parameters `θ`. For example, `obs(x, θ) = MvNormal(x, R(θ))`,
   where `R(θ)` denotes the covariance matrix.

For parameters `θ`,  `rand(start(θ))` gives a sample from the prior distribution of the
starting distribution. For current state `x` and parameters `θ`, `xnew = rand(dyn(x, θ))`
samples the new state (i.e. `dyn` gives for each `x, θ` a distribution-like object). Finally,
`y = rand(obs(x, θ))` samples an observation.

We can then define the `ParticleFilter` type that wraps a stochastic model `StochM::StochasticModel`,
a sampling strategy (with arguments `p, K, sump=1`) and observational data `ys`.
For simplicity, our implementation assumes a observation-likelihood function being available
via `pdf(obs(x, θ), y)`.

```@example particle_filter
struct StochasticModel{TType<:Integer,T1,T2,T3}
    T::TType # time steps
    start::T1 # prior
    dyn::T2 # dynamical model
    obs::T3 # observation model
end

struct ParticleFilter{mType<:Integer,MType<:StochasticModel,yType,sType}
    m::mType # number of particles
    StochM::MType # stochastic model
    ys::yType # observations
    sample_strategy::sType # sampling function
end
```

### Kalman filter

We consider a stochastic program that fulfills the assumptions of a Kalman filter.
We follow [Kalman.jl](https://github.com/mschauer/Kalman.jl/blob/master/README.md) to implement a differentiable version.
Our `KalmanFilter` type wraps a stochastic model `StochM::StochasticModel` and observational data `ys`. It assumes a
observation-likelihood function is implemented via `llikelihood(yres, S)`. The Kalman filter
contains the following fields:

- `d`: dimension of the state-transition matrix ``\Phi`` according to ``x = \Phi x + w`` with ``w \sim \operatorname{Normal}(0,Q)``.
- `StochM`: Stochastic model of type `StochasticModel`.
- `H`: linear map from the state space into the observed space according to ``y = H x + \nu`` with ``\nu \sim \operatorname{Normal}(0,R)``.
- `R`: covariance matrix entering the observation model according to ``y = H x + \nu`` with ``\nu \sim \operatorname{Normal}(0,R)``.
- `Q`: covariance matrix entering the state-transition model according to ``x = \Phi x + w`` with ``w \sim \operatorname{Normal}(0,Q)``.
- `ys`: observations.


```@example particle_filter
llikelihood(yres, S) = GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres)
struct KalmanFilter{dType<:Integer,MType<:StochasticModel,HType,RType,QType,yType}
    # H, R = obs
    # θ, Q = dyn
    d::dType
    StochM::MType # stochastic model
    H::HType # observation model, maps the true state space into the observed space
    R::RType # observation model, covariance matrix
    Q::QType # dynamical model, covariance matrix
    ys::yType # observations
end
```

To get observations `ys` from the latent states `xs` based on the
(true, potentially unknown) parameters `θ`, we simulate a single particle
from the forward model returning a vector of observations (no resampling steps).

```@example particle_filter
function simulate_single(StochM::StochasticModel, θ)
    @unpack T, start, dyn, obs = StochM
    x = rand(start(θ))
    y = rand(obs(x, θ))
    xs = [x]
    ys = [y]
    for t in 2:T
        x = rand(dyn(x, θ))
        y = rand(obs(x, θ))
        push!(xs, x)
        push!(ys, y)
    end
    xs, ys
end
```

A particle filter becomes efficient if resampling steps are included. Resampling
is numerically attractive because particles with small weight are discarded, so
computational resources are not wasted on particles with vanishing weight.

Here, let us implement a stratified resampling strategy, see for example
[Murray (2012)](https://arxiv.org/abs/1202.6163), where `p` denotes the probabilities of `K` particles
with `sump = sum(p)`.

```@example particle_filter
function sample_stratified(p, K, sump=1)
    n = length(p)
    U = rand()
    is = zeros(Int, K)
    i = 1
    cw = p[1]
    for k in 1:K
        t = sump * (k - 1 + U) / K
        while cw < t && i < n
            i += 1
            @inbounds cw += p[i]
        end
        is[k] = i
    end
    return is
end
```

This sampling strategy can be used within a differentiable resampling step in our
particle filter using the `use_new_weight` function as implemented in
`StochasticAD.jl`. The `resample` function below returns the states `X_new`
and weights `W_new` of the resampled particles.

- `m`: number of particles.
- `X`: current particle states.
- `W`: current weight vector of the particles.
- `ω == sum(W)` is an invariant.
- `sample_strategy`: specific resampling strategy to be used. For example, `sample_stratified`.
- `use_new_weight=true`: Allows one to switch between biased, stop-gradient method and
   differentiable resampling step.

```@example particle_filter
function resample(m, X, W, ω, sample_strategy, use_new_weight=true)
    js = Zygote.ignore(() -> sample_strategy(W, m, ω))
    X_new = X[js]
    if use_new_weight
        # differentiable resampling
        W_chosen = W[js]
        W_new = map(w -> ω * new_weight(w / ω) / m, W_chosen)
    else
        # stop gradient, biased approach
        W_new = fill(ω / m, m)
    end
    X_new, W_new
end
```

Note that we added a `if` condition that allows us to switch between the differentiable
resampling step and the stop-gradient approach.

We're now equipped with all primitive operations to set up the particle filter,
which propagates particles with weights `W` preserving the invariant `ω == sum(W)`.
We never normalize `W` and, therefore, `ω` in the code below contains likelihood
information. The particle-filter implementation defaults to return particle
positions and weights at `T` if `store_path=false` and takes the following input
arguments:

- `θ`: parameters for the stochastic program (state-transition and observation model).
- `store_path=false`: Option to store the path of the particles, e.g. to visualize/inspect
  their trajectories.
- `use_new_weight=true`: Option to switch between the stop-gradient and our differentiable
  resampling step method. Defaults to using differentiable resampling.
- `s`: controls the number of resampling steps according to `t > 1 && t < T && (t % s == 0)`.


```@example particle_filter
function (F::ParticleFilter)(θ; store_path=false, use_new_weight=true, s=1)
    # s controls the number of resampling steps
    @unpack m, StochM, ys, sample_strategy = F
    @unpack T, start, dyn, obs = StochM


    X = [rand(start(θ)) for j in 1:m] # particles
    W = [1 / m for i in 1:m] # weights
    ω = 1 # total weight
    store_path && (Xs = [X])
    for (t, y) in zip(1:T, ys)
        # update weights & likelihood using observations
        wi = map(x -> pdf(obs(x, θ), y), X)
        W = W .* wi
        ω_old = ω
        ω = sum(W)
        # resample particles
        if t > 1 && t < T && (t % s == 0) # && 1 / sum((W / ω) .^ 2) < length(W) ÷ 32
            X, W = resample(m, X, W, ω, sample_strategy, use_new_weight)
        end
        # update particle states
        if t < T
            X = map(x -> rand(dyn(x, θ)), X)
            store_path && Zygote.ignore(() -> push!(Xs, X))
        end
    end
    (store_path ? Xs : X), W
end
```

Following [Kalman.jl](https://github.com/mschauer/Kalman.jl/blob/master/README.md), we implement
a differentiable Kalman filter to check the ground-truth gradient. Our Kalman filter
returns an updated posterior state estimate and the log-likelihood and takes the
parameters of the stochastic program as an input.

```@example particle_filter
function (F::KalmanFilter)(θ)
    @unpack d, StochM, H, R, Q = F
    @unpack start = StochM

    x = start(θ)
    Φ = reshape(θ, d, d)

    x, yres, S = GaussianDistributions.correct(x, ys[1] + R, H)
    ll = llikelihood(yres, S)
    xs = Any[x]
    for i in 2:length(ys)
        x = Φ * x ⊕ Q
        x, yres, S = GaussianDistributions.correct(x, ys[i] + R, H)
        ll += llikelihood(yres, S)

        push!(xs, x)
    end
    xs, ll
end
```

For both filters, it is straightforward to obtain the log-likelihood via:

```@example particle_filter
function log_likelihood(F::ParticleFilter, θ, use_new_weight=true, s=1)
    _, W = F(θ; store_path=false, use_new_weight=use_new_weight, s=s)
    log(sum(W))
end
```
and
```@example particle_filter
function log_likelihood(F::KalmanFilter, θ)
    _, ll = F(θ)
    ll
end
```

For convenience, we define functions for
- forward-mode AD (and differentiable resampling step) to compute the gradient of
  the log-likelihood of the particle filter.
- reverse-mode AD (and differentiable resampling step) to compute the gradient of
  the log-likelihood of the particle filter.
- forward-mode AD (and stop-gradient method) to compute the gradient of
  the log-likelihood of the particle filter (without the `new_weight` function).
- forward-mode AD to compute the gradient of the log-likelihood of the Kalman filter.

```@example particle_filter

forw_grad(θ, F::ParticleFilter; s=1) = ForwardDiff.gradient(θ -> log_likelihood(F, θ, true, s), θ)
back_grad(θ, F::ParticleFilter; s=1) = Zygote.gradient(θ -> log_likelihood(F, θ, true, s), θ)[1]
forw_grad_biased(θ, F::ParticleFilter; s=1) = ForwardDiff.gradient(θ -> log_likelihood(F, θ, false, s), θ)
forw_grad_Kalman(θ, F::KalmanFilter) = ForwardDiff.gradient(θ -> log_likelihood(F, θ), θ)
```

## Model

Having set up all core functionalities, we can now define the specific stochastic
model.

We consider the following system with a ``d``-dimensional latent process,

```math
\begin{aligned}
x_i &= \Phi x_{i-1} + w_i &\text{ with } w_i \sim \operatorname{Normal}(0,Q),\\
y_i &= x_i + \nu_i &\text{ with } \nu_i \sim \operatorname{Normal}(0,R),
\end{aligned}
```

where ``\Phi`` is a ``d``-dimensional rotation matrix.

```@example particle_filter
seed = 423897

### Define model
# here: n-dimensional rotation matrix
Random.seed!(seed)
T = 20 # time steps
d = 2 # dimension
# generate a rotation matrix
M = randn(d, d)
c = 0.3 # scaling
O = exp(c * (M - transpose(M)) / 2)
@assert det(O) ≈ 1
@assert transpose(O) * O ≈ I(d)
θtrue = vec(O) # true parameter

# observation model
R = 0.01 * collect(I(d))
obs(x, θ) = MvNormal(x, R) # y = H x + ν with ν ~ Normal(0, R)

# dynamical model
Q = 0.02 * collect(I(d))
dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q) #  x = Φ*x + w with w ~ Normal(0,Q)

# starting position
x0 = randn(d)
# prior distribution
start(θ) = Gaussian(x0, 0.001 * collect(I(d)))

# put it all together
stochastic_model = StochasticModel(T, start, dyn, obs)

# relevant corresponding Kalman filterng defs
H_Kalman = collect(I(d))
R_Kalman = Gaussian(zeros(Float64, d), R)
# Φ_Kalman = O
Q_Kalman = Gaussian(zeros(Float64, d), Q)
###

### simulate model
Random.seed!(seed)
xs, ys = simulate_single(stochastic_model, θtrue)
```

## Visualization

Using `particle_filter(θ; store_path=true)` and `kalman_filter(θ)`, it is
straightforward to visualize both filters for our observed data.

```@example particle_filter
m = 1000
kalman_filter = KalmanFilter(d, stochastic_model, H_Kalman, R_Kalman, Q_Kalman, ys)
particle_filter = ParticleFilter(m, stochastic_model, ys, sample_stratified)
```


```@example particle_filter
### run and visualize filters
Xs, W = particle_filter(θtrue; store_path=true)
fig = plot(getindex.(xs, 1), getindex.(xs, 2), legend=false, xlabel=L"x_1", ylabel=L"x_2") # x1 and x2 are bad names..conflicting notation
scatter!(fig, getindex.(ys, 1), getindex.(ys, 2))
for i in 1:min(m, 100) # note that Xs has obs noise.
    local xs = [Xs[t][i] for t in 1:T]
    scatter!(fig, getindex.(xs, 1), getindex.(xs, 2), marker_z=1:T, color=:cool, alpha=0.1) # color to indicate time step
end

xs_Kalman, ll_Kalman = kalman_filter(θtrue)
plot!(getindex.(mean.(xs_Kalman), 1), getindex.(mean.(xs_Kalman), 2), legend=false, color="red")
png("pf_1") # hide
```
![](pf_1.png)

## Bias

We can also investigate the distribution of the gradients from the particle filter
with and without differentiable resampling step, as compared to the gradient computed
by differentiating the Kalman filter.

```@example particle_filter
### compute gradients
Random.seed!(seed)
X = [forw_grad(θtrue, particle_filter) for i in 1:200] # gradient of the particle filter *with* differentiation of the resampling step
Random.seed!(seed)
Xbiased = [forw_grad_biased(θtrue, particle_filter) for i in 1:200] # Gradient of the particle filter *without* differentiation of the resampling step
# pick an arbitrary coordinate
index = 1 # take derivative with respect to first parameter (2-dimensional example has a rotation matrix with four parameters in total)
# plot histograms for the sampled derivative values
fig = plot(normalize(fit(Histogram, getindex.(X, index), nbins=20), mode=:pdf), legend=false) # ours
plot!(normalize(fit(Histogram, getindex.(Xbiased, index), nbins=20), mode=:pdf)) # biased
vline!([mean(X)[index]], color=1)
vline!([mean(Xbiased)[index]], color=2)
# add derivative of differentiable Kalman filter as a comparison
XK = forw_grad_Kalman(θtrue, kalman_filter)
vline!([XK[index]], color="black")
png("pf_2") # hide
```
![](pf_2.png)

The estimator using the `new_weight` function agrees with the gradient value from
the Kalman filter and the [particle filter AD scheme developed by Ścibior and Wood](https://arxiv.org/abs/2106.10314),
unlike biased estimators that neglect the contribution of the derivative from the
resampling step. However, the biased estimator displays a smaller variance.

## Benchmark

Finally, we can use `BenchmarkTools.jl` to benchmark the run times of the primal
pass with respect to forward-mode and reverse-mode AD of the particle filter. As
expected, forward-mode AD outperforms reverse-mode AD for the small number of
parameters considered here.

```@example particle_filter
# secs for how long the benchmark should run, see https://juliaci.github.io/BenchmarkTools.jl/stable/
secs = 1

suite = BenchmarkGroup()
suite["scaling"] = BenchmarkGroup(["grads"])

suite["scaling"]["primal"] = @benchmarkable log_likelihood(particle_filter, θtrue)
suite["scaling"]["forward"] = @benchmarkable forw_grad(θtrue, particle_filter)
suite["scaling"]["backward"] = @benchmarkable back_grad(θtrue, particle_filter)

tune!(suite)
results = run(suite, verbose=true, seconds=secs)

t1 = measurement(mean(results["scaling"]["primal"].times), std(results["scaling"]["primal"].times) / sqrt(length(results["scaling"]["primal"].times)))
t2 = measurement(mean(results["scaling"]["forward"].times), std(results["scaling"]["forward"].times) / sqrt(length(results["scaling"]["forward"].times)))
t3 = measurement(mean(results["scaling"]["backward"].times), std(results["scaling"]["backward"].times) / sqrt(length(results["scaling"]["backward"].times)))
@show t1 t2 t3

ts = (t1, t2, t3) ./ 10^6 # ms
@show ts
```


================================================
FILE: docs/src/tutorials/random_walk.md
================================================
# Random walk

```@setup random_walk
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()
```

In this tutorial, we differentiate a random walk over the integers using `StochasticAD`. We will need the following packages,

```@example random_walk
using Distributions # defines several supported discrete distributions 
using StochasticAD
using StaticArrays # for more efficient small arrays
```

## Setting up the random walk

Let's define a function for simulating the walk.
```@example random_walk
function simulate_walk(probs, steps, n)
    state = 0
    for i in 1:n
        probs_here = probs(state) # transition probabilities for possible steps
        step_index = rand(Categorical(probs_here)) # which step do we take?
        step = steps[step_index] # get size of step 
        state += step
    end
    return state
end
```
Here, `steps` is a (1-indexed) array of the possible steps we can take. Each of these steps has a certain probability. To make things more interesting, we take in a *function* `probs` to produce these probabilities that can depend on the current state of the random walk.

Let's zoom in on the two lines where discrete randomness is involved. 
```
step_index = rand(Categorical(probs_here)) # which step do we take?
step = steps[step_index] # get size of step 
```
This is a cute pattern for making a discrete choice. First, we sample from a `Categorical` distribution from `Distributions.jl`, using the probabilities `probs_here` at our current position. This gives us an index between `1` and `length(steps)`, which we can use to pick the actual step to take. Stochastic triples propagate through both steps!

## Differentiating the random walk

Let's define a toy problem. We consider a random walk with `-1` and `+1` steps, where the probability of `+1` starts off high but decays exponentially with a decay length of `p`. We take `n = 100` steps and set `p = 50`.
```@example random_walk
using StochasticAD

const steps = SA[-1, 1] # move left or move right
make_probs(p) = X -> SA[1 - exp(-X / p), exp(-X / p)]

f(p, n) = simulate_walk(make_probs(p), steps, n)
@show f(50, 100) # let's run a single random walk with p = 50
@show stochastic_triple(p -> f(p, 100), 50) # let's see how a single stochastic triple looks like at p = 50
```
Time to differentiate! For fun, let's differentiate the *square* of the output of the random walk.
```@example random_walk
f_squared(p, n) = f(p, n)^2

samples = [derivative_estimate(p -> f_squared(p, 100), 50) for i in 1:1000] # many samples from derivative program at p = 50
derivative = mean(samples)
uncertainty = std(samples) / sqrt(1000)
println("derivative of 𝔼[f_squared] = $derivative ± $uncertainty")
```

## Computing variance

A crucial figure of merit for a derivative estimator is its variance. We compute the standard deviation (square root of the variance) of our estimator over a range of `n`.
```@example random_walk
n_range = 10:10:100 # range for testing asymptotic variance behaviour
p_range = 2 .* n_range
nsamples = 10000

stds_triple = Float64[]
for (n, p) in zip(n_range, p_range)
    std_triple = std(derivative_estimate(p -> f_squared(p, n), p)
                     for i in 1:(nsamples))
    push!(stds_triple, std_triple)
end
@show stds_triple
```
For comparison with other unbiased estimators, we also compute `stds_score` and `stds_score_baseline` for the
[score function gradient estimator](https://arxiv.org/pdf/1906.10652.pdf), both without and with a variance-reducing batch-average control variate (CV). (For details, see [`core.jl`](https://github.com/gaurav-arya/StochasticAD.jl/blob/main/tutorials/random_walk/core.jl) and [`compare_score.jl`](https://github.com/gaurav-arya/StochasticAD.jl/blob/main/random_walk/compare_score.jl).) We can now graph the standard deviation of each estimator versus $n$, observing lower variance in the unbiased derivative estimate produced by stochastic triples:

```@raw html
<img src="../images/compare_score.png" width="50%"/>
``` ⠀



================================================
FILE: docs/src/tutorials/reverse_demo.md
================================================
```@meta
EditURL = "../../../tutorials/reverse_example/reverse_demo.jl"
```

# Simple reverse mode example

```@setup random_walk
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()

import Random
Random.seed!(1234)
```

Load our packages

````@example reverse_demo
using StochasticAD
using Distributions
using Enzyme
using LinearAlgebra
````

Let us define our target function.

````@example reverse_demo
# Define a toy `StochasticAD`-differentiable function for computing an integer value from a string.
string_value(strings, index) = Int(sum(codepoint, strings[index]))
string_value(strings, index::StochasticTriple) = StochasticAD.propagate(index -> string_value(strings, index), index)

function f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling())
    strings = ["cat", "dog", "meow", "woofs"]
    index = randst(Categorical(θ); derivative_coupling)
    return string_value(strings, index)
end

θ = [0.1, 0.5, 0.3, 0.1]
@show f(θ)
nothing
````

First, let's compute the sensitivity of `f` in a particular direction via forward-mode Stochastic AD.

````@example reverse_demo
u = [1.0, 2.0, 4.0, -7.0]
@show derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
nothing
````

Now, let's do the same with reverse-mode.

````@example reverse_demo
@show derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))
````

Let's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative.

````@example reverse_demo
forward() = derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
reverse() = derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))

N = 40000
directional_derivs_fwd = [forward() for i in 1:N]
derivs_bwd = [reverse() for i in 1:N]
directional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd]
println("Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))")
println("Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))")
@assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2)

nothing
````

---

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*



================================================
FILE: ext/StochasticADEnzymeExt.jl
================================================
module StochasticADEnzymeExt

using StochasticAD
using Enzyme

function enzyme_target(u, X, p, backend)
    # equivalent to derivative_estimate(X, p; backend, direction = u), but specialize to real output to make Enzyme happier
    st = StochasticAD.stochastic_triple_direction(X, p, u; backend)
    if !(StochasticAD.valtype(st) <: Real)
        error("EnzymeReverseAlgorithm only supports real-valued outputs.")
    end
    return derivative_contribution(st)
end

function StochasticAD.derivative_estimate(X, p, alg::StochasticAD.EnzymeReverseAlgorithm;
        direction = nothing, alg_data = (; forward_u = nothing))
    if !isnothing(direction)
        error("EnzymeReverseAlgorithm does not support keyword argument `direction`")
    end
    if p isa AbstractVector
        Δu = zeros(float(eltype(p)), length(p))
        u = isnothing(alg_data.forward_u) ?
            rand(StochasticAD.RNG, float(eltype(p)), length(p)) : alg_data.forward_u
        autodiff(Enzyme.Reverse, enzyme_target, Active, Duplicated(u, Δu),
            Const(X), Const(p), Const(alg.backend))
        return Δu
    elseif p isa Real
        u = isnothing(alg_data.forward_u) ? rand(StochasticAD.RNG, float(typeof(p))) :
            forward_u
        ((du, _, _, _),) = autodiff(Enzyme.Reverse, enzyme_target, Active, Active(u),
            Const(X), Const(p), Const(alg.backend))
        return du
    else
        error("EnzymeReverseAlgorithm only supports p::Real or p::AbstractVector")
    end
end

end


================================================
FILE: src/StochasticAD.jl
================================================
module StochasticAD

### Public API

export stochastic_triple, derivative_contribution, perturbations, smooth_triple,
       dual_number, StochasticTriple # For working with stochastic triples
export derivative_estimate, StochasticModel, stochastic_gradient # Higher level functionality
export new_weight # Particle resampling
export PrunedFIsBackend,
       PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend,
       StrategyWrapperFIsBackend
export PrunedFIs, PrunedFIsAggressive, DictFIs, SmoothedFIs, StrategyWrapperFIs
export randst
export InversionMethodDerivativeCoupling

### Imports

using Random
using Distributions
using DistributionsAD
using ChainRulesCore
using ChainRulesOverloadGeneration
using ExprTools
using ForwardDiff
using Functors
import ChainRulesCore
# resolve conflicts while this code exists in both.
const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
const refresh_rules = ChainRulesOverloadGeneration.refresh_rules

const RNG = copy(Random.default_rng())

### Files responsible for backends

include("finite_infinitesimals.jl")
include("backends/pruned.jl")
include("backends/pruned_aggressive.jl")
include("backends/dict.jl")
include("backends/smoothed.jl")
include("backends/abstract_wrapper.jl")
include("backends/strategy_wrapper.jl")
using .PrunedFIsModule
using .PrunedFIsAggressiveModule
using .DictFIsModule
using .SmoothedFIsModule
using .AbstractWrapperFIsModule
using .StrategyWrapperFIsModule

include("prelude.jl") # Defines global constants
include("smoothing.jl") # Smoothing rules. Placed before general rules so that new_weight frule is caught by overload generation.
include("stochastic_triple.jl") # Defines stochastic triple object and higher level functions
include("general_rules.jl") # Defines rules for propagation through deterministic functions
include("discrete_randomness.jl") # Defines rules for propagation through discrete random functions
include("propagate.jl") # Experimental generalized forward propagation functionality
include("algorithms.jl") # Add algorithm-based higher-level interface 
include("misc.jl") # Miscellaneous functions that do not fit in the usual flow

end


================================================
FILE: src/algorithms.jl
================================================
abstract type AbstractStochasticADAlgorithm end

"""
    ForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm
    
A differentiation algorithm relying on forward propagation of stochastic triples.

The `backend` argument controls the algorithm used by the third component of the stochastic triples.

!!! note 
    The required computation time for forward-mode AD scales linearly with the number of 
    parameters in `p` (but is unaffected by the number of parameters in `X(p)`).
"""
struct ForwardAlgorithm{B <: StochasticAD.AbstractFIsBackend} <:
       AbstractStochasticADAlgorithm
    backend::B
end

"""
    EnzymeReverseAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm

A differentiation algorithm relying on transposing the propagation of stochastic triples to
produce a reverse-mode algorithm. The transposition is performed by [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl),
which must be loaded for the algorithm to run.

Currently, only real- and vector-valued inputs are supported, and only real-valued outputs are supported.

The `backend` argument controls the algorithm used by the third component of the stochastic triples.

In the call to `derivative_estimate`, this algorithm optionally accepts `alg_data` with the field `forward_u`,
which specifies the directional derivative used in the forward pass that will be transposed. 
If `forward_u` is not provided, it is randomly generated.

!!! warning
    For the reverse-mode algorithm to yield correct results, the employed `backend` cannot use input-dependent pruning  
    strategies. A suggested reverse-mode compatible backend is `PrunedFIsBackend(Val(:wins))`.
    
    Additionally, this algorithm relies on the ability of `Enzyme.jl` to differentiate the forward stochastic triple run.
    It is recommended to check that the primal function `X` is type stable for its input `p` using a tool such as
    [JET.jl](https://github.com/aviatesk/JET.jl), with all code executed in a function with no global state. 
    In addition, sometimes `X` may be type stable but stochastic triples introduce additional type instabilities.
    This can be debugged by checking type stability of Enzyme's target, which is
    `Base.get_extension(StochasticAD, :StochasticADEnzymeExt).enzyme_target(u, X, p, backend)`,
    where `u` is a test direction.
    
!!! note
    For more details on the reverse-mode approach, see the following papers and talks:
    
    * ["You Only Linearize Once: Tangents Transpose to Gradients"](https://arxiv.org/abs/2204.10923), Radul et al. 2022.
    * ["Reverse mode ADEV via YOLO: tangent estimators transpose to gradient estimators"](https://www.youtube.com/watch?v=pnPmk-leSsE), Becker et al. 2024
    * ["Probabilistic Programming with Programmable Variational Inference"](https://pldi24.sigplan.org/details/pldi-2024-papers/87/Probabilistic-Programming-with-Programmable-Variational-Inference), Becker et al. 2024
"""
struct EnzymeReverseAlgorithm{B <: StochasticAD.AbstractFIsBackend}
    backend::B
end

function derivative_estimate(
        X, p, alg::ForwardAlgorithm; direction = nothing, alg_data::NamedTuple = (;))
    return derivative_estimate(X, p; backend = alg.backend, direction)
end

@doc raw"""
    derivative_estimate(X, p, alg::AbstractStochasticADAlgorithm = ForwardAlgorithm(PrunedFIsBackend()); direction=nothing, alg_data::NamedTuple = (;))

Compute an unbiased estimate of ``\frac{\mathrm{d}\mathbb{E}[X(p)]}{\mathrm{d}p}``, 
the derivative of the expectation of the random function `X(p)` with respect to its input `p`.

Both `p` and `X(p)` can be any object supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/),
e.g. scalars or abstract arrays. 
The output of `derivative_estimate` has the same outer structure as `p`, but with each
scalar in `p` replaced by a derivative estimate of `X(p)` with respect to that entry.
For example, if `X(p) <: AbstractMatrix` and `p <: Real`, then the output would be a matrix.

The `alg` keyword argument specifies the [algorithm](public_api.md#Algorithms) used to compute the derivative estimate.
For backward compatibility, an additional signature `derivative_estimate(X, p; backend, direction=nothing)`
is supported, which uses `ForwardAlgorithm` by default with the supplied `backend.`
The `alg_data` keyword argument can specify any additional data that specific algorithms accept or require.

When `direction` is provided, the output is only differentiated with respect to a perturbation
of `p` in that direction.

# Example
```jldoctest
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);

julia> derivative_estimate(rand ∘ Bernoulli, 0.5) # A random quantity that averages to the true derivative.
2.0

julia> derivative_estimate(x -> [rand(Bernoulli(x * i/4)) for i in 1:3], 0.5)
3-element Vector{Float64}:
 0.2857142857142857
 0.6666666666666666
 0.0
```
"""
derivative_estimate


================================================
FILE: src/backends/abstract_wrapper.jl
================================================
module AbstractWrapperFIsModule

import ..StochasticAD

export AbstractWrapperFIs

"""
    AbstractWrapperFIs{V, FIs} <: StochasticAD.AbstractFIs{V}

A convenience type for backend strategies that wrap another backend. A subtype `WrapperFIs <: AbstractWrapperFIs`
should have a field called Δs containing the wrapped backend, and should also define the following methods:
* `StochasticAD.similar_type(::Type{<:WrapperFIs}, V, FIs)`: return the type of a new
    `WrapperFIs` with value type `V` and wrapped backend type `FIs`,
* `AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::WrapperFIs, Δs::AbstractFIs)`: construct
a new `WrapperFIs` wrapping `Δs` given an existing wrapped instace `wrapper_Δs`. 
* `AbstractWrapperFIsModule.reconstruct_wrapper(::Type{<:WrapperFIs}, Δs::AbstractFIs)`: construct
a new `WrapperFIs` wrapping `Δs` given the type of an existing `WrapperFIs`.

Then, all other methods will generically be forwarded to the inner backend, except those overloaded by the
specific wrapper type.
"""
abstract type AbstractWrapperFIs{V, FIs} <: StochasticAD.AbstractFIs{V} end

function reconstruct_wrapper end

function StochasticAD.similar_new(Δs::AbstractWrapperFIs, Δ, w)
    reconstruct_wrapper(Δs, StochasticAD.similar_new(Δs.Δs, Δ, w))
end
function StochasticAD.similar_empty(Δs::AbstractWrapperFIs, V)
    reconstruct_wrapper(Δs, StochasticAD.similar_empty(Δs.Δs, V))
end

function StochasticAD.similar_type(WrapperFIs::Type{<:AbstractWrapperFIs{V0, FIs}},
        V) where {V0, FIs}
    return StochasticAD.similar_type(WrapperFIs, V, StochasticAD.similar_type(FIs, V))
end

StochasticAD.valtype(Δs::AbstractWrapperFIs) = StochasticAD.valtype(Δs.Δs)

function StochasticAD.couple(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
        Δs_all;
        rep = nothing,
        kwargs...) where {V, FIs}
    _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)
    _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
    return reconstruct_wrapper(StochasticAD.get_any(Δs_all),
        StochasticAD.couple(FIs, _Δs_all; _rep_kwarg..., kwargs...))
end

function StochasticAD.combine(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
        Δs_all;
        rep = nothing,
        kwargs...) where {V, FIs}
    _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)
    _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
    return reconstruct_wrapper(StochasticAD.get_any(Δs_all),
        StochasticAD.combine(FIs, _Δs_all; _rep_kwarg..., kwargs...))
end

function StochasticAD.get_rep(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
        Δs_all;
        kwargs...) where {V, FIs}
    _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)
    return reconstruct_wrapper(StochasticAD.get_any(Δs_all),
        StochasticAD.get_rep(FIs, _Δs_all; kwargs...))
end

function StochasticAD.scalarize(Δs::AbstractWrapperFIs; rep = nothing, kwargs...)
    _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
    return StochasticAD.structural_map(StochasticAD.scalarize(
        Δs.Δs; _rep_kwarg..., kwargs...)) do _Δs
        reconstruct_wrapper(Δs, _Δs)
    end
end

function StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs, Δs_all; kwargs...)
    StochasticAD.derivative_contribution(Δs.Δs, Δs_all; kwargs...)
end

StochasticAD.alltrue(f, Δs::AbstractWrapperFIs) = StochasticAD.alltrue(f, Δs.Δs)

StochasticAD.perturbations(Δs::AbstractWrapperFIs) = StochasticAD.perturbations(Δs.Δs)

function StochasticAD.filter_state(Δs::AbstractWrapperFIs, state)
    StochasticAD.filter_state(Δs.Δs, state)
end

function StochasticAD.weighted_map_Δs(f, Δs::AbstractWrapperFIs; kwargs...)
    reconstruct_wrapper(Δs, StochasticAD.weighted_map_Δs(f, Δs.Δs; kwargs...))
end

StochasticAD.new_Δs_strategy(Δs::AbstractWrapperFIs) = StochasticAD.new_Δs_strategy(Δs.Δs)

function Base.empty(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}) where {V, FIs}
    return reconstruct_wrapper(WrapperFIs, empty(FIs))
end

Base.empty(Δs::AbstractWrapperFIs) = reconstruct_wrapper(Δs, empty(Δs.Δs))
Base.isempty(Δs::AbstractWrapperFIs) = isempty(Δs.Δs)
Base.length(Δs::AbstractWrapperFIs) = length(Δs.Δs)
Base.iszero(Δs::AbstractWrapperFIs) = iszero(Δs.Δs)

function StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs)
    StochasticAD.derivative_contribution(Δs.Δs)
end

function Base.convert(::Type{<:AbstractWrapperFIs{V}}, Δs::AbstractWrapperFIs) where {V}
    reconstruct_wrapper(Δs, convert(StochasticAD.similar_type(typeof(Δs.Δs), V), Δs.Δs))
end

function StochasticAD.send_signal(
        Δs::AbstractWrapperFIs, signal::StochasticAD.AbstractPerturbationSignal)
    reconstruct_wrapper(Δs, StochasticAD.send_signal(Δs.Δs, signal))
end

function Base.show(io::IO, Δs::AbstractWrapperFIs)
    return show(io, Δs.Δs)
end

end


================================================
FILE: src/backends/dict.jl
================================================
module DictFIsModule

export DictFIsBackend, DictFIs

import ..StochasticAD
using Dictionaries

"""
    DictFIsBackend <: StochasticAD.AbstractFIsBackend

A dictionary backend algorithm which keeps entries for each perturbation that has occurred without pruning. 
Currently very unoptimized.
"""
struct DictFIsBackend <: StochasticAD.AbstractFIsBackend end

"""
    DictFIsState    

State maintained by dictionary backend.
"""
mutable struct DictFIsState
    tag_count::Int64
    valid::Bool
    DictFIsState(valid = true) = new(0, valid)
end

struct InfinitesimalEvent
    tag::Any # unique identifier
    w::Float64 # weight (infinitesimal probability wε) 
end

Base.:<(event1::InfinitesimalEvent, event2::InfinitesimalEvent) = event1.tag < event2.tag
function Base.:(==)(event1::InfinitesimalEvent, event2::InfinitesimalEvent)
    event1.tag == event2.tag
end
Base.:isless(event1::InfinitesimalEvent, event2::InfinitesimalEvent) = event1 < event2

"""
    DictFIs{V} <: StochasticAD.AbstractFIs{V}

The implementing backend structure for DictFIsBackend.
"""
struct DictFIs{V} <: StochasticAD.AbstractFIs{V}
    dict::Dictionary{InfinitesimalEvent, V}
    state::DictFIsState
end

state(Δs::DictFIs) = Δs.state

### Empty / no perturbation

function DictFIs{V}(state::DictFIsState) where {V}
    DictFIs{V}(Dictionary{InfinitesimalEvent, V}(), state)
end
StochasticAD.similar_empty(Δs::DictFIs, V::Type) = DictFIs{V}(Δs.state)
Base.empty(Δs::DictFIs{V}) where {V} = StochasticAD.similar_empty(Δs::DictFIs, V::Type)
function Base.empty(::Type{<:DictFIs{V}}) where {V}
    DictFIs{V}(DictFIsState(false))
end

### Create a new perturbation with infinitesimal probability

function new_perturbation(Δ::V, w::Real, state::DictFIsState) where {V}
    state.tag_count += 1
    event = InfinitesimalEvent(state.tag_count, w)
    DictFIs{V}(Dictionary([event], [Δ]), state)
end
function StochasticAD.similar_new(Δs::DictFIs, Δ::V, w::Real) where {V}
    new_perturbation(Δ, w, Δs.state)
end

### Create Δs backend for the first stochastic triple of computation

StochasticAD.create_Δs(::DictFIsBackend, V) = DictFIs{V}(DictFIsState())

### Convert type of a backend

function Base.convert(::Type{DictFIs{V}}, Δs::DictFIs) where {V}
    DictFIs{V}(convert(Dictionary{InfinitesimalEvent, V}, Δs.dict), Δs.state)
end

### Getting information about Δs

Base.isempty(Δs::DictFIs) = isempty(Δs.dict)
Base.length(Δs::DictFIs) = length(Δs.dict)
Base.iszero(Δs::DictFIs) = isempty(Δs) || all(iszero.(Δs.dict))
function StochasticAD.derivative_contribution(Δs::DictFIs{V}) where {V}
    sum((Δ * event.w for (event, Δ) in pairs(Δs.dict)), init = zero(V) * 0.0)
end

function StochasticAD.perturbations(Δs::DictFIs)
    [(; Δ, weight = event.w, state = event) for (event, Δ) in pairs(Δs.dict)]
end

### Unary propagation

function StochasticAD.weighted_map_Δs(f, Δs::DictFIs; kwargs...)
    # Pass key as state in map
    mapped_values_and_weights = map(f, collect(Δs.dict), keys(Δs.dict))
    mapped_values = first.(mapped_values_and_weights)
    mapped_weights = last.(mapped_values_and_weights)
    scaled_events = map((event, a) -> InfinitesimalEvent(event.tag, event.w * a),
        keys(Δs.dict),
        mapped_weights) # TODO: should original events (with old tag) also be modified?
    dict = Dictionary(scaled_events, mapped_values)
    DictFIs(dict, Δs.state)
end

StochasticAD.alltrue(f, Δs::DictFIs) = all(map(f, collect(Δs.dict)))

### Coupling

function StochasticAD.get_rep(::Type{<:DictFIs}, Δs_all)
    for Δs in StochasticAD.structural_iterate(Δs_all)
        if Δs.state.valid
            return Δs
        end
    end
    return first(Δs_all)
end

function StochasticAD.couple(FIs::Type{<:DictFIs}, Δs_all;
        rep = StochasticAD.get_rep(FIs, Δs_all),
        out_rep = nothing,
        kwargs...)
    all_keys = Iterators.map(StochasticAD.structural_iterate(Δs_all)) do Δs
        keys(Δs.dict)
    end
    distinct_keys = unique(all_keys |> Iterators.flatten)
    Δs_coupled_dict = [StochasticAD.structural_map(
                           Δs -> isassigned(Δs.dict, key) ?
                                 Δs.dict[key] :
                                 zero(eltype(Δs.dict)),
                           Δs_all)
                       for key in distinct_keys]
    DictFIs(Dictionary(distinct_keys, Δs_coupled_dict), rep.state)
end

function StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all;
        rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...)
    Δs_dicts = Iterators.map(Δs -> Δs.dict, StochasticAD.structural_iterate(Δs_all))
    Δs_combined_dict = reduce(Δs_dicts) do Δs_dict1, Δs_dict2
        mergewith((x, y) -> StochasticAD.structural_map(+, x, y), Δs_dict1, Δs_dict2)
    end
    DictFIs(Δs_combined_dict, rep.state)
end

function StochasticAD.scalarize(Δs::DictFIs; out_rep = nothing)
    # TODO: use vcat here?
    tupleify(Δ1, Δ2) = StochasticAD.structural_map(tuple, Δ1, Δ2)
    Δ_all_allkeys = foldl(tupleify, values(Δs.dict))
    Δ_all_rep = first(values(Δs.dict))
    _keys = keys(Δs.dict)
    return StochasticAD.structural_map(Δ_all_rep, Δ_all_allkeys) do _, Δ_allkeys
        return DictFIs(Dictionary(_keys, Δ_allkeys), Δs.state)
    end
end

function StochasticAD.filter_state(Δs::DictFIs{V}, key) where {V}
    haskey(Δs.dict, key) ? Δs.dict[key] : zero(V)
end

### Miscellaneous

StochasticAD.similar_type(::Type{<:DictFIs}, V::Type) = DictFIs{V}
StochasticAD.valtype(::Type{<:DictFIs{V}}) where {V} = V

end


================================================
FILE: src/backends/pruned.jl
================================================
module PrunedFIsModule

import ..StochasticAD

export PrunedFIsBackend, PrunedFIs

"""
    PrunedFIsBackend <: StochasticAD.AbstractFIsBackend

A backend algorithm that prunes between perturbations as soon as they clash (e.g. added together).
Currently chooses uniformly between all perturbations.
"""
struct PrunedFIsBackend{M <: Val} <: StochasticAD.AbstractFIsBackend
    pruning_mode::M
    function PrunedFIsBackend(pruning_mode::M = Val(:weights)) where {M}
        if pruning_mode isa Val{:weights} || pruning_mode isa Val{:wins}
            return new{M}(pruning_mode)
        else
            error("Unsupported pruning_mode $pruning_mode for `PrunedFIsBackend.")
        end
    end
end

"""
    PrunedFIsState

State maintained by pruning backend.
"""
mutable struct PrunedFIsState{M, W}
    tag::Int32
    weight::Float64
    valid::Bool
    # TODO: generalize (wins, pruning_mode) into a general interface for accumulating state
    # that informs future pruning decisions.
    wins::W
    pruning_mode::M
    function PrunedFIsState(pruning_mode::M, valid = true) where {M <: Val}
        wins = pruning_mode isa Val{:wins} ? (valid ? 1 : 0) : nothing
        state::PrunedFIsState = new{M, typeof(wins)}(0, 0.0, valid, wins)
        state.tag = objectid(state) % typemax(Int32)
        return state
    end
end

Base.:(==)(state1::PrunedFIsState, state2::PrunedFIsState) = state1.tag == state2.tag
# c.f. https://github.com/JuliaLang/julia/blob/61c3521613767b2af21dfa5cc5a7b8195c5bdcaf/base/hashing.jl#L38C45-L38C51
Base.hash(state::PrunedFIsState) = state.tag

"""
    PrunedFIs{V} <: StochasticAD.AbstractFIs{V}

The implementing backend structure for PrunedFIsBackend.
"""
struct PrunedFIs{V, S <: PrunedFIsState} <: StochasticAD.AbstractFIs{V}
    Δ::V
    state::S
end

### Empty / no perturbation

PrunedFIs{V}(Δ::V, state::S) where {V, S <: PrunedFIsState} = PrunedFIs{V, S}(Δ, state)
PrunedFIs{V}(state::PrunedFIsState) where {V} = PrunedFIs{V}(zero(V), state)
# TODO: avoid allocations here
function StochasticAD.similar_empty(Δs::PrunedFIs, V::Type)
    PrunedFIs{V}(PrunedFIsState(Δs.state.pruning_mode, false))
end
Base.empty(Δs::PrunedFIs{V}) where {V} = StochasticAD.similar_empty(Δs::PrunedFIs, V::Type)
# we truly have no clue what the state is here, so use an invalidated state
function Base.empty(::Type{<:PrunedFIs{V, S}}) where {V, M, S <: PrunedFIsState{M}}
    PrunedFIs{V}(PrunedFIsState(M(), false))
end

### Create a new perturbation with infinitesimal probability

function StochasticAD.similar_new(Δs::PrunedFIs, Δ::V, w::Real) where {V}
    if iszero(w)
        return StochasticAD.similar_empty(Δs, V)
    end
    state = PrunedFIsState(Δs.state.pruning_mode)
    state.weight += w
    Δs = PrunedFIs{V}(Δ, state)
    return Δs
end

### Create Δs backend for the first stochastic triple of computation

function StochasticAD.create_Δs(backend::PrunedFIsBackend, V)
    PrunedFIs{V}(PrunedFIsState(backend.pruning_mode, false))
end

### Convert type of a backend

function Base.convert(::Type{<:PrunedFIs{V}}, Δs::PrunedFIs) where {V}
    PrunedFIs{V}(convert(V, Δs.Δ), Δs.state)
end

### Getting information about perturbations

# "empty" here means no perturbation or a perturbation that has been pruned away
Base.isempty(Δs::PrunedFIs) = !Δs.state.valid
Base.length(Δs::PrunedFIs) = isempty(Δs) ? 0 : 1
function Base.iszero(Δs::PrunedFIs)
    isempty(Δs) || all(iszero, StochasticAD.structural_iterate(Δs.Δ))
end
Base.iszero(Δs::PrunedFIs{<:Real}) = isempty(Δs) || iszero(Δs.Δ)
Base.iszero(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) || all(iszero.(Δs.Δ))
isapproxzero(Δs::PrunedFIs) = isempty(Δs) || isapprox(Δs.Δ, zero(Δs.Δ))

# we lazily prune, so check if empty first
function pruned_value(Δs::PrunedFIs{V}) where {V}
    isempty(Δs) ? StochasticAD.structural_map(zero, Δs.Δ) : Δs.Δ
end
pruned_value(Δs::PrunedFIs{<:Real}) = isempty(Δs) ? zero(Δs.Δ) : Δs.Δ
pruned_value(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ
pruned_value(Δs::PrunedFIs{<:AbstractArray}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ

StochasticAD.derivative_contribution(Δs::PrunedFIs) = pruned_value(Δs) * Δs.state.weight
function StochasticAD.perturbations(Δs::PrunedFIs)
    ((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),)
end

### Unary propagation

function StochasticAD.weighted_map_Δs(f, Δs::PrunedFIs; kwargs...)
    Δ_out, weight_out = f(pruned_value(Δs), Δs.state)
    # TODO: we could add a direct overload for map_Δs that elides the below line
    Δs.state.weight *= weight_out
    PrunedFIs(Δ_out, Δs.state)
end

StochasticAD.alltrue(f, Δs::PrunedFIs) = f(pruned_value(Δs))

### Coupling

function StochasticAD.get_rep(FIs::Type{<:PrunedFIs}, Δs_all)
    return empty(FIs) #StochasticAD.get_any(Δs_all)
end

function get_pruned_state(Δs_all; Δ_func = nothing, rep, out_rep = nothing)
    if !isnothing(Δ_func) && isnothing(out_rep)
        error("Specifying Δ_func requires out_rep to be specified.")
    end
    function op(cur_state, Δs)
        # lazy pruning optimization temporarily disabled with custom Δ_func 
        # (because custom Δ_func's may prefer not to lazily prune)
        (isnothing(Δ_func) && isapproxzero(Δs)) && return cur_state
        candidate_state = Δs.state
        if !candidate_state.valid ||
           (candidate_state == cur_state)
            return cur_state
        end
        if !cur_state.valid
            return candidate_state
        end

        # Compute "strength" of each perturbation for pruning proposal
        if !isnothing(Δ_func)
            # TODO: structural_map for each state can take asymptotically more time than necessary when combining many distinct states
            candidate_Δ = StochasticAD.structural_map(
                Base.Fix2(StochasticAD.filter_state, candidate_state), Δs_all)
            candidate_Δ_func::Float64 = Δ_func(candidate_Δ, candidate_state, out_rep)
            cur_Δ = StochasticAD.structural_map(
                Base.Fix2(StochasticAD.filter_state, cur_state), Δs_all)
            cur_Δ_func::Float64 = Δ_func(cur_Δ, cur_state, out_rep)
        else
            candidate_Δ_func = 1.0
            cur_Δ_func = 1.0
        end
        candidate_intrinsic_strength = Δs.state.pruning_mode isa Val{:wins} ?
                                       candidate_state.wins : abs(candidate_state.weight)
        cur_intrinsic_strength = Δs.state.pruning_mode isa Val{:wins} ? cur_state.wins :
                                 abs(cur_state.weight)
        candidate_strength = candidate_intrinsic_strength * candidate_Δ_func
        cur_strength = cur_intrinsic_strength * cur_Δ_func

        both_states_bad = iszero(candidate_strength) && iszero(cur_strength)
        if both_states_bad
            cur_state.valid = false
            candidate_state.valid = false
            return cur_state
        end

        # Prune between perturbations
        total_strength = cur_strength + candidate_strength
        p = candidate_strength / total_strength
        if isone(p) || (rand(StochasticAD.RNG) < p)
            cur_state.valid = false
            if Δs.state.pruning_mode isa Val{:wins}
                candidate_state.wins += 1
            end
            candidate_state.weight *= 1 / p
            return candidate_state
        else
            candidate_state.valid = false
            if Δs.state.pruning_mode isa Val{:wins}
                cur_state.wins += 1
            end
            cur_state.weight *= 1 / (1 - p)
            return cur_state
        end
    end
    dummy_state = PrunedFIsState(rep.state.pruning_mode, false) # For type stability, as well as retval if no better state found. TODO: can this be avoided?
    _new_state = foldl(op, StochasticAD.structural_iterate(Δs_all); init = dummy_state)
    return _new_state::PrunedFIsState
end

# for pruning, coupling amounts to getting rid of perturbed values that have been
# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.
function StochasticAD.couple(
        FIs::Type{<:PrunedFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all),
        out_rep = nothing, Δ_func = nothing, kwargs...)
    state = get_pruned_state(Δs_all; rep, Δ_func)
    Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here
    PrunedFIs(Δ_coupled, state)
end

# basically couple combined with a sum.
function StochasticAD.combine(
        FIs::Type{<:PrunedFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all),
        Δ_func = nothing, out_rep = nothing, kwargs...)
    state = get_pruned_state(Δs_all;
        rep,
        out_rep,
        Δ_func = !isnothing(Δ_func) ? (Δ, state, val) -> Δ_func(sum(Δ), state, val) :
                 Δ_func)
    Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all))
    PrunedFIs(Δ_combined, state)
end

function StochasticAD.scalarize(Δs::PrunedFIs; out_rep = nothing)
    return StochasticAD.structural_map(Δs.Δ) do Δ
        return PrunedFIs(Δ, Δs.state)
    end
end

function StochasticAD.filter_state(Δs::PrunedFIs{V}, state) where {V}
    Δs.state == state ? pruned_value(Δs) : zero(V)
end

### Miscellaneous

function StochasticAD.similar_type(::Type{<:PrunedFIs{V0, M}}, V::Type) where {V0, M}
    PrunedFIs{V, M}
end
StochasticAD.valtype(::Type{<:PrunedFIs{V}}) where {V} = V

function Base.show(io::IO, Δs::PrunedFIs{V}) where {V}
    print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε")
end

end


================================================
FILE: src/backends/pruned_aggressive.jl
================================================
module PrunedFIsAggressiveModule

import ..StochasticAD

export PrunedFIsAggressiveBackend, PrunedFIsAggressive

"""
    PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend

A backend algorithm that aggressively prunes between perturbations as soon as they are created.
"""
struct PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend end

"""
    PrunedFIsAggressiveState

State maintained by aggressive pruning backend.
"""
mutable struct PrunedFIsAggressiveState
    active_tag::Int64 # 0 is always a dummy tag
    weight::Float64
    tag_count::Int64
    valid::Bool
    PrunedFIsAggressiveState(valid = true) = new(0, 0.0, 0, valid)
end

"""
    PrunedFIsAggressive{V} <: StochasticAD.AbstractFIs{V}

The implementing backend structure for PrunedFIsAggressiveBackend.
"""
struct PrunedFIsAggressive{V} <: StochasticAD.AbstractFIs{V}
    Δ::V
    tag::Int
    state::PrunedFIsAggressiveState
    # directly called when propagating an existing perturbation
end

### Empty / no perturbation

function PrunedFIsAggressive{V}(state::PrunedFIsAggressiveState) where {V}
    PrunedFIsAggressive{V}(zero(V), -1, state)
end
function StochasticAD.similar_empty(Δs::PrunedFIsAggressive, V::Type)
    PrunedFIsAggressive{V}(Δs.state)
end
function Base.empty(Δs::PrunedFIsAggressive{V}) where {V}
    StochasticAD.similar_empty(Δs, V)
end
# we truly have no clue what the state is here, so use an invalidated state
function Base.empty(::Type{<:PrunedFIsAggressive{V}}) where {V}
    PrunedFIsAggressive{V}(PrunedFIsAggressiveState(false))
end

### Create a new perturbation with infinitesimal probability

function new_perturbation(Δ::V, w::Real, state::PrunedFIsAggressiveState) where {V}
    total_weight = state.weight + w
    if rand(StochasticAD.RNG) * total_weight < state.weight
        state.weight += w
        return PrunedFIsAggressive{V}(state)
    else
        state.tag_count += 1
        state.active_tag = state.tag_count
        state.weight += w
        return PrunedFIsAggressive{V}(Δ, state.active_tag, state)
    end
end
function StochasticAD.similar_new(Δs::PrunedFIsAggressive, Δ::V, w::Real) where {V}
    new_perturbation(Δ, w, Δs.state)
end

### Create Δs backend for the first stochastic triple of computation

function StochasticAD.create_Δs(::PrunedFIsAggressiveBackend, V)
    PrunedFIsAggressive{V}(PrunedFIsAggressiveState())
end

### Convert type of a backend

function Base.convert(::Type{PrunedFIsAggressive{V}}, Δs::PrunedFIsAggressive) where {V}
    PrunedFIsAggressive{V}(convert(V, Δs.Δ), Δs.tag, Δs.state)
end

### Getting information about perturbations

# "empty" here means no perturbation or a perturbation that has been pruned away
Base.isempty(Δs::PrunedFIsAggressive) = Δs.tag != Δs.state.active_tag
Base.length(Δs::PrunedFIsAggressive) = isempty(Δs) ? 0 : 1
Base.iszero(Δs::PrunedFIsAggressive) = isempty(Δs) || iszero(Δs.Δ)

# we lazily prune, so check if empty first
pruned_value(Δs::PrunedFIsAggressive{V}) where {V} = isempty(Δs) ? zero(V) : Δs.Δ

function StochasticAD.derivative_contribution(Δs::PrunedFIsAggressive)
    pruned_value(Δs) * Δs.state.weight
end

function StochasticAD.perturbations(Δs::PrunedFIsAggressive)
    ((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),)
end

### Unary propagation

function StochasticAD.weighted_map_Δs(f, Δs::PrunedFIsAggressive; kwargs...)
    Δ_out, weight_out = f(Δs.Δ, nothing)
    Δs.state.weight *= weight_out
    PrunedFIsAggressive(Δ_out, Δs.tag, Δs.state)
end

StochasticAD.alltrue(f, Δs::PrunedFIsAggressive) = f(Δs.Δ)

### Coupling

function StochasticAD.get_rep(::Type{<:PrunedFIsAggressive}, Δs_all)
    # Get some Δs with a valid state, or any if all are invalid.
    return reduce((Δs1, Δs2) -> Δs1.state.valid ? Δs1 : Δs2,
        StochasticAD.structural_iterate(Δs_all))
end

# for pruning, coupling amounts to getting rid of perturbed values that have been
# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.
function StochasticAD.couple(FIs::Type{<:PrunedFIsAggressive}, Δs_all;
        rep = StochasticAD.get_rep(FIs, Δs_all),
        out_rep = nothing, kwargs...)
    state = rep.state
    Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here
    PrunedFIsAggressive(Δ_coupled, state.active_tag, state)
end

# basically couple combined with a sum.
function StochasticAD.combine(FIs::Type{<:PrunedFIsAggressive}, Δs_all;
        rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...)
    state = rep.state
    Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all))
    PrunedFIsAggressive(Δ_combined, state.active_tag, state)
end

function StochasticAD.scalarize(Δs::PrunedFIsAggressive; out_rep = nothing)
    return StochasticAD.structural_map(Δs.Δ) do Δ
        return PrunedFIsAggressive(Δ, Δs.tag, Δs.state)
    end
end

StochasticAD.filter_state(Δs::PrunedFIsAggressive, _) = pruned_value(Δs)

### Miscellaneous

StochasticAD.similar_type(::Type{<:PrunedFIsAggressive}, V::Type) = PrunedFIsAggressive{V}
StochasticAD.valtype(::Type{<:PrunedFIsAggressive{V}}) where {V} = V

# should I have a mime input?
function Base.show(io::IO, mime::MIME"text/plain",
        Δs::PrunedFIsAggressive{V}) where {V}
    print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε, tag $(Δs.tag)")
end

function Base.show(io::IO, Δs::PrunedFIsAggressive{V}) where {V}
    print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε, tag $(Δs.tag)")
end

end


================================================
FILE: src/backends/smoothed.jl
================================================
module SmoothedFIsModule

import ..StochasticAD

export SmoothedFIsBackend, SmoothedFIs

"""
    SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend

A backend algorithm that smooths perturbations togethers. 
"""
struct SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend end

"""
    SmoothedFIs{V} <: StochasticAD.AbstractFIs{V}

The implementing backend structure for SmoothedFIsBackend.
"""
# TODO: make type of δ generic
struct SmoothedFIs{V, V_float} <: StochasticAD.AbstractFIs{V}
    δ::V_float
    function SmoothedFIs{V}(δ) where {V}
        # hardcode Float64 representation for now, for simplicity.
        δ_f64 = StochasticAD.structural_map(Base.Fix1(convert, Float64), δ)
        return new{V, typeof(δ_f64)}(δ_f64)
    end
end

### Empty / no perturbation

StochasticAD.similar_empty(::SmoothedFIs, V::Type) = SmoothedFIs{V}(0.0)
Base.empty(::Type{<:SmoothedFIs{V}}) where {V} = SmoothedFIs{V}(0.0)
Base.empty(Δs::SmoothedFIs) = empty(typeof(Δs))

### Create a new perturbation with infinitesimal probability

function StochasticAD.similar_new(::SmoothedFIs, Δ::V, w::Real) where {V}
    SmoothedFIs{V}(Δ * w)
end

StochasticAD.new_Δs_strategy(::SmoothedFIs) = StochasticAD.TwoSidedStrategy()

### Create Δs backend for the first stochastic triple of computation

StochasticAD.create_Δs(::SmoothedFIsBackend, V) = SmoothedFIs{V}(0.0)

### Convert type of a backend

function Base.convert(FIs::Type{<:SmoothedFIs{V}}, Δs::SmoothedFIs) where {V}
    SmoothedFIs{V}(Δs.δ)::FIs
end

### Getting information about perturbations

Base.isempty(Δs::SmoothedFIs) = false
Base.iszero(Δs::SmoothedFIs) = iszero(Δs.δ)
Base.iszero(Δs::SmoothedFIs{<:Tuple}) = all(iszero.(Δs.δ))
StochasticAD.derivative_contribution(Δs::SmoothedFIs) = Δs.δ

### Unary propagation

function StochasticAD.weighted_map_Δs(f, Δs::SmoothedFIs; deriv, out_rep, kwargs...)
    SmoothedFIs{typeof(out_rep)}(deriv(Δs.δ))
end

StochasticAD.alltrue(f, Δs::SmoothedFIs) = true

### Coupling

StochasticAD.get_rep(::Type{<:SmoothedFIs}, Δs_all) = StochasticAD.get_any(Δs_all)

function StochasticAD.couple(
        ::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep, kwargs...)
    SmoothedFIs{typeof(out_rep)}(StochasticAD.structural_map(Δs -> Δs.δ, Δs_all))
end

function StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, kwargs...)
    V_out = StochasticAD.valtype(first(StochasticAD.structural_iterate(Δs_all)))
    Δ_combined = sum(Δs -> Δs.δ, StochasticAD.structural_iterate(Δs_all))
    SmoothedFIs{V_out}(Δ_combined)
end

function StochasticAD.scalarize(Δs::SmoothedFIs; out_rep)
    return StochasticAD.structural_map(out_rep, Δs.δ) do out, δ
        return SmoothedFIs{typeof(out)}(δ)
    end
end

### Miscellaneous

StochasticAD.similar_type(::Type{<:SmoothedFIs}, V::Type) = SmoothedFIs{V, Float64}
StochasticAD.valtype(::Type{<:SmoothedFIs{V}}) where {V} = V

function Base.show(io::IO, Δs::SmoothedFIs)
    print(io, "$(Δs.δ)ε")
end

end


================================================
FILE: src/backends/strategy_wrapper.jl
================================================
module StrategyWrapperFIsModule

using ..StochasticAD
using ..StochasticAD.AbstractWrapperFIsModule

export StrategyWrapperFIsBackend, StrategyWrapperFIs

struct StrategyWrapperFIsBackend{
    B <: StochasticAD.AbstractFIsBackend,
    S <: StochasticAD.AbstractPerturbationStrategy
} <:
       StochasticAD.AbstractFIsBackend
    backend::B
    strategy::S
end

struct StrategyWrapperFIs{
    V,
    FIs <: StochasticAD.AbstractFIs{V},
    S <: StochasticAD.AbstractPerturbationStrategy
} <:
       AbstractWrapperFIs{V, FIs}
    Δs::FIs
    strategy::S
end

function StochasticAD.create_Δs(backend::StrategyWrapperFIsBackend, V)
    return StrategyWrapperFIs(StochasticAD.create_Δs(backend.backend, V), backend.strategy)
end

function StochasticAD.similar_type(::Type{<:StrategyWrapperFIs{V0, FIs0, S}},
        V,
        FIs) where {V0, FIs0, S}
    return StrategyWrapperFIs{V, FIs, S}
end

function AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::StrategyWrapperFIs, Δs)
    return StrategyWrapperFIs(Δs, wrapper_Δs.strategy)
end

function AbstractWrapperFIsModule.reconstruct_wrapper(
        ::Type{
            <:StrategyWrapperFIs{V, FIs, S},
        },
        Δs) where {V, FIs, S}
    return StrategyWrapperFIs(Δs, S())
end

StochasticAD.new_Δs_strategy(Δs::StrategyWrapperFIs) = Δs.strategy

end


================================================
FILE: src/discrete_randomness.jl
================================================
## Helper functions for discrete distributions 

# index of the parameter p
_param_index(::Geometric) = 1
_param_index(::Bernoulli) = 1
_param_index(::Binomial) = 2
_param_index(::Poisson) = 1
_param_index(::Categorical) = 1

_get_parameter(d) = params(d)[_param_index(d)]

# constructors
for dist in [:Geometric, :Bernoulli, :Binomial, :Poisson, :Categorical]
    @eval _constructor(::$dist) = $dist
end

# reconstruct probability distribution with new paramter value
function _reconstruct(d, p)
    i = _param_index(d)
    return _constructor(d)(params(d)[1:(i - 1)]..., p, params(d)[(i + 1):end]...)
end

# support of probability distribution
_has_finite_support(d) = false
_has_finite_support(d::Union{Bernoulli, Binomial, Categorical}) = true

_get_support(d::Union{Bernoulli, Binomial, Categorical}) = minimum(d):maximum(d)
# manual overloads to ensure that static-ness is preserved for Bernoulli's and Categoricals with static arrays.
# since mapping over the range above could result in allocating vectors.
_get_support(::Bernoulli) = (0, 1)
# the map below looks a bit silly, but it gives us a collection of the categories with the same structure as probs(d). 
_get_support(d::Categorical) = map((val, prob) -> val, 1:ncategories(d), probs(d))

## Derivative couplings

# Derivative coupling approaches, determining which weighted perturbations to consider
abstract type AbstractDerivativeCoupling end

"""
    InversionMethodDerivativeCoupling(; mode::Val = Val(:positive_weight), handle_zeroprob::Val = Val(true))

Specifies an inversion method coupling for generating perturbations from a univariate distribution.
Valid choices of `mode` are `Val(:positive_weight)`, `Val(:always_right)`, and `Val(:always_left)`.

# Example
```jldoctest
julia> using StochasticAD, Distributions, Random; Random.seed!(4321);

julia> function X(p)
           return randst(Bernoulli(1 - p); derivative_coupling = InversionMethodDerivativeCoupling(; mode = Val(:always_right)))
       end
X (generic function with 1 method)

julia> stochastic_triple(X, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability -2.0ε)
```
"""
Base.@kwdef struct InversionMethodDerivativeCoupling{M, HZP}
    mode::M = Val(:positive_weight)
    handle_zeroprob::HZP = Val(true)
end

# Strategies for precisely which perturbations to form given a derivative coupling
struct SingleSidedStrategy <: AbstractPerturbationStrategy end
struct TwoSidedStrategy <: AbstractPerturbationStrategy end
struct SmoothedStraightThroughStrategy <: AbstractPerturbationStrategy end
struct StraightThroughStrategy <: AbstractPerturbationStrategy end
struct IgnoreDiscreteStrategy <: AbstractPerturbationStrategy end

new_Δs_strategy(Δs) = SingleSidedStrategy()

# Derivative coupling high-level interface

"""
    δtoΔs(d, val, δ, Δs::AbstractFIs)

Given the parameter `val` of a distribution `d` and an infinitesimal change `δ`,
return the discrete change in the output, with a similar representation to `Δs`.
"""
δtoΔs(d, val, δ, Δs, derivative_coupling) = δtoΔs(
    d, val, δ, Δs, derivative_coupling, new_Δs_strategy(Δs))
function δtoΔs(d, val, δ, Δs, derivative_coupling, ::SingleSidedStrategy)
    _δtoΔs(d, val, δ, Δs, derivative_coupling)
end
function δtoΔs(d, val, δ, Δs, derivative_coupling, ::TwoSidedStrategy)
    Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling)
    Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling)
    return combine((scale(Δs1, 0.5), scale(Δs2, -0.5)))
end
# TODO: implement this ST for other distributions and couplings, if meaningful?
function δtoΔs(d::Union{Bernoulli, Binomial},
        val,
        δ,
        Δs,
        derivative_coupling::InversionMethodDerivativeCoupling,
        ::StraightThroughStrategy)
    p = succprob(d)
    Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling)
    Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling)
    return combine((scale(Δs1, 1 - p), scale(Δs2, -p)))
end
function δtoΔs(d, val::V, δ, Δs, derivative_coupling, ::IgnoreDiscreteStrategy) where {V}
    similar_empty(Δs, V)
end

# Implement straight through strategy, works for all distrs, but does something that is only
# meaningful for smoothed backends (using one(val))
function δtoΔs(d, val, δ, Δs, derivative_coupling, ::SmoothedStraightThroughStrategy)
    p = _get_parameter(d)
    δout = ForwardDiff.derivative(a -> mean(_reconstruct(d, p + a * δ)), 0.0)
    return similar_new(Δs, one(val), δout)
end

# Derivative coupling low-level implementations 

function _δtoΔs(d::Geometric,
        val::V,
        δ::Real,
        Δs::AbstractFIs,
        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
    p = succprob(d)
    if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||
       (derivative_coupling.mode isa Val{:always_right})
        return val > 0 ? similar_new(Δs, -one(V), δ * val / p / (1 - p)) :
               similar_empty(Δs, V)
    elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||
           (derivative_coupling.mode isa Val{:always_left})
        return similar_new(Δs, one(V), -δ * (val + 1) / p)
    else
        return similar_empty(Δs, V)
    end
end

function _δtoΔs(d::Bernoulli,
        val::V,
        δ::Real,
        Δs::AbstractFIs,
        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
    p = succprob(d)
    if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||
       (derivative_coupling.mode isa Val{:always_right})
        return isone(val) ? similar_empty(Δs, V) : similar_new(Δs, one(V), δ / (1 - p))
    elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||
           (derivative_coupling.mode isa Val{:always_left})
        return isone(val) ? similar_new(Δs, -one(V), -δ / p) : similar_empty(Δs, V)
    else
        return similar_empty(Δs, V)
    end
end

function _δtoΔs(d::Binomial,
        val::V,
        δ::Real,
        Δs::AbstractFIs,
        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
    p = succprob(d)
    n = ntrials(d)
    if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||
       (derivative_coupling.mode isa Val{:always_right})
        return val == n ? similar_empty(Δs, V) :
               similar_new(Δs, one(V), δ * (n - val) / (1 - p))
    elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||
           (derivative_coupling.mode isa Val{:always_left})
        return !iszero(val) ? similar_new(Δs, -one(V), -δ * val / p) : similar_empty(Δs, V)
    else
        return similar_empty(Δs, V)
    end
end

function _δtoΔs(d::Poisson,
        val::V,
        δ::Real,
        Δs::AbstractFIs,
        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
    p = mean(d) # rate
    if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||
       (derivative_coupling.mode isa Val{:always_right})
        return similar_new(Δs, 1, δ)
    elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||
           (derivative_coupling.mode isa Val{:always_left})
        return val > 0 ? similar_new(Δs, -1, -δ * val / p) : similar_empty(Δs, V)
    else
        return similar_empty(Δs, V)
    end
end

function _δtoΔs(d::Categorical,
        val::V,
        δs,
        Δs::AbstractFIs,
        derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
    p = params(d)[1]
    # NB: Although we might expect sum(δs) = 0, it is useful to handle things more generally, viewing δs
    # as perturbing the Categorical distribution locally along some direction in the space of general measures.
    # The below formulation gets things right in this case too. 
    left_sum = sum(δs[1:(val - 1)], init = zero(eltype(δs)))
    right_sum = sum(δs[1:val], init = zero(eltype(δs)))

    if (derivative_coupling.mode isa Val{:positive_weight} && left_sum > 0) ||
       (derivative_coupling.mode isa Val{:always_left} && !iszero(left_sum))
        # compute left_nonzero
        if derivative_coupling.handle_zeroprob isa Val{true}
            stop = rand() * left_sum
            upto = zero(eltype(δs)) # The "upto" logic handles an edge case of probability 0 events that have non-zero derivative.
            # It's a lot of logic to handle an edge case, but hopefully it's optimized away.
            left_nonzero = val
            for i in (val - 1):-1:1
                if !iszero(p[i]) || ((upto += δs[i]) > stop)
                    left_nonzero = i
                    break
                end
            end
        else
            left_nonzero = val - 1
        end
        Δs_left = similar_new(Δs, left_nonzero - val, left_sum / p[val])
    else
        Δs_left = similar_empty(Δs, typeof(val))
    end

    if (derivative_coupling.mode isa Val{:positive_weight} && right_sum < 0) ||
       (derivative_coupling.mode isa Val{:always_right} && !iszero(right_sum))
        # compute right_nonzero
        if derivative_coupling.handle_zeroprob isa Val{true}
            stop = -rand() * right_sum
            upto = zero(eltype(δs))
            right_nonzero = val
            for i in (val + 1):length(p)
                if !iszero(p[i]) || ((upto += δs[i]) > stop)
                    right_nonzero = i
                    break
                end
            end
        else
            right_nonzero = val + 1
        end
        Δs_right = similar_new(Δs, right_nonzero - val, -right_sum / p[val])
    else
        Δs_right = similar_empty(Δs, typeof(val))
    end

    return combine((Δs_left, Δs_right); rep = Δs)
end

## Propagation couplings

abstract type AbstractPropagationCoupling end

"""
    InversionMethodPropagationCoupling 

Specifies an inversion method coupling for propagating perturbations.
"""
struct InversionMethodPropagationCoupling <: AbstractPropagationCoupling end

function _map_func(d, val, Δ, ::InversionMethodPropagationCoupling)
    # construct alternative distribution
    p = _get_parameter(d)
    alt_d = _reconstruct(d, p + Δ)
    # compute bounds on original ω
    low = cdf(d, val - 1)
    high = cdf(d, val)
    # sample alternative value
    alt_val = quantile(alt_d, rand(RNG) * (high - low) + low)
    return convert(Signed, alt_val - val)
end

function _map_enumeration(d, val, Δ, ::InversionMethodPropagationCoupling)
    # construct alternative distribution
    p = _get_parameter(d)
    alt_d = _reconstruct(d, p + Δ)
    # compute bounds on original ω
    low = cdf(d, val - 1)
    high = cdf(d, val)
    if _has_finite_support(alt_d)
        map(_get_support(alt_d)) do alt_val
            # interval intersect of (cdf(alt_d, alt_val - 1), cdf(alt_d, alt_val)) and (low, high)
            alt_low = cdf(alt_d, alt_val - 1)
            alt_high = cdf(alt_d, alt_val)
            prob_alt = max(0.0, min(alt_high, high) - max(alt_low, low)) /
                       (high - low)
            return (alt_val - val, prob_alt)
        end
    else
        error("enumeration not supported for distribution $d. Does $d have finite support?")
    end
end

## Overloading of random sampling 

# Define randst interface

"""
    randst(rng, d::Distributions.Sampleable; kwargs...)

When no keyword arguments are provided, `randst` behaves identically to `rand(rng, d)` in both ordinary computation
and for stochastic triple dispatches. However, `randst` also allows the user to provide various keyword arguments
for customizing the differentiation logic. The set of allowed keyword arguments depends on the type of `d`: a couple
common ones are `derivative_coupling` and `propagation_coupling`.

For developers: if you wish to accept custom keyword arguments in a stochastic triple dispatch, you should overload
`randst`, and redirect `rand` to your `randst` method. If you do not, it suffices to just overload `rand`.
"""
randst(rng, d::Distributions.Sampleable; kwargs...) = rand(rng, d)
randst(d::Distributions.Sampleable; kwargs...) = randst(Random.default_rng(), d; kwargs...)

# Define stochastic triple rules

for dist in [:Geometric, :Bernoulli, :Binomial, :Poisson]
    @eval function Base.rand(rng::AbstractRNG,
            d_st::$dist{StochasticTriple{T, V, FIs}}) where {T, V, FIs}
        return randst(rng, d_st)
    end
    @eval function randst(rng::AbstractRNG,
            d_st::$dist{StochasticTriple{T, V, FIs}};
            Δ_kwargs = (;),
            derivative_coupling = InversionMethodDerivativeCoupling(),
            propagation_coupling = InversionMethodPropagationCoupling()) where {T, V, FIs}
        st = _get_parameter(d_st)
        d = _reconstruct(d_st, st.value)
        val = convert(Signed, rand(rng, d))
        Δs1 = δtoΔs(d, val, st.δ, st.Δs, derivative_coupling)

        Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling),
            st.Δs;
            enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling),
            deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling),
            out_rep = val,
            Δ_kwargs...)

        StochasticTriple{T}(val, zero(val), combine((Δs2, Δs1); rep = Δs1)) # ensure that tags are in order in combine, in case backend wishes to exploit this 
    end
end

# currently handle Categorical separately since parameter is a vector
# what if some elements in vector are not stochastic triples... promotion should take care of that?
function Base.rand(rng::AbstractRNG,
        d_st::Categorical{StochasticTriple{T, V, FIs}}) where {T, V, FIs}
    return randst(rng, d_st)
end
function randst(rng::AbstractRNG,
        d_st::Categorical{<:StochasticTriple{T},
            <:AbstractVector{<:StochasticTriple{T, V}}};
        Δ_kwargs = (;),
        derivative_coupling = InversionMethodDerivativeCoupling(),
        propagation_coupling = InversionMethodPropagationCoupling()) where {T, V}
    sts = _get_parameter(d_st) # stochastic triple for each probability
    p = map(st -> st.value, sts) # try to keep the same type. e.g. static array -> static array. TODO: avoid allocations 
    d = _reconstruct(d_st, p)
    val = convert(Signed, rand(rng, d))

    Δs_all = map(st -> st.Δs, sts)
    Δs_rep = get_rep(Δs_all)

    Δs1 = δtoΔs(d, val, map(st -> st.δ, sts), Δs_rep, derivative_coupling)

    Δs_coupled = couple(Δs_all; rep = Δs_rep, out_rep = p) # TODO: again, there are possible allocations here
    Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling),
        Δs_coupled;
        enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling),
        deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling),
        out_rep = val,
        Δ_kwargs...)

    Δs = combine((Δs2, Δs1); rep = Δs1, out_rep = val, Δ_kwargs...)

    StochasticTriple{T}(val, zero(val), Δs)
end

## Handling finite perturbation to Binomial number of trials

"""
    DiscreteDeltaStochasticTriple{T, V, FIs <: AbstractFIs}

An experimental discrete stochastic triple type used internally for representing perturbations
to non-real quantities. Currently only used to represent a finite perturbation to the Binomial 
parameter n.

## Constructor

- `value`: the primal value.
- `Δs``: some representation of the perturbation to the primal, which can have an unconventional
         interpretation depending on `T`.
"""
struct DiscreteDeltaStochasticTriple{T, V, FIs <: AbstractFIs}
    value::V
    Δs::FIs
    function DiscreteDeltaStochasticTriple{T, V, FIs}(value::V,
            Δs::FIs) where {T, V,
            FIs <: AbstractFIs}
        new{T, V, FIs}(value, Δs)
    end
end

function DiscreteDeltaStochasticTriple{T}(val::V, Δs::FIs) where {T, V, FIs <: AbstractFIs}
    DiscreteDeltaStochasticTriple{T, V, FIs}(val, Δs)
end

function Distributions.Binomial(n::StochasticTriple{T}, p::Real) where {T}
    return DiscreteDeltaStochasticTriple{T}(Binomial(n.value, p), n.Δs)
end

# TODO: Support functions other than `rand` called on a perturbed Binomial.
function Base.rand(rng::AbstractRNG,
        d_st::DiscreteDeltaStochasticTriple{T, <:Binomial}) where {T}
    return randst(rng, d_st)
end
function randst(rng::AbstractRNG,
        d_st::DiscreteDeltaStochasticTriple{T, <:Binomial}) where {T}
    d = d_st.value
    val = rand(rng, d)
    function map_func(Δ)
        if Δ >= 0
            return rand(StochasticAD.RNG, Binomial(Δ, value(succprob(d))))
        else
            return -rand(StochasticAD.RNG,
                Hypergeometric(value(val), ntrials(d) - value(val), -Δ))
        end
    end
    Δs = map(map_func, d_st.Δs)
    if val isa StochasticTriple
        return StochasticTriple{T}(val.value, val.δ, combine((Δs, val.Δs); rep = Δs))
    else
        return StochasticTriple{T}(val, zero(val), Δs)
    end
end


================================================
FILE: src/finite_infinitesimals.jl
================================================
# TODO: make this a module, with the interface exported?

## 
"""
    AbstractFIsBackend

An abstract type for backend strategies of Finite perturbations that occur with Infinitesimal probability (FIs).
"""
abstract type AbstractFIsBackend end

"""
    AbstractFIs{V}

An abstract type for concrete backend representations of Finite Infinitesimals. 
"""
abstract type AbstractFIs{V} end

### Some of the necessary interface notes below.
# TODO: document

function create_Δs end

function similar_new end
function similar_empty end
function similar_type end

valtype(Δs::AbstractFIs) = valtype(typeof(Δs))

# TODO: typeof ∘ first is a loose check, should make more robust.
# TODO: perhaps deprecate these methods in favor of an explicit first argument?
couple(Δs_all; kwargs...) = couple(typeof(first(Δs_all)), Δs_all; kwargs...)
combine(Δs_all; kwargs...) = combine(typeof(first(Δs_all)), Δs_all; kwargs...)
get_rep(Δs_all; kwargs...) = get_rep(typeof(first(Δs_all)), Δs_all; kwargs...)
function scalarize end

function derivative_contribution end

function alltrue end

function perturbations end

function filter_state end

function weighted_map_Δs end
function map_Δs(f, Δs::AbstractFIs; kwargs...)
    StochasticAD.weighted_map_Δs((Δs, state) -> (f(Δs, state), 1.0), Δs; kwargs...)
end
function Base.map(f, Δs::AbstractFIs; kwargs...)
    StochasticAD.map_Δs((Δs, _) -> f(Δs), Δs; kwargs...)
end
# We also add a scale to deriv for scaling smoothed perturbations 
function scale(Δs::AbstractFIs, a::Real)
    StochasticAD.weighted_map_Δs((Δ, state) -> (Δ, a),
        Δs;
        deriv = Base.Fix1(*, a),
        out_rep = Δs)
end

function new_Δs_strategy end

# utility function useful e.g. for get_rep in some backends
function get_any(Δs_all)
    # The code below is a bit ridiculous, but it's faster than `first` for small structures:)
    foldl((Δs1, Δs2) -> Δs1, StochasticAD.structural_iterate(Δs_all))
end

abstract type AbstractPerturbationStrategy end

abstract type AbstractPerturbationSignal end

function send_signal end

# Ignore signals by default since they do not change semantics.
function StochasticAD.send_signal(
        Δs::StochasticAD.AbstractFIs, ::StochasticAD.AbstractPerturbationSignal)
    return Δs
end


================================================
FILE: src/general_rules.jl
================================================
"""
Operators which have already been overloaded by StochasticAD. 
"""
const handled_ops = Tuple{DataType, Int}[]

"""
    define_triple_overload(sig)

Given the signature type-type of the primal function, define operator
overloading rules for stochastic triples.
Currently supports functions with all-real inputs and one real output.
"""
# TODO: special case optimizations
# TODO: generalizations to not-all-real inputs and/or not-one-real output
function define_triple_overload(sig)
    opT, argTs = Iterators.peel(ExprTools.parameters(sig))
    opT <: Type{<:Type} && return  # not handling constructors
    sig <: Tuple{Type, Vararg{Any}} && return
    opT <: Core.Builtin && return false  # can't do operator overloading for builtins

    isabstracttype(opT) || fieldcount(opT) == 0 || return false  # not handling functors
    isempty(argTs) && return false  # we are an operator overloading AD, need operands
    all(argT isa Type && Real <: argT for argT in argTs) || return

    N = length(ExprTools.parameters(sig)) - 1  # skip the op

    # Skip already-handled ops, as well as ops that will be handled manually later (and more correctly, see #79).
    if (opT, N) in handled_ops || (opT.instance in UNARY_TYPEFUNCS_WRAP)
        return
    end

    push!(handled_ops, (opT, N))

    if opT.instance in UNARY_PREDICATES && (N == 1)
        @eval function (f::$opT)(st::StochasticTriple)
            val = value(st)
            out = f(val)
            if !alltrue(Δ -> (f(val + Δ) == out), st.Δs)
                error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)")
            end
            return out
        end
    elseif opT.instance in BINARY_PREDICATES && (N == 2)
        # Special case equality comparisons as in https://github.com/JuliaDiff/ForwardDiff.jl/pull/481
        if opT.instance == Base.:(==)
            return_value_real = quote
                out && iszero(delta(st))
            end
            return_value_st = quote
                out2 = out && (delta(st1) == delta(st2))
            end
        else
            return_value_real = quote
                out
            end
            return_value_st = quote
                out
            end
        end
        @eval function (f::$opT)(st::StochasticTriple, x::Real)
            val = value(st)
            out = f(val, x)
            if !alltrue(Δ -> (f(val + Δ, x) == out), st.Δs)
                error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)")
            end
            return $return_value_real
        end
        @eval function (f::$opT)(x::Real, st::StochasticTriple)
            val = value(st)
            out = f(x, val)
            if !alltrue(Δ -> (f(x, val + Δ) == out), st.Δs)
                error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)")
            end
            return $return_value_real
        end
        @eval function (f::$opT)(st1::StochasticTriple, st2::StochasticTriple)
            val1 = value(st1)
            val2 = value(st2)
            out = f(val1, val2)

            Δs_coupled = couple((st1.Δs, st2.Δs); out_rep = (val1, val2))
            safe_perturb = alltrue(Δs -> f(val1 + Δs[1], val2 + Δs[2]) == out, Δs_coupled)
            if !safe_perturb
                error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)")
            end
            return $return_value_st
        end
    elseif N == 1
        if Base.return_types(frule, (Tuple{NoTangent, Real}, opT, Real))[1] <:
           Tuple{Any, NoTangent}
            return
        end
        @eval function (f::$opT)(st::StochasticTriple{T}; kwargs...) where {T}
            run_frule = δ -> begin
                args_tangent = (NoTangent(), δ)
                return frule(args_tangent, f, value(st); kwargs...)
            end
            val, δ0 = run_frule(delta(st))
            δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0
            if !iszero(st.Δs)
                Δs = map(Δ -> f(st.value + Δ; kwargs...) - val, st.Δs;
                    deriv = last ∘ run_frule, out_rep = val)
            else
                Δs = similar_empty(st.Δs, typeof(val))
            end
            return StochasticTriple{T}(val, δ, Δs)
        end
    elseif N == 2
        if Base.return_types(frule, (Tuple{NoTangent, Real, Real}, opT, Real, Real))[1] <:
           Tuple{Any, NoTangent}
            return
        end
        for R in AMBIGUOUS_TYPES
            @eval function (f::$opT)(st::StochasticTriple{T}, x::$R; kwargs...) where {T}
                run_frule = δ -> begin
                    args_tangent = (NoTangent(), δ, zero(x))
                    return frule(args_tangent, f, value(st), x; kwargs...)
                end
                val, δ0 = run_frule(delta(st))
                δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ?
                                 zero(value(st)) : δ0
                if !iszero(st.Δs)
                    Δs = map(Δ -> f(st.value + Δ, x; kwargs...) - val, st.Δs;
                        deriv = last ∘ run_frule, out_rep = val)
                else
                    Δs = similar_empty(st.Δs, typeof(val))
                end
                return StochasticTriple{T}(val, δ, Δs)
            end
            @eval function (f::$opT)(x::$R, st::StochasticTriple{T}; kwargs...) where {T}
                run_frule = δ -> begin
                    args_tangent = (NoTangent(), zero(x), δ)
                    return frule(args_tangent, f, x, value(st); kwargs...)
                end
                val, δ0 = run_frule(delta(st))
                δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ?
                                 zero(value(st)) : δ0
                if !iszero(st.Δs)
                    Δs = map(Δ -> f(x, st.value + Δ; kwargs...) - val, st.Δs;
                        deriv = last ∘ run_frule, out_rep = val)
                else
                    Δs = similar_empty(st.Δs, typeof(val))
                end
                return StochasticTriple{T}(val, δ, Δs)
            end
        end
        @eval function (f::$opT)(sts::Vararg{StochasticTriple{T}, 2}; kwargs...) where {T}
            run_frule = δs -> begin
                args_tangent = (NoTangent(), δs...)
                args = (f, value.(sts)...)
                return frule(args_tangent, args...; kwargs...)
            end
            val, δ0 = run_frule(delta.(sts))
            δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0

            Δs_all = map(st -> getfield(st, :Δs), sts)
            if all(iszero.(Δs_all))
                Δs = similar_empty(first(sts).Δs, typeof(val))
            else
                vals_in = value.(sts)
                Δs_coupled = couple(Tuple(Δs_all); out_rep = vals_in)
                mapfunc = let vals_in = vals_in
                    Δ -> (f((vals_in .+ Δ)...; kwargs...) - val)
                end
                Δs = map(mapfunc, Δs_coupled; deriv = last ∘ run_frule, out_rep = val)
            end
            return StochasticTriple{T}(val, δ, Δs)
        end
    end
end

on_new_rule(define_triple_overload, frule)

### Extra overloads

# TODO: generalize the below logic to compactly handle a wider range of functions.
# See also https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/dual.jl.

function Base.hash(st::StochasticTriple, hsh::UInt)
    if !isempty(st.Δs)
        error("Hashing a stochastic triple with perturbations not yet supported.")
    end
    hash(StochasticAD.value(st), hsh)
end

#=
This is a hacky experimental way to convert a float-like stochastic triple
into an integer-like one, to facilitate generic coding.
=#
function Base.round(I::Type{<:Integer}, st::StochasticTriple{T, V}) where {T, V}
    return StochasticTriple{T}(round(I, st.value), map(Δ -> round(I, st.value + Δ), st.Δs))
end

for op in UNARY_TYPEFUNCS_NOWRAP
    function (::typeof(op))(::Type{<:StochasticTriple{T, V, FIs}}) where {T, V, FIs}
        return op(V)
    end
end

for op in UNARY_TYPEFUNCS_WRAP
    function (::typeof(op))(::Type{StochasticTriple{T, V, FIs}}) where {T, V, FIs}
        return StochasticTriple{T, V, FIs}(op(V), zero(V), empty(FIs))
    end
    function (::typeof(op))(st::StochasticTriple)
        return op(typeof(st))
    end
end

for op in RNG_TYPEFUNCS_WRAP
    function (::typeof(op))(rng::AbstractRNG,
            ::Type{StochasticTriple{T, V, FIs}}) where {T, V, FIs}
        return StochasticTriple{T, V, FIs}(op(rng, V), zero(V), empty(FIs))
    end
end

#=
The short-circuit "x == y" case in Base.isapprox is bad for us
because it could unnecessarily lead to a boolean-predicate
depends on output error where StochasticAD cannot prove correctness.
We patch up the rule by removing the short-circuit, allowing some common
cases to work.

In the future, we will ideally handle the overloading rule in a more general
way. (E.g. by catching the chain rule for isapprox and recursively calling isapprox
on the values.)
=#
function Base.isapprox(st1::StochasticTriple, st2::StochasticTriple;
        atol::Real = 0, rtol::Real = Base.rtoldefault(st1, st2, atol),
        nans::Bool = false, norm::Function = abs)
    (isfinite(st1) && isfinite(st2) &&
     norm(st1 - st2) <= max(atol, rtol * max(norm(st1), norm(st2)))) ||
        (nans && isnan(st1) && isnan(st2))
end
function Base.isapprox(st1::StochasticTriple, x::Real;
        atol::Real = 0, rtol::Real = Base.rtoldefault(st1, x, atol),
        nans::Bool = false, norm::Function = abs)
    (isfinite(st1) && isfinite(x) &&
     norm(st1 - x) <= max(atol, rtol * max(norm(st1), norm(x)))) ||
        (nans && isnan(st1) && isnan(x))
end
function Base.isapprox(x::Real, st::StochasticTriple; kwargs...)
    return Base.isapprox(st, x; kwargs...)
end

# Alternate version of _isassigned that does not fall back on try/catch.
_isassigned(C, i) = (i in eachindex(C))

"""
    Base.getindex(C::AbstractArray, st::StochasticTriple{T})

A simple prototype rule for array indexing. Assumes that underlying type of `st` can index into collection C.
"""
# TODO: support multiple indices, cartesian indices, non abstract array indexables, other use cases...
# Example to fix: A[:, :, st]
function Base.getindex(C::AbstractArray, st::StochasticTriple{T, V, FIs}) where {T, V, FIs}
    val = C[st.value]
    do_map = (Δ, state) -> begin
        return value(C[st.value + Δ], state) - value(val, state)
    end

    # TODO: below doesn't support sparse arrays, use something like nextind
    deriv = δ -> begin
        scale = if _isassigned(C, st.value + 1) && _isassigned(C, st.value - 1)
            1 / 2 * (value(C[st.value + 1]) - value(C[st.value - 1]))
        elseif _isassigned(C, st.value + 1)
            value(C[st.value + 1]) - value(C[st.value])
        elseif _isassigned(C, st.value - 1)
            value(C[st.value]) - value(C[st.value - 1])
        else
            zero(eltype(C))
        end
        return scale * δ
    end

    Δs = StochasticAD.map_Δs(do_map, st.Δs; deriv, out_rep = value(val))
    if val isa StochasticTriple
        Δs = combine((Δs, val.Δs))
    end
    return StochasticTriple{T}(value(val), delta(val), Δs)
end


================================================
FILE: src/misc.jl
================================================
@doc raw"""
    StochasticModel(X, p)

Combine stochastic program `X` with parameter `p` into 
a trainable model using [Functors](https://fluxml.ai/Functors.jl/stable/), where
`p <: AbstractArray`.
Formulate as a minimization problem, i.e. find ``p`` that minimizes ``\mathbb{E}[X(p)]``.
"""
struct StochasticModel{S <: AbstractArray, T}
    X::T
    p::S
end
@functor StochasticModel (p,)

"""
    stochastic_gradient(m::StochasticModel)

Compute gradient with respect to the trainable parameter `p` of `StochasticModel(X, p)`.
"""
function stochastic_gradient(m::StochasticModel)
    fmap(p -> derivative_estimate(m.X, p), m)
end


================================================
FILE: src/prelude.jl
================================================
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)

const UNARY_PREDICATES = [isinf, isnan, isfinite, iseven, isodd, isreal, isinteger]

const BINARY_PREDICATES = [
    isequal,
    isless,
    <,
    >,
    ==,
    !=,
    <=,
    >=
]

const UNARY_TYPEFUNCS_NOWRAP = [Base.rtoldefault]
const UNARY_TYPEFUNCS_WRAP = [
    typemin,
    typemax,
    floatmin,
    floatmax,
    zero,
    one
]
const RNG_TYPEFUNCS_WRAP = [rand, randn, randexp]

"""
    structural_iterate(args)

Internal helper function for iterating through the scalar values of a functor, 
where AbstractFIs are also counted as scalars.
"""
function structural_iterate(args)
    make_iterator(x) = x isa AbstractArray ? x : (x,)
    exclude(x) = Functors.isleaf(x) || (x isa AbstractFIs)
    iter = fmap(make_iterator, args; walk = Functors.IterateWalk(), cache = nothing,
        exclude)
    return iter
end
structural_iterate(args::NTuple{N, Union{Real, AbstractFIs}}) where {N} = args
structural_iterate(args::AbstractArray{T}) where {T <: Union{Real, AbstractFIs}} = args
structural_iterate(args::T) where {T <: Real} = (args,)

"""
    structural_map(f, args)

Internal helper function for a structure-preserving map, 
often to be used on a function's input/output arguments. 
Currently uses [fmap](https://fluxml.ai/Functors.jl/stable/api/#Functors.fmap) 
from Functors.jl as a backend.
"""
function structural_map(f, args...; only_vals = nothing)
    walk = if only_vals isa Val{true}
        Functors.StructuralWalk()
    elseif (only_vals isa Val{false}) || isnothing(only_vals)
        Functors.DefaultWalk()
    else
        error("Unsupported argument only_vals = $only_vals")
    end
    fmap((args...) -> args[1] isa AbstractArray ? f.(args...) : f(args...), args...;
        cache = nothing,
        walk)
end


================================================
FILE: src/propagate.jl
================================================
"""
A version of `value`` that allows unrecognized args to pass through. 
"""
function get_value(arg)
    if arg isa StochasticTriple
        return value(arg)
    else
        # potentially dangerous, see also note in get_Δs
        return arg
    end
end

function get_Δs(arg, FIs)
    if arg isa StochasticTriple
        return arg.Δs
    else
        #=
        this case is a bit dangerous: perturbations could be dropped here
        if a leaf of a functor somehow contains a type that is not one of 
        the two above.
        =#
        return empty(similar_type(FIs, typeof(arg)))
    end
end

function strip_Δs(arg; use_dual = Val(true))
    if arg isa StochasticTriple
        # TODO: replace check below with a more robust notion of discreteness.
        if valtype(arg) <: Integer
            return value(arg)
        else
            if use_dual isa Val{true}
                return ForwardDiff.Dual{tag(arg)}(value(arg), delta(arg))
            else
                return StochasticAD.StochasticTriple{tag(arg)}(
                    value(arg), delta(arg), empty(arg.Δs))
            end
        end
    else
        return arg
    end
end

"""
    propagate(f, args...; keep_deltas = Val(false))

Propagates `args` through a function `f`, handling stochastic triples by independently running `f` on the primal
and the alternatives, rather than by inspecting the internals of `f` (which may possibly be unsupported by `StochasticAD`).
Currently handles deterministic functions `f` with any input and output that is `fmap`-able by `Functors.jl`.
If `f` has a continuously differentiable component, provide `keep_deltas = Val(true)`.

This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator 
overloading rules based on dispatch. For example:

```jldoctest
using StochasticAD, Distributions
import Random # hide
Random.seed!(4321) # hide

function mybranch(x)
    str = repr(x) # string-valued intermediate!
    if length(str) < 2
        return 3
    else
        return 7
    end
end

function f(x)
    return mybranch(9 + rand(Bernoulli(x)))
end

# stochastic_triple(f, 0.5) # this would fail

# Add a dispatch rule for mybranch using StochasticAD.propagate
mybranch(x::StochasticAD.StochasticTriple) = StochasticAD.propagate(mybranch, x)

stochastic_triple(f, 0.5) # now works

# output

StochasticTriple of Int64:
3 + 0ε + (4 with probability 2.0ε)
```

!!! warning
    This function is experimental and subject to change.
"""
function propagate(f,
        args...;
        keep_deltas = Val(false),
        provided_st_rep = nothing,
        deriv = nothing)
    # TODO: support kwargs to f (or just use kwfunc in macro)
    #= 
    TODO: maybe don't iterate through every scalar of array below, 
    but rather have special array dispatch
    =#
    st_rep = if provided_st_rep === nothing
        args_iter = structural_iterate(args)
        function args_fold(arg1, arg2)
            if arg1 isa StochasticTriple
                if (arg2 isa StochasticTriple) && (tag(arg1) !== tag(arg2))
                    throw(ArgumentError("Tags of combined stochastic triples do not match!"))
                end
                return arg1
            else
                return arg2
            end
        end
        foldl(args_fold, args_iter)
    else
        provided_st_rep
    end

    if !(st_rep isa StochasticTriple)
        return f(args...)
    end

    primal_args = structural_map(get_value, args)
    input_args = keep_deltas isa Val{false} ? primal_args : structural_map(strip_Δs, args)
    #= 
    TODO: the below is dangerous is general.
    It should be safe so long as f does not close over stochastic triples.
    (If f is a closure, the parameters of f should be treated like any other parameters;
    if they are stochastic triples and we are ignoring them, dangerous in general.)
    =#
    out = f(input_args...)
    val = structural_map(value, out)
    # TODO: what does the only_vals do in the below and why?
    Δs_all = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), args;
        only_vals = Val{true}())
    # TODO: Coupling approach below needs to handle non-perturbable objects.
    Δs_coupled = couple(backendtype(st_rep), Δs_all; rep = st_rep.Δs, out_rep = val)

    function map_func(Δ_coupled)
        perturbed_args = structural_map(+, primal_args, Δ_coupled)
        #= 
        TODO: for f discrete random with randomness independent of params,
        could couple here. But difficult without a splittable RNG. 
        =#
        alt = f(perturbed_args...)
        return structural_map((x, y) -> value(x) - y, alt, val)
    end
    Δs = map(map_func, Δs_coupled; out_rep = val, deriv)
    # TODO: make sure all FI backends support interface needed below
    new_out = structural_map(out, scalarize(Δs; out_rep = val)) do leaf_out, leaf_Δs
        StochasticAD.StochasticTriple{tag(st_rep)}(value(leaf_out), delta(leaf_out),
            leaf_Δs)
    end
    return new_out
end


================================================
FILE: src/smoothing.jl
================================================
### Particle resampling

@doc raw"""
    new_weight(p::Real)

    Simulate a Bernoulli variable whose primal output is always 1. 
    Uses a smoothing rule for use in forward and reverse-mode AD, which is exactly unbiased when the quantity is only
    used in linear functions  (e.g. used as an [importance weight](https://en.wikipedia.org/wiki/Importance_sampling)).
"""
new_weight(p::Real) = 1

function new_weight(p::ForwardDiff.Dual{T}) where {T}
    Δp = ForwardDiff.partials(p)
    val_p = ForwardDiff.value(p)
    val_p = max(1e-5, val_p) # TODO: is this necessary?
    ForwardDiff.Dual{T}(one(val_p), Δp / val_p)
end

function ChainRulesCore.frule((_, Δp), ::typeof(new_weight), p::Real)
    val_p = max(1e-5, p) # TODO: is this necessary?
    return one(p), Δp / val_p
end

function ChainRulesCore.rrule(::typeof(new_weight), p)
    function new_weight_pullback(∇Ω)
        return (ChainRulesCore.NoTangent(), ∇Ω / p)
    end
    return (one(p), new_weight_pullback)
end

# Smoothed rules for univariate single-parameter distributions. 

function smoothed_delta(d, val, δ, derivative_coupling)
    Δs_empty = SmoothedFIs{typeof(val)}(0.0)
    return derivative_contribution(δtoΔs(d, val, δ, Δs_empty, derivative_coupling))
end

for (dist, i, field) in [
    (:Geometric, :1, :p),
    (:Bernoulli, :1, :p),
    (:Binomial, :2, :p),
    (:Poisson, :1, :λ),
    (:Categorical, :1, :p)
] # i = index of parameter p
    # dual overloading 
    @eval function Base.rand(rng::AbstractRNG,
            d_dual::$dist{<:ForwardDiff.Dual{T}}) where {T}
        return randst(rng, d_dual)
    end
    @eval function randst(rng::AbstractRNG,
            d_dual::$dist{<:ForwardDiff.Dual{T}};
            derivative_coupling = InversionMethodDerivativeCoupling()) where {T}
        dual = params(d_dual)[$i]
        # dual could represent an array of duals or a single one; map handles both cases.
        p = map(value, dual)
        # Generate a δ for each partial component.
        partials_indices = ntuple(identity, length(first(dual).partials))
        δs = map(i -> map(d -> ForwardDiff.partials(d)[i], dual), partials_indices)
        d = $dist(params(d_dual)[1:($i - 1)]..., p,
            params(d_dual)[($i + 1):end]...)
        val = convert(Signed, rand(rng, d))
        partials = ForwardDiff.Partials(map(
            δ -> smoothed_delta(d, val, δ, derivative_coupling), δs))
        ForwardDiff.Dual{T}(val, partials)
    end
    # frule
    @eval function ChainRulesCore.frule(Δargs, ::typeof(rand), rng::AbstractRNG,
            d::$dist)
        return frule(Δargs, randst, rng, d)
    end
    @eval function ChainRulesCore.frule((_, _, Δd), ::typeof(randst), rng::AbstractRNG,
            d::$dist; derivative_coupling = InversionMethodDerivativeCoupling())
        val = convert(Signed, rand(rng, d))
        Δval = smoothed_delta(d, val, Δd, derivative_coupling)
        return (val, Δval)
    end
    # rrule
    @eval function ChainRulesCore.rrule(::typeof(rand), rng::AbstractRNG, d::$dist)
        return rrule(randst, rng, d)
    end
    @eval function ChainRulesCore.rrule(::typeof(randst),
            rng::AbstractRNG,
            d::$dist;
            derivative_coupling = InversionMethodDerivativeCoupling())
        val = convert(Signed, rand(rng, d))
        function rand_pullback(∇out)
            p = params(d)[$i]
            if p isa Real
                Δp = smoothed_delta(d, val, one(val), derivative_coupling)
            else
                # TODO: this rule is O(length(p)^2), whereas we should be able to do O(length(p)) by reversing through δtoΔs.
                I = eachindex(p)
                V = eltype(p)
                onehot(i) = map(j -> j == i ? one(V) : zero(V), I)
                Δp = map(i -> smoothed_delta(d, val, onehot(i), derivative_coupling), I)
            end
            # rrule_via_ad approach below not used because slow.
            # Δp = rrule_via_ad(config, smoothed_delta, d, val, map(one, p))[2](∇out)[4]
            Δd = ChainRulesCore.Tangent{typeof(d)}(; $field = ∇out * Δp)
            return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), Δd)
        end
        return (val, rand_pullback)
    end
end


================================================
FILE: src/stochastic_triple.jl
================================================
""" 
    StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}}

Stores the primal value of the computation, alongside a "dual" component
representing an infinitesimal change, and a "triple" component that tracks
discrete change(s) with infinitesimal probability. 

Pretty printed as "value + δε + ({pretty print of Δs})".

## Constructor

- `value`: the primal value.
- `δ`: the value of the almost-sure derivative, i.e. the rate of "infinitesimal" change.
- `Δs`: alternate values with associated weights, i.e. Finite perturbations with Infinitesimal probability,
        represented by a backend `FIs <: AbstractFIs`.
"""
struct StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}} <: Real
    value::V
    δ::V # infinitesimal change
    Δs::FIs # finite changes with infinitesimal probabilities # (Δ = 3, p = 1*h)
    function StochasticTriple{T, V, FIs}(value::V, δ::V,
            Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}
        new{T, V, FIs}(value, δ, Δs)
    end
end

"""
    value(st::StochasticTriple)

Return the primal value of `st`.
"""
value(x::Real, state = nothing) = x
# Experimental method for obtaining the alternate value of a stochastic triple associated with a certain backend state.
value(st::StochasticTriple) = st.value
function value(st::StochasticTriple, state)
    st.value + filter_state(st.Δs, state)
end
#=
Support ForwardDiff.Dual for internal usage.
Assumes batch size is 1.
=#
value(d::ForwardDiff.Dual, state = nothing) = ForwardDiff.value(d)

"""
    delta(st::StochasticTriple)

Return the almost-sure derivative of `st`, i.e. the rate of infinitesimal change.
"""
delta(x::Real) = zero(x)
delta(st::StochasticTriple) = st.δ
# Support ForwardDiff.Dual for internal usage.
delta(d::ForwardDiff.Dual) = ForwardDiff.partials(d)[1]

"""
    perturbations(st::StochasticTriple)

Return the finite perturbation(s) of `st`, in a format dependent on the [backend](devdocs.md) used for storing perturbations.
"""
perturbations(x::Real) = ()
perturbations(st::StochasticTriple) = perturbations(st.Δs)

"""
    send_signal(st::StochasticTriple, signal::AbstractPerturbationSignal)
    send_signal(Δs::StochasticAD.AbstractFIs, signal::AbstractPerturbationSignal)

Send a certain signal to a stochastic triple's perturbation collection `st.Δs` (or to a `Δs` directly), 
which the backend may process as it wishes. Semantically, unbiasedness should not be affected by the 
sending of the signal. The new version of the first argument (`st` or `Δs`) after signal processing is 
returned.
"""
send_signal(st::Real, ::AbstractPerturbationSignal) = st
function send_signal(st::StochasticTriple{T}, signal::AbstractPerturbationSignal) where {T}
    new_Δs = send_signal(st.Δs, signal)
    return StochasticTriple{T}(st.value, st.δ, new_Δs)
end

"""
    derivative_contribution(st::StochasticTriple)

Return the derivative estimate given by combining the dual and triple components of `st`.
"""
derivative_contribution(x::Real) = zero(x)
derivative_contribution(d::ForwardDiff.Dual) = delta(d)
derivative_contribution(st::StochasticTriple) = st.δ + derivative_contribution(st.Δs)

"""
    tag(st::StochasticTriple)
    tag(::Type{<:StochasticTriple}))

Get the tag of a stochastic triple.
"""
tag(::Type{<:StochasticTriple{T}}) where {T} = T
tag(::Type{<:ForwardDiff.Dual{T}}) where {T} = T
tag(::StochasticTriple{T}) where {T} = T
tag(::ForwardDiff.Dual{T}) where {T} = T

"""
    valtype(st::StochasticTriple)
    valtype(st::Type{<:StochasticTriple})

Get the underlying type of the value tracked by a stochastic triple.
"""
valtype(st::StochasticTriple) = valtype(typeof(st))
valtype(::Type{<:StochasticTriple{T, V}}) where {T, V} = V

"""
    backendtype(st::StochasticTriple)
    backendtype(st::Type{<:StochasticTriple})

Get the backend type of a stochastic triple.
"""
backendtype(st::StochasticTriple) = backendtype(typeof(st))
backendtype(::Type{<:StochasticTriple{T, V, FIs}}) where {T, V, FIs} = FIs

"""
    smooth_triple(st::StochasticTriple)

Smooth the dual and triple components of a stochastic triple into a single dual component.
Useful for avoiding unnecessary pruning when running multilinear functions on triples.
"""
smooth_triple(x::Real) = x
function smooth_triple(st::StochasticTriple{T, V, FIs}) where {T, V, FIs}
    return StochasticTriple{T}(value(st), derivative_contribution(st), empty(FIs))
end

### Extra constructors of stochastic triples

function StochasticTriple{T}(value::V, δ::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}
    StochasticTriple{T, V, FIs}(value, δ, Δs)
end

function StochasticTriple{T}(value::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}
    StochasticTriple{T}(value, zero(value), Δs)
end

function StochasticTriple{T}(value::A, δ::B,
        Δs::FIs) where {T, A, B, C, FIs <: AbstractFIs{C}}
    V = promote_type(A, B, C)
    StochasticTriple{T}(convert(V, value), convert(V, δ), convert(similar_type(FIs, V), Δs))
end

### Conversion rules

# TODO: is this the right thing to do? Maybe, different from the promote case because there V was guaranteed to be an ancestor. 
# Also, bad to do when already same type?
function Base.convert(::Type{StochasticTriple{T1, V, FIs}},
        x::StochasticTriple{T2}) where {T1, T2, V, FIs}
    (T1 !== T2) && throw(ArgumentError("Tags of combined stochastic triples do not match."))
    StochasticTriple{T1, V, FIs}(convert(V, x.value), convert(V, x.δ), convert(FIs, x.Δs))
end

# TODO: ForwardDiff's promotion rules are a little more complicated, see https://github.com/JuliaDiff/ForwardDiff.jl/issues/322
# May need to look into why and possibly use them here too.
function Base.promote_rule(::Type{StochasticTriple{T, V1, FIs}},
        ::Type{StochasticTriple{T, V2, FIs2}}) where {T, V1, FIs, V2,
        FIs2}
    V = promote_type(V1, V2)
    StochasticTriple{T, V, similar_type(FIs, V)}
end

function Base.promote_rule(::Type{StochasticTriple{T, V1, FIs}},
        ::Type{V2}) where {T, V1, FIs, V2 <: Real}
    V = promote_type(V1, V2)
    StochasticTriple{T, V, similar_type(FIs, V)}
end

function Base.convert(::Type{StochasticTriple{T, V, FIs}}, x::Real) where {T, V, FIs}
    StochasticTriple{T, V, FIs}(convert(V, x), zero(V), empty(FIs))
end

### Creating the first stochastic triple in a computation

function StochasticTriple{T}(value::V, δ::V, backend::AbstractFIsBackend) where {T, V}
    StochasticTriple{T}(value, δ, create_Δs(backend, V))
end

function StochasticTriple{T}(value::V, backend::AbstractFIsBackend) where {T, V}
    StochasticTriple{T}(value, zero(V), backend)
end

function StochasticTriple{T}(value::A, δ::B, backend::AbstractFIsBackend) where {T, A, B}
    V = promote_type(A, B)
    StochasticTriple{T}(convert(V, value), convert(V, δ), backend)
end

### Showing a stochastic triple

function Base.summary(::StochasticTriple{T, V}) where {T, V}
    return "StochasticTriple of $V"
end

function Base.show(io::IO, ::MIME"text/plain", st::StochasticTriple)
    println(io, "$(summary(st)):")
    show(io, st)
end

function Base.show(io::IO, st::StochasticTriple)
    print(io, "$(st.value) + $(st.δ)ε")
    if (!isempty(st.Δs))
        print(io, " + ($(repr(st.Δs)))")
    end
end

### Higher level functions

struct Tag{F, V}
end

function stochastic_triple_direction(f, p::V, direction; backend) where {V}
    Δs = create_Δs(backend, Int) # TODO: necessity of hardcoding some type here suggests interface improvements
    sts = structural_map(p, direction) do p_i, direction_i
        StochasticTriple{Tag{typeof(f), V}}(p_i, direction_i,
            similar_empty(Δs, typeof(p_i)))
    end
    return f(sts)
end

"""
    stochastic_triple(X, p; backend=PrunedFIsBackend(), direction=nothing)
    stochastic_triple(p; backend=PrunedFIsBackend(), direction=nothing)

For any `p` that is supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/),
e.g. scalars or abstract arrays,
differentiate the output with respect to each value of `p`,
returning an output of similar structure to `p`, where a particular value contains
the stochastic-triple output of `X` when perturbing the corresponding value in `p`
(i.e. replacing the original value `x` with `x + ε`).

When `direction` is provided, return only the stochastic-triple output of `X` with respect to a perturbation
of `p` in that particular direction.
When `X` is not provided, the identity function is used. 

The `backend` keyword argument describes the algorithm used by the third component
of the stochastic triple, see [technical details](devdocs.md) for more details.

# Example
```jldoctest
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);

julia> stochastic_triple(rand ∘ Bernoulli, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability 2.0ε)
```
"""
function stochastic_triple(
        f, p; direction = nothing, backend::AbstractFIsBackend = PrunedFIsBackend())
    if direction !== nothing
        return stochastic_triple_direction(f, p, direction; backend)
    end
    counter = begin
        c = 0
        (_) -> begin
            c += 1
            return c
        end
    end
    indices = structural_map(counter, p)
    map_func = perturbed_index -> begin
        direction = structural_map(indices, p) do i, p_i
            i == perturbed_index ? one(p_i) : zero(p_i)
        end
        stochastic_triple_direction(f, p, direction; backend)
    end
    return structural_map(map_func, indices)
end

stochastic_triple(p; kwargs...) = stochastic_triple(identity, p; kwargs...)

"""
    dual_number(X, p; backend=PrunedFIsBackend(), direction=nothing)
    dual_number(p; backend=PrunedFIsBackend(), direction=nothing)

A lightweight wrapper around [`stochastic_triple`](#StochasticAD.stochastic_triple) that entirely ignores the
derivative contribution of all discrete random components, so that it behaves like a regular dual number.
Mostly for fun -- this, of course, leads to a useless derivative estimate for discrete random functions!
"""
function dual_number(f, p; backend = PrunedFIsBackend(), kwargs...)
    backend = StrategyWrapperFIsBackend(backend, IgnoreDiscreteStrategy())
    stochastic_triple(f, p; backend, kwargs...)
end
dual_number(p; kwargs...) = dual_number(identity, p; kwargs...)

function derivative_estimate(f, p; kwargs...)
    StochasticAD.structural_map(derivative_contribution, stochastic_triple(f, p; kwargs...))
end


================================================
FILE: test/game_of_life.jl
================================================
using StochasticAD
using Test
using Statistics

include("../tutorials/game_of_life/core.jl")
using .GoLCore: fd_clever, play, p, nsamples

@testset "AD and Finite Differences" begin
    samples_fd_clever = [fd_clever(p) for i in 1:nsamples]
    samples_st = [derivative_estimate(play, p) for i in 1:nsamples]

    @test mean(samples_st)≈mean(samples_fd_clever) rtol=5e-2
end


================================================
FILE: test/random_walk.jl
================================================
using StochasticAD
using Test
using Statistics
using ForwardDiff: derivative

include("../tutorials/random_walk/core.jl")
using .RandomWalkCore: n, p, nsamples
using .RandomWalkCore: fX, get_dfX

@testset "Check unbiasedness" begin
    fX_deriv = derivative(p -> get_dfX(p, n), p)
    fX_deriv_estimate = mean(derivative_estimate(fX, p) for i in 1:nsamples)
    @test isapprox(fX_deriv, fX_deriv_estimate; rtol = 1e-2)
end


================================================
FILE: test/resampling.jl
================================================
using StochasticAD
using Random, Test
using Distributions
using LinearAlgebra
using ForwardDiff

# test forward-mode AD and reverse-mode AD on the particle filter

### Particle Filter Functions
include("../tutorials/particle_filter/core.jl")
seed = 237347578

### Define model
Random.seed!(seed)

T = 3
d = 2
A(θ, a = 0.01) = [exp(-a)*cos(θ[]) exp(-a)*sin(θ[])
                  -exp(-a)*sin(θ[]) exp(-a)*cos(θ[])]
obs(x, θ) = MvNormal(x, 0.01 * collect(I(d)))
dyn(x, θ) = MvNormal(A(θ) * x, 0.02 * collect(I(d)))
x0 = [2.0, 0.0] # start value of the simulation
start(θ) = Dirac(x0)
θtrue = [0.20]
# put it all together
stochastic_model = ParticleFilterCore.StochasticModel(T, start, dyn, obs)

### simulate model
Random.seed!(seed)
xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue)
###

### initialize sampler
m = 1000
particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys,
    ParticleFilterCore.sample_stratified)
###

@testset "new weight" begin
    p = 0.5
    st = stochastic_triple(p)
    d = ForwardDiff.Dual(p, (1.0, 2.0))
    @test new_weight(p) == one(p)
    @test StochasticAD.value(new_weight(st)) == one(p)
    @test StochasticAD.delta(new_weight(st)) == 1.0 / p
    @test ForwardDiff.value(new_weight(d)) == one(p)
    @test collect(ForwardDiff.partials(new_weight(d))) == [1.0 / p, 2.0 / p]
end

@testset "forward-mode and reverse-mode AD: single run" begin
    Random.seed!(seed)
    grad_forw = ParticleFilterCore.forw_grad(θtrue, particle_filter)
    Random.seed!(seed)
    grad_back = ParticleFilterCore.back_grad(θtrue, particle_filter)
    @test grad_forw ≈ grad_back
end

@testset "AD and Finite Differences" begin
    h = 0.02 # finite diff
    N = 500 # number of samples
    grad_fw = [ParticleFilterCore.forw_grad(θtrue, particle_filter)[1] for i in 1:N]
    # grad_bw = @time [back_grad(θtrue, particle_filter) for i in 1:N]
    grad_fd = [(ParticleFilterCore.log_likelihood(particle_filter, θtrue .+ h) -
                ParticleFilterCore.log_likelihood(particle_filter, θtrue .- h)) / (2h)
               for i in 1:N]

    @test mean(grad_fd)≈mean(grad_fw) rtol=5e-2
end


================================================
FILE: test/runtests.jl
================================================
using SafeTestsets
using Test, Pkg
import Random

Random.seed!(1234)

const GROUP = get(ENV, "GROUP", "All")
const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")

@time begin
    if GROUP == "All"
        @time @safetestset "Triples" begin
            include("triples.jl")
        end
        @time @safetestset "Game of life" begin
            include("game_of_life.jl")
        end
        @time @safetestset "Random walk" begin
            include("random_walk.jl")
        end
        @time @safetestset "Resampling" begin
            include("resampling.jl")
        end
    end
end


================================================
FILE: test/triples.jl
================================================
using StochasticAD
using Test
using Distributions
using ForwardDiff
using OffsetArrays
using ChainRulesCore
using Random
using Zygote

const backends = [
    PrunedFIsBackend(),
    PrunedFIsAggressiveBackend(),
    DictFIsBackend()
]

const backends_smoothed = [
    SmoothedFIsBackend(),
    StrategyWrapperFIsBackend(SmoothedFIsBackend(), StochasticAD.TwoSidedStrategy())
]

@testset "Distributions w.r.t. continuous parameter" begin
    for backend in vcat(backends,
        backends_smoothed,
        :smoothing_autodiff)
        MAX = 10000
        nsamples = 100000
        rtol = 5e-2 # friendly tolerance for stochastic comparisons. TODO: more motivated choice of tolerance.

        ### Make test cases

        distributions = [
            Bernoulli,
            Geometric,
            Poisson,
            (p -> Categorical([p^2, 1 - p^2])),
            (p -> Categorical([0, p^2, 0, 0, 1 - p^2])), # check that 0's are skipped over
            (p -> Categorical([1.0, exp(p)] ./ (1.0 + exp(p)))), # test fix for #38 (floating point comparisons in Categorical logic)
            (p -> Binomial(3, p)),
            (p -> Binomial(20, p))
        ]
        p_ranges = [(0.2, 0.8) for _ in 1:8]
        out_ranges = [0:1, 0:MAX, 0:MAX, 1:2, 1:5, 1:2, 0:3, 0:20]
        test_cases = collect(zip(distributions, p_ranges, out_ranges))
        test_funcs = [x -> 7 * x - 3, x -> (x + 1)^2, x -> sqrt(x + 1)]

        if backend isa DictFIsBackend
            # Only test dictionary backend on Bernoulli to speed things up. Should still cover interface.
            test_cases = test_cases[1:1]
        elseif backend == :smoothing_autodiff || backend in backends_smoothed
            # Only test smoothing backend on each unique distribution once to seed tests up. 
            test_cases = vcat(test_cases[1:4], test_cases[7])
            # Only test unbiasedness of smoothing for linear function
            test_funcs = test_funcs[1:1]
        end

        for (distr, p_range, out_range) in test_cases
            for f in test_funcs
                function get_mean(p)
                    dp = distr(p)
                    sum(pdf(dp, i) * f(i) for i in out_range)
                end

                low_p, high_p = p_range
                for g in [p -> p, p -> high_p + low_p - p] # test both sides of derivative
                    full_func = f ∘ rand ∘ distr ∘ g
                    p = low_p + (high_p - low_p) * rand()
                    exact_deriv = ForwardDiff.derivative(p -> get_mean(g(p)), p)
                    if backend == :smoothing_autodiff
                        batched_full_func(p) = mean([full_func(p) for i in 1:nsamples])
                        # The array input used for ForwardDiff below is a trick to test multiple partials
                        triple_deriv_forward = mean(ForwardDiff.gradient(
                            arr -> batched_full_func(sum(arr)),
                            [2 * p, -p]))
                        triple_deriv_backward = Zygote.gradient(batched_full_func, p)[1]
                        @test isapprox(triple_deriv_forward, exact_deriv, rtol = rtol)
                        @test isapprox(triple_deriv_backward, exact_deriv, rtol = rtol)
                    else
                        get_deriv = () -> derivative_estimate(full_func, p; backend)
                        triple_deriv = mean(get_deriv() for i in 1:nsamples)
                        @test isapprox(triple_deriv, exact_deriv, rtol = rtol)
                    end
                end
            end
        end
    end
end

@testset "Perturbing n of binomial" begin
    function get_triple_deriv(Δ)
        # Manually create a finite perturbation to avoid any randomness in its creation
        Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), Δ,
            3.5)
        st = StochasticAD.StochasticTriple{0}(5, 0, Δs)
        st_continuous = stochastic_triple(0.5)
        return derivative_contribution(rand(Binomial(st, st_continuous)))
    end
    for Δ in -2:2
        triple_deriv = mean(get_triple_deriv(Δ) for i in 1:100000)
        exact_deriv = 3.5 * 0.5 * Δ + 5
        @test isapprox(triple_deriv, exact_deriv, rtol = 5e-2)
    end
end

@testset "Nested binomials" begin
    binbin = p -> rand(Binomial(rand(Binomial(10, p)), p)) # ∼ Binomial(10, p^2)
    for p in [0.3, 0.7]
        triple_deriv = mean(derivative_estimate(binbin, p) for i in 1:100000)
        exact_deriv = 10 * 2 * p
        @test isapprox(triple_deriv, exact_deriv, rtol = 5e-2)
    end
end

@testset "Boolean comparisons" begin
    for backend in backends
        tested = falses(2)
        while !(all(tested))
            st = stochastic_triple(rand ∘ Bernoulli, 0.5; backend)
            x = StochasticAD.value(st)
            if x == 0
                # Ensure errors on unsafe/unsupported boolean comparisons
                @test_throws Exception st>0.5
                @test_throws Exception 0.5<st
                @test_throws Exception st==0
            else
                @test st > 0.5
                @test 0.5 < st
                @test st == 1
            end
            tested[x + 1] = true
        end
        @test stochastic_triple(1.0; backend) != 1
    end
end

@testset "Array indexing" begin
    for backend in vcat(backends, backends_smoothed)
        p = 0.3
        # Test indexing into array of floats with stochastic triple index
        arr = [3.5, 5.2, 8.4]
        (backend in backends_smoothed) && (arr[3] = 6.9) # make linear for smoothing test
        function array_index(p)
            index = rand(Categorical([p / 2, p / 2, 1 - p]))
            return arr[index]
        end
        array_index_mean(p) = sum([p / 2, p / 2, (1 - p)] .* arr)
        triple_array_index_deriv = mean(derivative_estimate(array_index, p; backend)
        for i in 1:50000)
        exact_array_index_deriv = ForwardDiff.derivative(array_index_mean, p)
        @test isapprox(triple_array_index_deriv, exact_array_index_deriv, rtol = 5e-2)
        # Don't run subsequent tests with smoothing backend
        (backend in backends_smoothed) && continue
        # Test indexing into array of stochastic triples with stochastic triple index
        function array_index2(p)
            arr2 = [rand(Bernoulli(p)), rand(Bernoulli(p)), rand(Bernoulli(p))] .* arr
            index = rand(Categorical([p / 2, p / 2, 1 - p]))
            return arr2[index]
        end
        array_index2_mean(p) = sum([p / 2 * p, p / 2 * p, (1 - p) * p] .* arr)
        triple_array_index2_deriv = mean(derivative_estimate(array_index2, p; backend)
        for i in 1:50000)
        exact_array_index2_deriv = ForwardDiff.derivative(array_index2_mean, p)
        @test isapprox(triple_array_index2_deriv, exact_array_index2_deriv, rtol = 5e-2)
        # Test case where triple and alternate array value are coupled
        function array_index3(p)
            st = rand(Bernoulli(p))
            arr2 = [-5, st]
            return arr2[st + 1]
        end
        array_index3_mean(p) = -5 * (1 - p) + 1 * p
        triple_array_index3_deriv = mean(derivative_estimate(array_index3, p; backend)
        for i in 1:50000)
        exact_array_index3_deriv = ForwardDiff.derivative(array_index3_mean, p)
        @test isapprox(triple_array_index3_deriv, exact_array_index3_deriv, rtol = 5e-2)
    end
end

@testset "Array/functor inputs to higher level functions" begin
    for backend in backends
        # Try a deterministic test function to compare to ForwardDiff
        f(x) = (x[1] * x[2] * sin(x[3]) + exp(x[1] * x[2])) / x[3]
        x = [1, 2, π / 2]

        stochastic_ad_grad = derivative_estimate(f, x; backend)
        stochastic_ad_grad2 = derivative_contribution.(stochastic_triple(f, x; backend))
        stochastic_ad_grad_firsttwo = derivative_estimate(
            f, x; direction = [1.0, 1.0, 0.0],
            backend)
        fd_grad = ForwardDiff.gradient(f, x)
        @test stochastic_ad_grad ≈ fd_grad
        @test stochastic_ad_grad ≈ stochastic_ad_grad2
        @test stochastic_ad_grad[1] + stochastic_ad_grad[2] ≈ stochastic_ad_grad_firsttwo

        # Try an OffsetArray too
        f_off(x) = (x[0] * x[1] * sin(x[2]) + exp(x[0] * x[1])) / x[2]
        x_off = OffsetArray([1, 2, π / 2], 0:2)
        stochastic_ad_grad_off = derivative_estimate(f_off, x_off)
        @test stochastic_ad_grad_off ≈ OffsetArray(stochastic_ad_grad, 0:2)

        # Try a Functor
        f_func(x) = (x[1] * x[2][1] * sin(x[2][2]) + exp(x[1] * x[2][1])) / x[2][2]
        x_func = (1, [2, π / 2])
        stochastic_ad_grad_func = derivative_estimate(f_func, x_func)
        stochastic_ad_grad_func_expected = (stochastic_ad_grad[1], stochastic_ad_grad[2:3])
        compare_grad_funcs = StochasticAD.structural_map(≈, stochastic_ad_grad_func,
            stochastic_ad_grad_func_expected)
        @test all(compare_grad_funcs |> StochasticAD.structural_iterate)

        # Test StochasticModel + stochastic_gradient combination
        m = StochasticModel(f, x)
        @test stochastic_gradient(m).p ≈ stochastic_ad_grad
    end
end

@testset "Propagation using frule with ZeroTangent" begin
    st = stochastic_triple(0.5)

    # Verify that the rule for imag indeed gives a ZeroTangent
    value = StochasticAD.value(st)
    δ = StochasticAD.delta(st)
    @test frule((NoTangent(), δ), imag, value)[2] isa ZeroTangent
    # Test that stochastic triples flow through this rule
    out_st = imag(st)
    @test StochasticAD.value(out_st) ≈ 0
    @test StochasticAD.delta(out_st) ≈ 0
    @test isempty(out_st.Δs)
end

@testset "Unary functions converting type to fixed instance" begin
    for val in [0.5, 1]
        st = stochastic_triple(val)
        for op in StochasticAD.UNARY_TYPEFUNCS_WRAP
            f = getfield(Base, Symbol(op))
            out_st = f(st)
            @test out_st isa StochasticAD.StochasticTriple
            @test StochasticAD.value(out_st) ≈ f(val) ≈ f(typeof(val))
            @test StochasticAD.delta(out_st) ≈ 0
            @test isempty(out_st.Δs)
            @test f(typeof(st)) == out_st
        end
        #=
        It so happens that the UNARY_TYPEFUNCS_WRAP funcs all support both instances and types
        whereas UNARY_TYPEFUNCS_NOWRAP only supports types, so we only test types in the below,
        but this is a coincidence that may not hold in the future.
        =#
        for op in StochasticAD.UNARY_TYPEFUNCS_NOWRAP
            f = getfield(Base, Symbol(op))
            out = f(typeof(st))
            @test out isa typeof(val)
            @test out ≈ f(typeof(val))
        end
        RNG = copy(Random.GLOBAL_RNG)
        for op in StochasticAD.RNG_TYPEFUNCS_WRAP
            f = getfield(Random, Symbol(op))
            out_st = f(copy(RNG), typeof(st))
            @test out_st isa StochasticAD.StochasticTriple
            @test StochasticAD.value(out_st) ≈ f(copy(RNG), typeof(val))
            @test StochasticAD.delta(out_st) ≈ 0
            @test isempty(out_st.Δs)
        end
    end
end

@testset "Hashing" begin
    st = stochastic_triple(3.0)
    @test_nowarn hash(st)
    @test_nowarn hash(st, UInt(5))
    d = Dict()
    @test_nowarn d[st] = 5
    @test d[st] == 5
    @test d[3] == 5
    # Test that we get an error with discrete random dictionary indices,
    # since this isn't supported and we want to avoid silent failures.
    Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0, 1.0)
    st = StochasticAD.StochasticTriple{0}(1.0, 0, Δs)
    @test_throws ErrorException d[rand(Bernoulli(st))]
end

@testset "Coupled comparison" begin
    Δs_1 = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0,
        1.0)
    Δs_2 = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0,
        1.0)
    st_1 = StochasticAD.StochasticTriple{0}(1.0, 0, Δs_1)
    st_2 = StochasticAD.StochasticTriple{0}(1.0, 0, Δs_2)
    @test st_1 == st_1
    @test_throws ErrorException st_1==st_2
end

@testset "Converting float stochastic triples to integer triples" begin
    st = stochastic_triple(0.6)
    @test round(Int, st) isa StochasticAD.StochasticTriple
    @test StochasticAD.delta(round(Int, st)) ≈ 0
    @test round(Int, st) ≈ 1
end

@testset "Approximate comparisons" begin
    st = stochastic_triple(0.5)
    @test st ≈ st
    # Check that the rtol is indeed reasonable
    @test st ≈ st + 1e-14
    @test !(st ≈ st + 1)
    @test_broken stochastic_triple(Inf) ≈ stochastic_triple(Inf)
end

@testset "Error on unmatched tags" begin
    st1 = stochastic_triple(0.5)
    st2 = stochastic_triple(x -> x^2, 0.5)
    @test_throws ArgumentError convert(typeof(st1), st2)
end

@testset "Finite perturbation backend interface" begin
    for backend in vcat(backends,
        backends_smoothed)
        # this boolean may need to become more fine-grained in the future
        is_smoothed_backend = backend in backends_smoothed
        #=
        Test the backend interface across the finite perturbation backends,
        which is currently a bit implicitly defined.
        =#
        V0 = Int
        V1 = Float64
        #=
        All four of the below approaches should create an empty backend,
        although the backend's internal state management may differ. 
        =#
        Δs0 = StochasticAD.create_Δs(backend, V0) # used to create first triple in computation
        FIs = typeof(Δs0)
        Δs1 = empty(Δs0)
        Δs2 = empty(typeof(Δs0))
        Δs3 = StochasticAD.similar_empty(Δs0, V1)
        for (Δs, V) in ((Δs0, V0), (Δs1, V0), (Δs2, V0), (Δs3, V1))
            @test StochasticAD.valtype(Δs) === V
            @test Δs isa StochasticAD.similar_type(FIs, V)
            !is_smoothed_backend && @test isempty(Δs)
            @test iszero(derivative_contribution(Δs))
        end
        # Test creation of a single perturbation
        for Δ in (1, 3.0)
            Δs0 = StochasticAD.create_Δs(backend, V0)
            Δs1 = StochasticAD.similar_new(Δs0, Δ, 3.0)
            @test StochasticAD.valtype(Δs1) === typeof(Δ)
            @test Δs1 isa StochasticAD.similar_type(FIs, typeof(Δ))
            !is_smoothed_backend && @test !isempty(Δs1)
            @test derivative_contribution(Δs1) == 3Δ
            # Test StochasticAD.alltrue
            @test StochasticAD.alltrue(_Δ -> true, Δs1)
            @test !StochasticAD.alltrue(_Δ -> false, Δs1) || is_smoothed_backend
            # Test map
            # We use a dummy deriv here and below. TODO: use a more interesting dummy for better testing.
            Δs1_map = Base.map(Δ -> Δ^2, Δs1; deriv = identity, out_rep = Δ)
            !is_smoothed_backend && @test derivative_contribution(Δs1_map) ≈ Δ^2 * 3.0
            # Test map with weight (make a new copy so that original does not get reweighted)
            Δs2 = StochasticAD.similar_new(StochasticAD.create_Δs(backend, V0), Δ, 3.0)
            Δs2_weight_map = StochasticAD.weighted_map_Δs((Δ, _) -> (Δ^2, 2.0),
                Δs2;
                deriv = identity,
                out_rep = Δ)
            !is_smoothed_backend &&
                @test derivative_contribution(Δs2_weight_map) ≈ Δ^2 * 3.0 * 2.0
            # Also test scale
            w2 = derivative_contribution(Δs2)
            Δs2_scaled = StochasticAD.scale(Δs2, 2.0)
            w2_scaled = derivative_contribution(Δs2_scaled)
            @test w2_scaled ≈ 2.0 * w2
            # Test map_Δs with filter state
            if !is_smoothed_backend
                Δs1_plus_Δs0 = StochasticAD.map_Δs(
                    (Δ, state) -> Δ +
                                  StochasticAD.filter_state(Δs0,
                        state),
                    Δs1)
                @test derivative_contribution(Δs1_plus_Δs0) ≈ Δ * 3.0
                Δs1_plus_mapped = StochasticAD.map_Δs(
                    (Δ, state) -> Δ +
                                  StochasticAD.filter_state(Δs1,
                        state),
                    Δs1_map)
                @test derivative_contribution(Δs1_plus_mapped) ≈ Δ * 3.0 + Δ^2 * 3.0
            end
        end
        # Test coupling
        Δ_coupleds = (3, [4.0, 5.0], (2, [3.0, 4.0]))
        for Δ_coupled in Δ_coupleds
            function get_Δs_coupled(; do_combine = false, use_get_rep = false)
                Δs0 = StochasticAD.create_Δs(backend, Int)
                Δs1 = StochasticAD.similar_new(Δs0, 1, 3.0) # perturbation 1
                Δs2 = StochasticAD.similar_new(Δs0, 1, 2.0) # perturbation 2
                # A group of perturbations that all stem from perturbation 1. 
                Δs_all1 = StochasticAD.structural_map(Δ_coupled) do Δ
                    Base.map(_Δ -> Δ, Δs1; deriv = identity, out_rep = Δ)
                end
                # A group of perturbations that all stem from perturbation 2. 
                Δs_all2 = StochasticAD.structural_map(Δ_coupled) do Δ
                    Base.map(_Δ -> 2 * Δ, Δs2; deriv = (δ -> 2δ), out_rep = Δ)
                end
                # Join them into a single structure that should be coupled
                Δs_all = (Δs_all1, Δs_all2)
                kwargs = use_get_rep ? (; rep = StochasticAD.get_rep(FIs, Δs_all)) : (;)
                if do_combine
                    return StochasticAD.combine(FIs, Δs_all; kwargs...)
                else
                    return StochasticAD.couple(FIs, Δs_all;
                        out_rep = (Δ_coupled, Δ_coupled),
                        kwargs...)
                end
            end
            #=
            As a test function to apply to the coupled perturbation, we apply
            a matmul followed by a sigmoid activation function and a sum.
            =#
            l = 2 * length(collect(StochasticAD.structural_iterate(Δ_coupled)))
            A = rand(l, l)
            function mapfunc(Δ_coupled)
                arr = collect(StochasticAD.structural_iterate(Δ_coupled))
                sum(x -> 1 / (1 + exp(-x)), A * arr)
            end
            # Test the above function, and also a simple sum.
            for use_get_rep in (false, true)
                Δs_coupled = get_Δs_coupled(; use_get_rep)
                @test StochasticAD.valtype(Δs_coupled) == typeof((Δ_coupled, Δ_coupled))
                for (mapfunc, check_combine) in ((mapfunc, false),
                    (Δ_coupled -> sum(StochasticAD.structural_iterate(Δ_coupled)),
                        true))
                    function get_contribution()
                        Δs_coupled = get_Δs_coupled(; use_get_rep)
                        Δs_coupled_mapped = map(mapfunc, Δs_coupled; deriv = (δ -> 1.0),
                            out_rep = 0.0)
                        return derivative_contribution(Δs_coupled_mapped)
                    end
                    zero_Δ_coupled = StochasticAD.structural_map(zero, Δ_coupled)
                    expected_contribution1 = 3.0 * mapfunc((Δ_coupled, zero_Δ_coupled))
                    expected_contribution2 = 2.0 * mapfunc((zero_Δ_coupled,
                        StochasticAD.structural_map(x -> 2x,
                            Δ_coupled)))
                    expected_contribution = expected_contribution1 + expected_contribution2
                    if !is_smoothed_backend
                        @test isapprox(mean(get_contribution() for i in 1:1000),
                            expected_contribution; rtol = 5e-2)
                    end
                    # For a simple sum, this should be equivalent to the combine behaviour.
                    if check_combine && !is_smoothed_backend
                        @test isapprox(
                            mean(derivative_contribution(get_Δs_coupled(;
                                     do_combine = true))
                            for i in 1:1000),
                            expected_contribution;
                            rtol = 5e-2)
                    end
                    # Check scalarize
                    Δs_coupled2 = StochasticAD.couple(FIs,
                        StochasticAD.scalarize(Δs_coupled;
                            out_rep = (Δ_coupled,
                                Δ_coupled)),
                        out_rep = (Δ_coupled, Δ_coupled))
                    @test derivative_contribution(map(mapfunc, Δs_coupled;
                        deriv = (δ -> 1.0),
                        out_rep = 0.0)) ≈
                          derivative_contribution(map(mapfunc, Δs_coupled2;
                        deriv = (δ -> 1.0),
                        out_rep = 0.0))
                end
            end
        end
    end
end

@testset "Getting information about stochastic triples" begin
    for backend in vcat(backends,
        backends_smoothed)
        Random.seed!(4321)
        f(x) = rand(Bernoulli(x)) + x
        st = stochastic_triple(f, 0.5; backend)
        # Expected: 0.5 + 1.0ε + (1.0 with probability 2.0ε)
        dual = ForwardDiff.Dual(0.5, 1.0)

        @test StochasticAD.value(0.5) == 0.5
        @test StochasticAD.value(st) == 0.5
        @test StochasticAD.value(dual) == 0.5

        @test iszero(StochasticAD.delta(0.5))
        @test StochasticAD.delta(st) == 1.0
        @test StochasticAD.delta(dual) == 1.0

        if !(backend in backends_smoothed)
            #= 
            NB: since the implementation of perturbations can be backend-specific, the
            below property need not hold in general, but does for the current non-smoothed backends.
            =#
            p = only(perturbations(st))
            @test p.Δ == 1 && p.weight == 2.0
            @test derivative_contribution(st) == 3.0
        else
            # Since smoothed algorithm uses the two-sided strategy, we get a different derivative contribution.
            @test derivative_contribution(st) == 2.0
        end

        @test StochasticAD.tag(st) === StochasticAD.Tag{typeof(f), Float64}
        @test StochasticAD.valtype(st) === Float64
        @test StochasticAD.valtype(st.Δs) === Float64
    end
end

@testset "Propagation via StochasticAD.propagate" begin
    for backend in backends
        function form_triple(primal, δ, Δ, Δs_base)
            Δs = map(_Δ -> Δ, Δs_base)
            return StochasticAD.StochasticTriple{0}(primal, δ, Δs)
        end

        function test_propagate(f, primals, Δs; test_deltas = false)
            Δs_base = StochasticAD.similar_new(StochasticAD.create_Δs(backend, Int),
                0, 1.0)
            _form_triple(x, δ, Δ) = form_triple(x, δ, Δ, Δs_base)
            out = f(primals...)
            out_Δ_expected = StochasticAD.structural_map(-,
                f(StochasticAD.structural_map(+,
                    primals,
                    Δs)...),
                f(primals...))
            if test_deltas
                duals = StochasticAD.structural_map(primals) do x
                    x isa AbstractFloat ? ForwardDiff.Dual{0}(x, rand(typeof(x))) : x
                end
                δs = StochasticAD.structural_map(StochasticAD.delta, duals)
                out_δ_expected = StochasticAD.structural_map(StochasticAD.delta,
                    f(duals...))
            else
                δs = StochasticAD.structural_map(zero, primals)
                out_δ_expected = StochasticAD.structural_map(zero, out)
            end
            input_sts = StochasticAD.structural_map(_form_triple, primals, δs, Δs)
            out_st = StochasticAD.propagate(f, input_sts...; keep_deltas = Val{test_deltas})
            # Test type
            StochasticAD.structural_map(out_st, out, out_δ_expected,
                out_Δ_expected) do x_st, x, δ, Δ
                @test x_st isa StochasticAD.StochasticTriple{0, typeof(x)}
                @test StochasticAD.value(x_st) == x
                @test StochasticAD.delta(x_st) ≈ δ
                p = only(perturbations(x_st))
                @test p.Δ == Δ && p.weight == 1.0
            end
        end

        #=
        Test propagation through some simple functions. 
            f1: a simple if statement.
            f2: involves array-containing-fucntor input and output.
            f3: involves array-containing-functor input, but real output.
            f4: length ∘ repr (real or array input, real output).
            f5: mutates input array! Broken since unsupported.
            f6: the first-arg (blob) should just be passed through without attempting
                to perturb. Broken since unsupported.
            f7: involves matrix-containing-functor input and output.
        =#
        function f1(x)
            if x == 0
                return 1
            elseif x == 3
                return 2
            else
                return 5
            end
        end

        @test StochasticAD.propagate(f1, 0) === f1(0)
        for (primal, Δ) in [(0, 3), (0, 4), (3, -1)]
            test_propagate(f1, (primal,), (Δ,))
        end

        function f2(arr, scalar)
            if sum(arr) + scalar <= 5
                return arr .* scalar, sum(arr) * scalar
            else
                return arr .- scalar, sum(arr) - scalar
            end
        end
        f3(arr, scalar) = f2(arr, scalar)[2]

        primals1 = ([1, 1], 2)
        Δs1 = ([2, 3], 5)
        primals2 = ([1, 2], 1)
        Δs2 = ([1, -2], 1)
        primals3 = ([5, 2], -1)
        Δs3 = ([-3, 1], 0)

        for (primals, Δs) in [(primals1, Δs1), (primals2, Δs2), (primals3, Δs3)]
            for test_deltas in (false, true)
                if test_deltas
                    primals = StochasticAD.structural_map(float, primals)
                    Δs = StochasticAD.structural_map(float, Δs)
                end
                test_propagate(f2, primals, Δs; test_deltas)
                test_propagate(f3, primals, Δs; test_deltas)
            end
        end

        f4(x) = Base.length(repr(x))

        for (primals, Δs) in [(2, 11), (([3, 14],), ([14, -152],))]
            test_propagate(f4, primals, Δs)
        end

        function f5(arr)
            if arr == [1, 2]
                arr .+= 1
            else
                arr .-= 1
            end
        end

        # Tests for f6 skipped (would break)
        for (primals, Δs) in [([1, 2], [1, -1]), ([2, 4], [-1, -2]), ([2, 4], [-1, -1])]
            @test_skip "propagate f5"
            # test_propagate(f5, primals, Δs)
        end

        f6(blob, arr) = blob, f5(arr)

        # Tests for f6 missing (would break)
        @test_skip "propagate f6"

        function f7(mat, scalar)
            return mat * scalar, scalar + sum(mat)
        end

        test_propagate(f7, (rand(2, 2), 4.0), (rand(2, 2), 1.0); test_deltas = true)
    end
end

@testset "zero'ing of Inf/NaN (#79)" begin
    st = stochastic_triple(0.5)
    st_zero = zero(1 / zero(st))
    @test iszero(StochasticAD.value(st_zero))
    @test iszero(StochasticAD.delta(st_zero))
end

@testset "smooth_triple" begin
    f(p) = sum(rand(Bernoulli(p)) * i for i in 1:100)
    f2(p) = sum(smooth_triple(rand(Bernoulli(p))) * i for i in 1:100)
    p = 0.6
    f_est = mean(derivative_estimate(f, p) for i in 1:10000)
    f2_est = mean(derivative_estimate(f2, p) for i in 1:10000)
    @test f_est≈f2_est rtol=5e-2
end

@testset "No unnecessary float promotion" begin
    f(p) = rand(Bernoulli(p))^2
    st = stochastic_triple(f, 0.5)
    @test StochasticAD.valtype(st) == typeof(convert(Signed, f(0.5)))
end


================================================
FILE: tutorials/Project.toml
================================================
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"


================================================
FILE: tutorials/README.md
================================================
The raw source code for our tutorials. You likely want to look at the documentation instead, where they are presented more clearly! 


================================================
FILE: tutorials/game_of_life/core.jl
================================================
module GoLCore

using Random
using Distributions
using LinearAlgebra
using StochasticAD
using StaticArrays
using OffsetArrays

function update_state!(all_probs, N, board, board_old)
    for i in (-N):N
        for j in (-N):N
            neighbours = board_old[i + 1, j] + board_old[i - 1, j] + board_old[i, j - 1] +
                         board_old[i, j + 1]
            index = board[i, j] * 5 + neighbours + 1 # trick necessary because we do not have implementation support for stochastic triple not <: Real
            b = rand(Bernoulli(all_probs[index]))
            board[i, j] += (1 - 2 * board[i, j]) * b
        end
    end
end

function play_game_of_life(p, all_probs, N, T, log = false)
    dual_type = promote_type(typeof(rand(Bernoulli(p))),
        typeof.(rand.(Bernoulli.(all_probs)))...) # TODO: better way of getting the concrete type
    board = OffsetArray(zeros(dual_type, 2 * N + 3, 2 * N + 3), (-(N + 1)):(N + 1),
        (-(N + 1)):(N + 1)) # pad by 1
    for i in (-N):N
        for j in (-N):N
            board[i, j] = rand(Bernoulli(p))
        end
    end
    board_old = similar(board)
    log && (history = [])
    for time_step in 1:T
        copy!(board_old, board)
        update_state!(all_probs, N, board, board_old)
        log && push!(history, copy(board))
    end
    if !log
        return sum(board)
    else
        return sum(board), board, history
    end
end

function play(p, θ = 0.1, N = 3, T = 3; log = false)
    # N is the board half-length, T are game time steps
    low = θ
    high = 1 - θ
    birth_probs = SA[low, low, low, high, low] # 0, 1, 2, 3, 4 neighbours
    death_probs = SA[high, high, low, low, high] # 0, 1, 2, 3, 4 neighbours
    return play_game_of_life(p, vcat(birth_probs, death_probs), N, T, log)
end

# An implementation of finite differences that uses "common random numbers"
# (the same seed), for more accurate checking, albeit with a finite step size h
# such that there is weight degeneracy as h → 0.
function fd_clever(p, h = 0.01)
    state = copy(Random.default_rng())
    run1 = play(p + h)
    copy!(Random.default_rng(), state)
    run2 = play(p - h)
    (run1 - run2) / (2h)
end

# Provide some default parameters
p = 0.5
nsamples = 200_000

end


================================================
FILE: tutorials/game_of_life/plot_board.jl
================================================
include("core.jl")
using Plots
using Statistics
using BenchmarkTools

p = 0.5
_, board, history = stochastic_triple(p -> GoLCore.play(p; log = true), p)

anim1 = @animate for (i, board) in enumerate(history)
    heatmap(collect(StochasticAD.value.(board)), title = "time $i", clim = (-1, 1),
        c = :grays)
end
anim2 = @animate for (i, board) in enumerate(history)
    heatmap(collect(StochasticAD.derivative_contribution.(board)), title = "time $i",
        clim = (-1, 1), c = :grays)
end

gif(anim1, "game.gif", fps = 15)
gif(anim2, "perturbation.gif", fps = 15)
fig1 = heatmap(collect(StochasticAD.value.(board)), clim = (-1, 1), c = :grays)
fig2 = heatmap(collect(derivative_contribution.(board)), clim = (-1, 1), c = :grays) # TODO: graph perturbed values instead of derivative contribution
savefig(fig1, "board.png")
savefig(fig2, "perturbation.png")


================================================
FILE: tutorials/particle_filter/benchmark.jl
================================================
include("core.jl")
include("model.jl")
using Plots, LaTeXStrings
using BenchmarkTools
using Measurements

# Benchmark for primal, forward- and reverse-mode AD of particle sampler

### compute gradients
# secs for how long the benchmark should run, see https://juliaci.github.io/BenchmarkTools.jl/stable/
secs = 10

suite = BenchmarkGroup()
suite["scaling"] = BenchmarkGroup(["grads"])

suite["scaling"]["primal"] = @benchmarkable ParticleFilterCore.log_likelihood(
    particle_filter,
    θtrue)
suite["scaling"]["forward"] = @benchmarkable ParticleFilterCore.forw_grad(θtrue,
    particle_filter)
suite["scaling"]["backward"] = @benchmarkable ParticleFilterCore.back_grad(θtrue,
    particle_filter)

tune!(suite)
results = run(suite, verbose = true, seconds = secs)

t1 = measurement(mean(results["scaling"]["primal"].times),
    std(results["scaling"]["primal"].times) /
    sqrt(length(results["scaling"]["primal"].times)))
t2 = measurement(mean(results["scaling"]["forward"].times),
    std(results["scaling"]["forward"].times) /
    sqrt(length(results["scaling"]["forward"].times)))
t3 = measurement(mean(results["scaling"]["backward"].times),
    std(results["scaling"]["backward"].times) /
    sqrt(length(results["scaling"]["backward"].times)))
@show t1 t2 t3

ts = (t1, t2, t3) ./ 10^6 # ms
@show ts

BenchmarkTools.save("benchmark_data_" * string(d) * ".json", results)


================================================
FILE: tutorials/particle_filter/bias.jl
================================================
include("core.jl")
include("model.jl")
using Plots, LaTeXStrings
using Random

# Comparison of the derivative of the particle filter with and without differentiating the resampling step.

### compute gradients
Random.seed!(seed)
X = [ParticleFilterCore.forw_grad(θtrue, particle_filter) for i in 1:1000] # gradient of the particle filter *with* differentiation of the resampling step
Random.seed!(seed)
Xbiased = [ParticleFilterCore.forw_grad_biased(θtrue, particle_filter) for i in 1:1000] # Gradient of the particle filter *without* differentiation of the resampling step
# pick an arbitrary coordinate
index = 1 # take derivative with respect to first parameter (2-dimensional example has a rotation matrix with four parameters in total)
# plot histograms for the sampled derivative values
fig = plot(normalize(fit(Histogram, getindex.(X, index), nbins = 50), mode = :pdf),
    legend = false) # ours
plot!(normalize(fit(Histogram, getindex.(Xbiased, index), nbins = 50), mode = :pdf)) # biased
vline!([mean(X)[index]], color = 1)
vline!([mean(Xbiased)[index]], color = 2)
# add derivative of differentiable Kalman filter as a comparison
XK = ParticleFilterCore.forw_grad_Kalman(θtrue, kalman_filter)
vline!([XK[index]], color = "black")

display(fig)
savefig(fig, "tails.pdf")


================================================
FILE: tutorials/particle_filter/core.jl
================================================
module ParticleFilterCore

# load dependencies
using Distributions
using DistributionsAD
using Random
using Statistics
using StatsBase
using LinearAlgebra
using Zygote
using StochasticAD
using ForwardDiff
using GaussianDistributions
using GaussianDistributions: correct, ⊕
using UnPack

### Particle Filter Functions

# Model defs

"""
    StochasticModel{dType<:Integer,TType<:Integer,T1,T2,T3}

For parameters `θ`,  `rand(start(θ))` gives a sample from the prior distribution of the
starting distribution. For current state `x` and parameters `θ`, `xnew = rand(dyn(x, θ))`
samples the new state (i.e. `dyn` gives for each `x, θ` a distribution-like object). Finally,
`y = rand(obs(x, θ))` samples an observation.

## Constructor

- `T`: total number of time steps.
- `start`: starting distribution for the initial state. For example, in the form of a narrow
   Gaussian `start(θ) = Gaussian(x0, 0.001 * I(d))`.
- `dyn`: pointwise differentiable stochastic program in the form of Markov transition densities.
   For example, `dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q(θ))`, where `Q(θ)` denotes the
   covariance matrix.
- `obs`: observation model having a smooth conditional probability density depending on
   current state `x` and parameters `θ`. For example, `obs(x, θ) = MvNormal(x, R(θ))`,
   where `R(θ)` denotes the covariance matrix.
"""
struct StochasticModel{TType <: Integer, T1, T2, T3}
    T::TType # time steps
    start::T1 # prior
    dyn::T2 # dynamical model
    obs::T3 # observation model
end

# Particle filter
"""

    ParticleFilter{mType<:Integer,MType<:StochasticModel,yType,sType}

Wraps a stochastic model `StochM::StochasticModel` and observational data `ys`.
Assumes a observation-likelihood is available via `pdf(obs(x, θ), y)`.

## Constructor

- `m`: number of particles.
- `StochM`: stochastic model of type `StochasticModel`.
- `ys`: observations.
- `sample_strategy`: strategy for the resampling step of the particle filter. For example,
  stratified sampling as implemented in `sample_stratified`.
"""
struct ParticleFilter{mType <: Integer, MType <: StochasticModel, yType, sType}
    m::mType # number of particles
    StochM::MType # stochastic model
    ys::yType # observations
    sample_strategy::sType # sampling function
end

# Kalman filter
"""

    KalmanFilter{dType<:Integer,MType<:StochasticModel,HType,RType,QType,yType}

Differentiable Kalman filter following https://github.com/mschauer/Kalman.jl/blob/master/README.md.
Wraps a stochastic mode
Download .txt
gitextract__gvecf0r/

├── .JuliaFormatter.toml
├── .git-blame-ignore-revs
├── .github/
│   └── workflows/
│       ├── CI.yml
│       ├── CompatHelper.yml
│       ├── Documentation.yml
│       ├── FormatCheck.yml
│       ├── TagBot.yml
│       └── benchmark.yml
├── .gitignore
├── CITATION.bib
├── LICENSE
├── Project.toml
├── README.md
├── benchmark/
│   ├── benchmarks.jl
│   ├── game_of_life.jl
│   ├── iteration.jl
│   ├── random_walk.jl
│   ├── runbenchmarks.jl
│   ├── simple_ops.jl
│   └── utils.jl
├── docs/
│   ├── Project.toml
│   ├── make.jl
│   └── src/
│       ├── assets/
│       │   └── extra_styles.css
│       ├── devdocs.md
│       ├── index.md
│       ├── limitations.md
│       ├── public_api.md
│       └── tutorials/
│           ├── game_of_life.md
│           ├── optimizations.md
│           ├── particle_filter.md
│           ├── random_walk.md
│           └── reverse_demo.md
├── ext/
│   └── StochasticADEnzymeExt.jl
├── src/
│   ├── StochasticAD.jl
│   ├── algorithms.jl
│   ├── backends/
│   │   ├── abstract_wrapper.jl
│   │   ├── dict.jl
│   │   ├── pruned.jl
│   │   ├── pruned_aggressive.jl
│   │   ├── smoothed.jl
│   │   └── strategy_wrapper.jl
│   ├── discrete_randomness.jl
│   ├── finite_infinitesimals.jl
│   ├── general_rules.jl
│   ├── misc.jl
│   ├── prelude.jl
│   ├── propagate.jl
│   ├── smoothing.jl
│   └── stochastic_triple.jl
├── test/
│   ├── game_of_life.jl
│   ├── random_walk.jl
│   ├── resampling.jl
│   ├── runtests.jl
│   └── triples.jl
└── tutorials/
    ├── Project.toml
    ├── README.md
    ├── game_of_life/
    │   ├── core.jl
    │   └── plot_board.jl
    ├── particle_filter/
    │   ├── benchmark.jl
    │   ├── bias.jl
    │   ├── core.jl
    │   ├── model.jl
    │   ├── variance.jl
    │   └── visualize.jl
    ├── random_walk/
    │   ├── compare_score.jl
    │   ├── core.jl
    │   └── show_unbiased.jl
    ├── reverse_example/
    │   └── reverse_demo.jl
    └── toy_optimizations/
        ├── Project.toml
        ├── igarch.jl
        ├── intro.jl
        └── variational.jl
Condensed preview — 72 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (236K chars).
[
  {
    "path": ".JuliaFormatter.toml",
    "chars": 16,
    "preview": "style = \"sciml\"\n"
  },
  {
    "path": ".git-blame-ignore-revs",
    "chars": 376,
    "preview": "# Run this command to always ignore these in local `git blame`:\n# git config blame.ignoreRevsFile .git-blame-ignore-revs"
  },
  {
    "path": ".github/workflows/CI.yml",
    "chars": 602,
    "preview": "name: CI\non:\n  pull_request:\n  push:\n    branches:\n      - main \n    tags: '*'\njobs:\n  unittest:\n    runs-on: ubuntu-lat"
  },
  {
    "path": ".github/workflows/CompatHelper.yml",
    "chars": 457,
    "preview": "name: CompatHelper\non:\n  schedule:\n    - cron: 0 0 * * *\n  workflow_dispatch:\njobs:\n  CompatHelper:\n    runs-on: ubuntu-"
  },
  {
    "path": ".github/workflows/Documentation.yml",
    "chars": 727,
    "preview": "name: Documentation\n\non:\n  push:\n    branches:\n      - main\n    tags: '*'\n  pull_request:\n\njobs:\n  build:\n    runs-on: u"
  },
  {
    "path": ".github/workflows/FormatCheck.yml",
    "chars": 1130,
    "preview": "name: format-check\n\non:\n  push:\n    branches:\n      - 'main'\n      - 'release-'\n    tags: '*'\n  pull_request:\n\njobs:\n  b"
  },
  {
    "path": ".github/workflows/TagBot.yml",
    "chars": 362,
    "preview": "name: TagBot\non:\n  issue_comment:\n    types:\n      - created\n  workflow_dispatch:\njobs:\n  TagBot:\n    if: github.event_n"
  },
  {
    "path": ".github/workflows/benchmark.yml",
    "chars": 525,
    "preview": "\nname: Benchmarks\n\non:\n  pull_request:\n  push:\n    branches:\n      - main \n    tags: '*'\n\njobs:\n  benchmark:\n    runs-on"
  },
  {
    "path": ".gitignore",
    "chars": 13,
    "preview": "Manifest.toml"
  },
  {
    "path": "CITATION.bib",
    "chars": 583,
    "preview": "@inproceedings{arya2022automatic,\n author = {Arya, Gaurav and Schauer, Moritz and Sch\\\"{a}fer, Frank and Rackauckas, Chr"
  },
  {
    "path": "LICENSE",
    "chars": 1101,
    "preview": "MIT License\n\nCopyright (c) 2022 Gaurav Arya <aryag@mit.edu> and contributors\n\nPermission is hereby granted, free of char"
  },
  {
    "path": "Project.toml",
    "chars": 2053,
    "preview": "name = \"StochasticAD\"\nuuid = \"e4facb34-4f7e-4bec-b153-e122c37934ac\"\nauthors = [\"Gaurav Arya <aryag@mit.edu> and contribu"
  },
  {
    "path": "README.md",
    "chars": 1636,
    "preview": "![](docs/src/images/path_skeleton.png#gh-light-mode-only)\n![](docs/src/images/path_skeleton_dark.png#gh-dark-mode-only)\n"
  },
  {
    "path": "benchmark/benchmarks.jl",
    "chars": 341,
    "preview": "using BenchmarkTools\n\ninclude(\"random_walk.jl\")\ninclude(\"game_of_life.jl\")\ninclude(\"iteration.jl\")\ninclude(\"simple_ops.j"
  },
  {
    "path": "benchmark/game_of_life.jl",
    "chars": 588,
    "preview": "module GoLBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\nusing Statistics\nusing ForwardDiff: derivative\ninclude(\".."
  },
  {
    "path": "benchmark/iteration.jl",
    "chars": 2530,
    "preview": "\"\"\"\nIn the library we have tried to avoid generated functions, instead reductions from base with the\nhope that the itera"
  },
  {
    "path": "benchmark/random_walk.jl",
    "chars": 781,
    "preview": "module RandomWalkBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\nusing Statistics\nusing ForwardDiff: derivative\nincl"
  },
  {
    "path": "benchmark/runbenchmarks.jl",
    "chars": 295,
    "preview": "using PkgBenchmark\n\ninclude(\"utils.jl\")\nusing .Utils\n\nresults = benchmarkpkg(dirname(@__DIR__),\n    BenchmarkConfig(env "
  },
  {
    "path": "benchmark/simple_ops.jl",
    "chars": 1117,
    "preview": "module SimpleOpsBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\n\nconst suite = BenchmarkGroup()\n\nsuite[\"add\"] = Benc"
  },
  {
    "path": "benchmark/utils.jl",
    "chars": 487,
    "preview": "module Utils\n\nexport print_group\n\nusing Functors\nusing BenchmarkTools\n\n## Printing\n\n# Type piracy, fine since just in be"
  },
  {
    "path": "docs/Project.toml",
    "chars": 274,
    "preview": "[deps]\nDistributions = \"31c24e10-a181-5473-b8eb-7969acd0382f\"\nDocThemeIndigo = \"8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f\"\nDo"
  },
  {
    "path": "docs/make.jl",
    "chars": 1273,
    "preview": "using Pkg\n\nusing Documenter\nusing StochasticAD\nusing DocThemeIndigo\nusing Literate\n\n### Formatting\n\nindigo = DocThemeInd"
  },
  {
    "path": "docs/src/assets/extra_styles.css",
    "chars": 195,
    "preview": ".display-light-only {display: block;}\n.display-dark-only {display: none;}\n.theme--documenter-dark .display-light-only {d"
  },
  {
    "path": "docs/src/devdocs.md",
    "chars": 3982,
    "preview": "# Developer documentation (WIP)\n\n## Writing a custom rule for stochastic triples\n\n### via `StochasticAD.propagate`\n\nTo h"
  },
  {
    "path": "docs/src/index.md",
    "chars": 3758,
    "preview": "```@raw html\n<img class=\"display-light-only\" src=\"images/path_skeleton.png\">\n<img class=\"display-dark-only\" src=\"images/"
  },
  {
    "path": "docs/src/limitations.md",
    "chars": 1711,
    "preview": "# Limitations of StochasticAD\n\n`StochasticAD` has a number of limitations that are important to be aware of:\n\n* `Stochas"
  },
  {
    "path": "docs/src/public_api.md",
    "chars": 2155,
    "preview": "# API walkthrough\n \nThe function [`derivative_estimate`](@ref) transforms a stochastic program containing discrete rando"
  },
  {
    "path": "docs/src/tutorials/game_of_life.md",
    "chars": 4180,
    "preview": "# Stochastic Game of Life\n\nWe consider a stochastic version of [Conway's Game of Life](https://en.wikipedia.org/wiki/Con"
  },
  {
    "path": "docs/src/tutorials/optimizations.md",
    "chars": 4757,
    "preview": "# Stochastic optimizations with discrete randomness\n\n```@setup random_walk\nimport Pkg\nPkg.activate(\"../../../tutorials/t"
  },
  {
    "path": "docs/src/tutorials/particle_filter.md",
    "chars": 16944,
    "preview": "# Differentiable particle filter\n\nUsing a bootstrap particle sampler, we can approximate the posterior distributions\nof "
  },
  {
    "path": "docs/src/tutorials/random_walk.md",
    "chars": 4024,
    "preview": "# Random walk\n\n```@setup random_walk\nimport Pkg\nPkg.activate(\"../../../tutorials\")\nPkg.develop(path=\"../../..\")\nPkg.inst"
  },
  {
    "path": "docs/src/tutorials/reverse_demo.md",
    "chars": 2355,
    "preview": "```@meta\nEditURL = \"../../../tutorials/reverse_example/reverse_demo.jl\"\n```\n\n# Simple reverse mode example\n\n```@setup ra"
  },
  {
    "path": "ext/StochasticADEnzymeExt.jl",
    "chars": 1490,
    "preview": "module StochasticADEnzymeExt\n\nusing StochasticAD\nusing Enzyme\n\nfunction enzyme_target(u, X, p, backend)\n    # equivalent"
  },
  {
    "path": "src/StochasticAD.jl",
    "chars": 2161,
    "preview": "module StochasticAD\n\n### Public API\n\nexport stochastic_triple, derivative_contribution, perturbations, smooth_triple,\n  "
  },
  {
    "path": "src/algorithms.jl",
    "chars": 4957,
    "preview": "abstract type AbstractStochasticADAlgorithm end\n\n\"\"\"\n    ForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: A"
  },
  {
    "path": "src/backends/abstract_wrapper.jl",
    "chars": 4766,
    "preview": "module AbstractWrapperFIsModule\n\nimport ..StochasticAD\n\nexport AbstractWrapperFIs\n\n\"\"\"\n    AbstractWrapperFIs{V, FIs} <:"
  },
  {
    "path": "src/backends/dict.jl",
    "chars": 5451,
    "preview": "module DictFIsModule\n\nexport DictFIsBackend, DictFIs\n\nimport ..StochasticAD\nusing Dictionaries\n\n\"\"\"\n    DictFIsBackend <"
  },
  {
    "path": "src/backends/pruned.jl",
    "chars": 9481,
    "preview": "module PrunedFIsModule\n\nimport ..StochasticAD\n\nexport PrunedFIsBackend, PrunedFIs\n\n\"\"\"\n    PrunedFIsBackend <: Stochasti"
  },
  {
    "path": "src/backends/pruned_aggressive.jl",
    "chars": 5530,
    "preview": "module PrunedFIsAggressiveModule\n\nimport ..StochasticAD\n\nexport PrunedFIsAggressiveBackend, PrunedFIsAggressive\n\n\"\"\"\n   "
  },
  {
    "path": "src/backends/smoothed.jl",
    "chars": 2948,
    "preview": "module SmoothedFIsModule\n\nimport ..StochasticAD\n\nexport SmoothedFIsBackend, SmoothedFIs\n\n\"\"\"\n    SmoothedFIsBackend <: S"
  },
  {
    "path": "src/backends/strategy_wrapper.jl",
    "chars": 1318,
    "preview": "module StrategyWrapperFIsModule\n\nusing ..StochasticAD\nusing ..StochasticAD.AbstractWrapperFIsModule\n\nexport StrategyWrap"
  },
  {
    "path": "src/discrete_randomness.jl",
    "chars": 16620,
    "preview": "## Helper functions for discrete distributions \n\n# index of the parameter p\n_param_index(::Geometric) = 1\n_param_index(:"
  },
  {
    "path": "src/finite_infinitesimals.jl",
    "chars": 2237,
    "preview": "# TODO: make this a module, with the interface exported?\n\n## \n\"\"\"\n    AbstractFIsBackend\n\nAn abstract type for backend s"
  },
  {
    "path": "src/general_rules.jl",
    "chars": 11268,
    "preview": "\"\"\"\nOperators which have already been overloaded by StochasticAD. \n\"\"\"\nconst handled_ops = Tuple{DataType, Int}[]\n\n\"\"\"\n "
  },
  {
    "path": "src/misc.jl",
    "chars": 632,
    "preview": "@doc raw\"\"\"\n    StochasticModel(X, p)\n\nCombine stochastic program `X` with parameter `p` into \na trainable model using ["
  },
  {
    "path": "src/prelude.jl",
    "chars": 1834,
    "preview": "const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)\n\nconst UNARY_PREDICATES = [is"
  },
  {
    "path": "src/propagate.jl",
    "chars": 5000,
    "preview": "\"\"\"\nA version of `value`` that allows unrecognized args to pass through. \n\"\"\"\nfunction get_value(arg)\n    if arg isa Sto"
  },
  {
    "path": "src/smoothing.jl",
    "chars": 4195,
    "preview": "### Particle resampling\n\n@doc raw\"\"\"\n    new_weight(p::Real)\n\n    Simulate a Bernoulli variable whose primal output is a"
  },
  {
    "path": "src/stochastic_triple.jl",
    "chars": 10323,
    "preview": "\"\"\" \n    StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}}\n\nStores the primal value of the computation, alongside a "
  },
  {
    "path": "test/game_of_life.jl",
    "chars": 375,
    "preview": "using StochasticAD\nusing Test\nusing Statistics\n\ninclude(\"../tutorials/game_of_life/core.jl\")\nusing .GoLCore: fd_clever, "
  },
  {
    "path": "test/random_walk.jl",
    "chars": 423,
    "preview": "using StochasticAD\nusing Test\nusing Statistics\nusing ForwardDiff: derivative\n\ninclude(\"../tutorials/random_walk/core.jl\""
  },
  {
    "path": "test/resampling.jl",
    "chars": 2144,
    "preview": "using StochasticAD\nusing Random, Test\nusing Distributions\nusing LinearAlgebra\nusing ForwardDiff\n\n# test forward-mode AD "
  },
  {
    "path": "test/runtests.jl",
    "chars": 599,
    "preview": "using SafeTestsets\nusing Test, Pkg\nimport Random\n\nRandom.seed!(1234)\n\nconst GROUP = get(ENV, \"GROUP\", \"All\")\nconst is_AP"
  },
  {
    "path": "test/triples.jl",
    "chars": 27352,
    "preview": "using StochasticAD\nusing Test\nusing Distributions\nusing ForwardDiff\nusing OffsetArrays\nusing ChainRulesCore\nusing Random"
  },
  {
    "path": "tutorials/Project.toml",
    "chars": 1214,
    "preview": "[deps]\nBenchmarkTools = \"6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf\"\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nD"
  },
  {
    "path": "tutorials/README.md",
    "chars": 133,
    "preview": "The raw source code for our tutorials. You likely want to look at the documentation instead, where they are presented mo"
  },
  {
    "path": "tutorials/game_of_life/core.jl",
    "chars": 2234,
    "preview": "module GoLCore\n\nusing Random\nusing Distributions\nusing LinearAlgebra\nusing StochasticAD\nusing StaticArrays\nusing OffsetA"
  },
  {
    "path": "tutorials/game_of_life/plot_board.jl",
    "chars": 863,
    "preview": "include(\"core.jl\")\nusing Plots\nusing Statistics\nusing BenchmarkTools\n\np = 0.5\n_, board, history = stochastic_triple(p ->"
  },
  {
    "path": "tutorials/particle_filter/benchmark.jl",
    "chars": 1383,
    "preview": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing BenchmarkTools\nusing Measurements\n\n# Benchmark fo"
  },
  {
    "path": "tutorials/particle_filter/bias.jl",
    "chars": 1281,
    "preview": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing Random\n\n# Comparison of the derivative of the par"
  },
  {
    "path": "tutorials/particle_filter/core.jl",
    "chars": 9989,
    "preview": "module ParticleFilterCore\n\n# load dependencies\nusing Distributions\nusing DistributionsAD\nusing Random\nusing Statistics\nu"
  },
  {
    "path": "tutorials/particle_filter/model.jl",
    "chars": 1897,
    "preview": "# ParticleFilter Model\n\nusing Random, LinearAlgebra, GaussianDistributions, Distributions\n\n# particle filter core functi"
  },
  {
    "path": "tutorials/particle_filter/variance.jl",
    "chars": 3198,
    "preview": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing Random\n\nRandom.seed!(seed)\n# Comparison of the va"
  },
  {
    "path": "tutorials/particle_filter/visualize.jl",
    "chars": 875,
    "preview": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\n\n# visualization of stochastic process (observations an"
  },
  {
    "path": "tutorials/random_walk/compare_score.jl",
    "chars": 2292,
    "preview": "include(\"core.jl\")\nusing Plots, LaTeXStrings\nusing Statistics\nusing StochasticAD\nusing ForwardDiff: derivative\nusing Pro"
  },
  {
    "path": "tutorials/random_walk/core.jl",
    "chars": 2890,
    "preview": "module RandomWalkCore\n\nusing Random\nusing Statistics\nusing Distributions\nusing LinearAlgebra\nusing StochasticAD\nusing St"
  },
  {
    "path": "tutorials/random_walk/show_unbiased.jl",
    "chars": 2733,
    "preview": "include(\"core.jl\")\nprintln(\"## Exact computation\\n\")\n\nusing ForwardDiff: derivative\nusing BenchmarkTools\nusing .RandomWa"
  },
  {
    "path": "tutorials/reverse_example/reverse_demo.jl",
    "chars": 3701,
    "preview": "#text # Simple reverse mode example \n\n#text ```@setup random_walk\n#text import Pkg\n#text Pkg.activate(\"../../../tutorial"
  },
  {
    "path": "tutorials/toy_optimizations/Project.toml",
    "chars": 315,
    "preview": "[deps]\nCairoMakie = \"13f3f980-e62b-5c42-98c6-ff1f3baf88f0\"\nDistributions = \"31c24e10-a181-5473-b8eb-7969acd0382f\"\nOptimi"
  },
  {
    "path": "tutorials/toy_optimizations/igarch.jl",
    "chars": 2063,
    "preview": "# Poisson autoregression\ncd(@__DIR__)\nusing StochasticAD, Distributions\nusing Optimisers\nimport Random\nRandom.seed!(1234"
  },
  {
    "path": "tutorials/toy_optimizations/intro.jl",
    "chars": 1453,
    "preview": "# Toy expectation optimization problem \ncd(@__DIR__)\nusing StochasticAD, Distributions, Optimisers\nimport Random # hide\n"
  },
  {
    "path": "tutorials/toy_optimizations/variational.jl",
    "chars": 1618,
    "preview": "# Toy variational problem: Find Poisson(p) close to NegativeBinomial(10, 1-30/(10+30))\n# by minimization of the Kullback"
  }
]

About this extraction

This page contains the full source code of the gaurav-arya/StochasticAD.jl GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 72 files (217.3 KB), approximately 65.4k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!