Showing preview only (237K chars total). Download the full file or copy to clipboard to get everything.
Repository: gaurav-arya/StochasticAD.jl
Branch: main
Commit: 24788b5bc82d
Files: 72
Total size: 217.3 KB
Directory structure:
gitextract__gvecf0r/
├── .JuliaFormatter.toml
├── .git-blame-ignore-revs
├── .github/
│ └── workflows/
│ ├── CI.yml
│ ├── CompatHelper.yml
│ ├── Documentation.yml
│ ├── FormatCheck.yml
│ ├── TagBot.yml
│ └── benchmark.yml
├── .gitignore
├── CITATION.bib
├── LICENSE
├── Project.toml
├── README.md
├── benchmark/
│ ├── benchmarks.jl
│ ├── game_of_life.jl
│ ├── iteration.jl
│ ├── random_walk.jl
│ ├── runbenchmarks.jl
│ ├── simple_ops.jl
│ └── utils.jl
├── docs/
│ ├── Project.toml
│ ├── make.jl
│ └── src/
│ ├── assets/
│ │ └── extra_styles.css
│ ├── devdocs.md
│ ├── index.md
│ ├── limitations.md
│ ├── public_api.md
│ └── tutorials/
│ ├── game_of_life.md
│ ├── optimizations.md
│ ├── particle_filter.md
│ ├── random_walk.md
│ └── reverse_demo.md
├── ext/
│ └── StochasticADEnzymeExt.jl
├── src/
│ ├── StochasticAD.jl
│ ├── algorithms.jl
│ ├── backends/
│ │ ├── abstract_wrapper.jl
│ │ ├── dict.jl
│ │ ├── pruned.jl
│ │ ├── pruned_aggressive.jl
│ │ ├── smoothed.jl
│ │ └── strategy_wrapper.jl
│ ├── discrete_randomness.jl
│ ├── finite_infinitesimals.jl
│ ├── general_rules.jl
│ ├── misc.jl
│ ├── prelude.jl
│ ├── propagate.jl
│ ├── smoothing.jl
│ └── stochastic_triple.jl
├── test/
│ ├── game_of_life.jl
│ ├── random_walk.jl
│ ├── resampling.jl
│ ├── runtests.jl
│ └── triples.jl
└── tutorials/
├── Project.toml
├── README.md
├── game_of_life/
│ ├── core.jl
│ └── plot_board.jl
├── particle_filter/
│ ├── benchmark.jl
│ ├── bias.jl
│ ├── core.jl
│ ├── model.jl
│ ├── variance.jl
│ └── visualize.jl
├── random_walk/
│ ├── compare_score.jl
│ ├── core.jl
│ └── show_unbiased.jl
├── reverse_example/
│ └── reverse_demo.jl
└── toy_optimizations/
├── Project.toml
├── igarch.jl
├── intro.jl
└── variational.jl
================================================
FILE CONTENTS
================================================
================================================
FILE: .JuliaFormatter.toml
================================================
style = "sciml"
================================================
FILE: .git-blame-ignore-revs
================================================
# Run this command to always ignore these in local `git blame`:
# git config blame.ignoreRevsFile .git-blame-ignore-revs
# Run formatter
70fd432667fb431e08ba52728734108d822a1922
# Run formatter
21038a047c023330876feb9259cd5c92add3ca81
# Run formatter after bracket alignment removal
799277f9652258282a91ecfe976df5fb8ab64c82
# Format
db4333c604cc23c3c36420f09aa998d01ef0214b
================================================
FILE: .github/workflows/CI.yml
================================================
name: CI
on:
pull_request:
push:
branches:
- main
tags: '*'
jobs:
unittest:
runs-on: ubuntu-latest
strategy:
matrix:
group:
- Core
version:
- '1'
- '1.7'
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v2
with:
file: lcov.info
================================================
FILE: .github/workflows/CompatHelper.yml
================================================
name: CompatHelper
on:
schedule:
- cron: 0 0 * * *
workflow_dispatch:
jobs:
CompatHelper:
runs-on: ubuntu-latest
steps:
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
================================================
FILE: .github/workflows/Documentation.yml
================================================
name: Documentation
on:
push:
branches:
- main
tags: '*'
pull_request:
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: '1'
- name: Install dependencies
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
- name: Build and deploy
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key
DATADEPS_ALWAYS_ACCEPT: true
run: julia --project=docs/ docs/make.jl
================================================
FILE: .github/workflows/FormatCheck.yml
================================================
name: format-check
on:
push:
branches:
- 'main'
- 'release-'
tags: '*'
pull_request:
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1]
julia-arch: [x86]
os: [ubuntu-latest]
steps:
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- uses: actions/checkout@v1
- name: Install JuliaFormatter and format
# This will use the latest version by default but you can set the version like so:
#
# julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))'
run: |
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
julia -e 'using JuliaFormatter; format(".", verbose=true)'
- name: Format check
run: |
julia -e '
out = Cmd(`git diff`) |> read |> String
if out == ""
exit(0)
else
@error "Some files have not been formatted !!!"
write(stdout, out)
exit(1)
end'
================================================
FILE: .github/workflows/TagBot.yml
================================================
name: TagBot
on:
issue_comment:
types:
- created
workflow_dispatch:
jobs:
TagBot:
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
runs-on: ubuntu-latest
steps:
- uses: JuliaRegistries/TagBot@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
ssh: ${{ secrets.DOCUMENTER_KEY }}
================================================
FILE: .github/workflows/benchmark.yml
================================================
name: Benchmarks
on:
pull_request:
push:
branches:
- main
tags: '*'
jobs:
benchmark:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: 1
- name: Install dependencies
run: julia -e 'using Pkg; Pkg.activate("tutorials"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate();'
- name: Run benchmarks
run: julia --project=tutorials --color=yes benchmark/runbenchmarks.jl
================================================
FILE: .gitignore
================================================
Manifest.toml
================================================
FILE: CITATION.bib
================================================
@inproceedings{arya2022automatic,
author = {Arya, Gaurav and Schauer, Moritz and Sch\"{a}fer, Frank and Rackauckas, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
pages = {10435--10447},
publisher = {Curran Associates, Inc.},
title = {Automatic Differentiation of Programs with Discrete Randomness},
url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/43d8e5fc816c692f342493331d5e98fc-Paper-Conference.pdf},
volume = {35},
year = {2022}
}
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Gaurav Arya <aryag@mit.edu> and contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: Project.toml
================================================
name = "StochasticAD"
uuid = "e4facb34-4f7e-4bec-b153-e122c37934ac"
authors = ["Gaurav Arya <aryag@mit.edu> and contributors"]
version = "0.1.26"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
[extensions]
StochasticADEnzymeExt = "Enzyme"
[compat]
ChainRulesCore = "1.15"
ChainRulesOverloadGeneration = "0.1"
Dictionaries = "0.3"
Distributions = "0.25"
DistributionsAD = "0.6"
ExprTools = "0.1"
ForwardDiff = "0.10"
Functors = "0.4.3"
julia = "1"
[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[targets]
test = ["LinearAlgebra", "Pkg", "Printf", "Test", "Statistics", "SafeTestsets", "OffsetArrays", "StaticArrays", "Zygote", "ForwardDiff", "GaussianDistributions", "Measurements", "UnPack", "StatsBase", "DiffResults", "ChainRulesCore"]
================================================
FILE: README.md
================================================


# StochasticAD
[](https://github.com/gaurav-arya/StochasticAD.jl/actions?query=workflow:CI)
[](https://gaurav-arya.github.io/StochasticAD.jl/dev/)
[](https://arxiv.org/abs/2210.08572)
StochasticAD is an experimental, research package for automatic differentiation (AD) of stochastic programs. It implements AD algorithms for handling programs that can contain *discrete* randomness, based on the methodology developed in [this NeurIPS 2022 paper](https://doi.org/10.48550/arXiv.2210.08572). We're still working on docs and code cleanup!
## Installation
The package can be installed with the Julia package manager:
```julia
julia> using Pkg;
julia> Pkg.add("StochasticAD");
```
## Citation
```
@inproceedings{arya2022automatic,
author = {Arya, Gaurav and Schauer, Moritz and Sch\"{a}fer, Frank and Rackauckas, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
pages = {10435--10447},
publisher = {Curran Associates, Inc.},
title = {Automatic Differentiation of Programs with Discrete Randomness},
url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/43d8e5fc816c692f342493331d5e98fc-Paper-Conference.pdf},
volume = {35},
year = {2022}
}
```
================================================
FILE: benchmark/benchmarks.jl
================================================
using BenchmarkTools
include("random_walk.jl")
include("game_of_life.jl")
include("iteration.jl")
include("simple_ops.jl")
const SUITE = BenchmarkGroup()
SUITE["random_walk"] = RandomWalkBenchmark.suite
SUITE["game_of_life"] = GoLBenchmark.suite
SUITE["iteration"] = IterationBenchmark.suite
SUITE["simple_ops"] = SimpleOpsBenchmark.suite
================================================
FILE: benchmark/game_of_life.jl
================================================
module GoLBenchmark
using BenchmarkTools
using StochasticAD
using Statistics
using ForwardDiff: derivative
include("../tutorials/game_of_life/core.jl")
using .GoLCore: play, p
const suite = BenchmarkGroup()
suite["original"] = @benchmarkable $play($p)
suite["PrunedFIs"] = @benchmarkable derivative_estimate($play, $p;
backend = PrunedFIsBackend())
suite["PrunedFIsAggressive"] = @benchmarkable derivative_estimate($play, $p;
backend = PrunedFIsAggressiveBackend())
suite["SmoothedFIs"] = @benchmarkable derivative_estimate($play, $p;
backend = SmoothedFIsBackend())
end
================================================
FILE: benchmark/iteration.jl
================================================
"""
In the library we have tried to avoid generated functions, instead reductions from base with the
hope that the iteration will be optimized to be zero-cost.
This suite tests the performance of iteration on small nested structures, which crop up when `propagate` is called
on small structures of scalars.
The `couple` and `combine` operations of FIss, which use iteration, are benchmarked.
"""
module IterationBenchmark
using BenchmarkTools
using StochasticAD
using StaticArrays
const suite = BenchmarkGroup()
# Examples consist of flat and non-flat versions of structures, to test zero-cost iteration.
tups = Dict("easy" => (ntuple(identity, 3), (1, (2, 3))),
"hard" => (ntuple(identity, 9), (1, (2, 3), (4, (5, (6, 7, 8), 9)))))
SAs = Dict("easy" => (SA[1, 2, 3], (1, SA[2, 3])),
"hard" => (SA[1, 2, 3, 4, 5, 6, 7, 8, 9],
(1, SA[2, 3], (4, (5, SA[6, 7, 8], 9)))))
for (setname, set) in (("tups", tups), ("SAs", SAs))
suite[setname] = BenchmarkGroup()
setsuite = suite[setname]
for case in ["easy", "hard"]
casesuite = setsuite[case] = BenchmarkGroup()
for isflat in [false, true]
flatsuite = casesuite[isflat ? "flat" : "not flat"] = BenchmarkGroup()
values = set[case][isflat ? 1 : 2]
flatsuite["make_iterate_values"] = @benchmarkable StochasticAD.structural_iterate($values)
iter_values = StochasticAD.structural_iterate(values)
flatsuite["foldl_values"] = @benchmarkable foldl(+, $(iter_values))
flatsuite["iterate_values"] = @benchmarkable for i in $(iter_values)
end
for backend in [PrunedFIsBackend(), PrunedFIsAggressiveBackend()]
FIs_suite = flatsuite[backend] = BenchmarkGroup()
Δs = StochasticAD.create_Δs(backend, Int)
Δs1 = StochasticAD.similar_new(Δs, 1, 1)
Δs_all = StochasticAD.structural_map(x -> map(Δ -> x, Δs1), values)
FIs_suite["make_iterate_Δs"] = @benchmarkable StochasticAD.structural_iterate($Δs_all)
# We don't interpolate backend directly in below (i.e. do $FIs) because string interpolating a type
# seems to lead to slow benchmarks.
FIs_suite["couple_same"] = @benchmarkable StochasticAD.couple(typeof($Δs),
$Δs_all)
FIs_suite["combine_same"] = @benchmarkable StochasticAD.combine(
typeof($Δs),
$Δs_all)
end
end
end
end
end
================================================
FILE: benchmark/random_walk.jl
================================================
module RandomWalkBenchmark
using BenchmarkTools
using StochasticAD
using Statistics
using ForwardDiff: derivative
include("../tutorials/random_walk/core.jl")
using .RandomWalkCore: n, p, nsamples
using .RandomWalkCore: fX, get_dfX
const suite = BenchmarkGroup()
suite["original"] = @benchmarkable $(fX)($p)
suite["PrunedFIs"] = @benchmarkable derivative_estimate($fX, $p;
backend = PrunedFIsBackend())
suite["PrunedFIsAggressive"] = @benchmarkable derivative_estimate($fX, $p;
backend = PrunedFIsAggressiveBackend())
suite["SmoothedFIs"] = @benchmarkable derivative_estimate($fX, $p;
backend = SmoothedFIsBackend())
forwarddiff_func = p -> fX(p; hardcode_leftright_step = true)
suite["ForwardDiff_smoothing"] = @benchmarkable derivative($forwarddiff_func, $p)
end
================================================
FILE: benchmark/runbenchmarks.jl
================================================
using PkgBenchmark
include("utils.jl")
using .Utils
results = benchmarkpkg(dirname(@__DIR__),
BenchmarkConfig(env = Dict("JULIA_NUM_THREADS" => "1",
"OMP_NUM_THREADS" => "1")),
resultfile = joinpath(@__DIR__, "result.json"))
@show results = print_group(results.benchmarkgroup)
================================================
FILE: benchmark/simple_ops.jl
================================================
module SimpleOpsBenchmark
using BenchmarkTools
using StochasticAD
const suite = BenchmarkGroup()
suite["add"] = BenchmarkGroup()
suite["add_via_propagate_nodeltas"] = BenchmarkGroup()
suite["add_via_propagate"] = BenchmarkGroup()
suite["add"]["original"] = @benchmarkable +(0.5, 0.5)
suite["add_via_propagate_nodeltas"]["original"] = @benchmarkable StochasticAD.propagate(+,
0.5,
0.5)
suite["add_via_propagate"]["original"] = @benchmarkable StochasticAD.propagate(+, 0.5, 0.5;
keep_deltas = Val{
true,
})
for backend in [PrunedFIsBackend(), PrunedFIsAggressiveBackend()]
suite["add"][backend] = @benchmarkable +(st, st) setup=(st = stochastic_triple(0.5;
backend = $backend))
suite["add_via_propagate_nodeltas"][backend] = @benchmarkable StochasticAD.propagate(+,
st,
st) setup=(st = stochastic_triple(0.5;
backend = $backend))
suite["add_via_propagate"][backend] = @benchmarkable StochasticAD.propagate(+, st, st;
keep_deltas = Val{
true,
}) setup=(st = stochastic_triple(0.5;
backend = $backend))
end
end
================================================
FILE: benchmark/utils.jl
================================================
module Utils
export print_group
using Functors
using BenchmarkTools
## Printing
# Type piracy, fine since just in benchmarking. (design of Functors should probably allow for user-customized functors)
@functor BenchmarkTools.BenchmarkGroup
function print_trial(t)
ptime = BenchmarkTools.prettytime(time(t))
pallocs = "$(allocs(t)) allocs"
return "$ptime, $pallocs"
end
function print_group(b)
fmap(t -> (t isa BenchmarkTools.Trial ? print_trial(t) : t), b)
end
end
================================================
FILE: docs/Project.toml
================================================
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac"
================================================
FILE: docs/make.jl
================================================
using Pkg
using Documenter
using StochasticAD
using DocThemeIndigo
using Literate
### Formatting
indigo = DocThemeIndigo.install(StochasticAD)
format = Documenter.HTML(prettyurls = false,
assets = [indigo, "assets/extra_styles.css"],
repolink = "https://github.com/gaurav-arya/StochasticAD.jl",
edit_link = "main")
### Pagination
pages = [
"Overview" => "index.md",
"Tutorials" => [
"tutorials/random_walk.md",
"tutorials/game_of_life.md",
"tutorials/particle_filter.md",
"tutorials/optimizations.md",
"tutorials/reverse_demo.md"
],
"Public API" => "public_api.md",
"Developer documentation" => "devdocs.md",
"Limitations" => "limitations.md"
]
### Prepare literate tutorials
# TODO (for now they are manually built into docs/src/tutorials and checked into repo)
### Make docs
makedocs(sitename = "StochasticAD.jl",
authors = "Gaurav Arya and other contributors",
modules = [StochasticAD],
format = format,
pages = pages,
warnonly = [:missing_docs])
try
deploydocs(repo = "github.com/gaurav-arya/StochasticAD.jl",
devbranch = "main",
push_preview = true)
catch e
println("Error encountered while deploying docs:")
showerror(stdout, e)
end
================================================
FILE: docs/src/assets/extra_styles.css
================================================
.display-light-only {display: block;}
.display-dark-only {display: none;}
.theme--documenter-dark .display-light-only {display: none;}
.theme--documenter-dark .display-dark-only {display: block;}
================================================
FILE: docs/src/devdocs.md
================================================
# Developer documentation (WIP)
## Writing a custom rule for stochastic triples
### via `StochasticAD.propagate`
To handle a deterministic discrete construct that `StochasticAD` does not automatically handle (e.g. branching via `if`, boolean comparisons), it is often sufficient to simply add a dispatch rule that calls out to `StochasticAD.propagate`.
```@docs
StochasticAD.propagate
```
### via a custom dispatch
If a function does not meet the conditions of `StochasticAD.propagate` and is not already supported, a custom
dispatch may be necessary. For example, consider the following function which manually implements a geometric random variable:
```@example rule
import Random
Random.seed!(1234) # hide
using Distributions
# make rng input explicit
function mygeometric(rng, p)
x = 0
while !(rand(rng, Bernoulli(p)))
x += 1
end
return x
end
mygeometric(p) = mygeometric(Random.default_rng(), p)
```
This is equivalent to `rand(Geometric(p))` which is already supported, but for pedagogical purposes we will
implement our own rule from scratch. Using the stochastic derivative formulas from [Automatic Differentiation of Programs with Discrete Randomness](https://doi.org/10.48550/arXiv.2210.08572), the right stochastic derivative of this program is given by
```math
Y_R = X - 1, w_R = \frac{x}{p(1-p)},
```
and the left stochastic derivative of this program is given by
```math
Y_L = X + 1, w_L = -\frac{x+1}{p}.
```
Using these expressions, we can now write the dispatch rule for stochastic triples:
```@example rule
using StochasticAD
import StochasticAD: StochasticTriple, similar_new, similar_empty, combine
function mygeometric(rng, p_st::StochasticTriple{T}) where {T}
p = p_st.value
rng_copy = copy(rng) # save a copy for coupling later
x = mygeometric(rng, p)
# Form the new discrete perturbations (combinations of weight w and perturbation Y - X)
Δs1 = if p_st.δ > 0
# right stochastic derivative
w = p_st.δ * x / (p * (1 - p))
x > 0 ? similar_new(p_st.Δs, -1, w) : similar_empty(p_st.Δs, Int)
elseif p_st.δ < 0
# left stochastic derivative
w = -p_st.δ * (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L
similar_new(p_st.Δs, 1, w)
else
similar_empty(p_st.Δs, Int)
end
# Propagate any existing perturbations to p through the function
function map_func(Δ)
# Couple the samples by using the same RNG. (A simpler strategy would have been independent sampling, i.e. mygeometric(p + Δ) - x)
mygeometric(copy(rng_copy), p + Δ) - x
end
Δs2 = map(map_func, p_st.Δs)
# Return the output stochastic triple
StochasticTriple{T}(x, zero(x), combine((Δs2, Δs1)))
end
```
In the above, we used some of the interface functions supported by a collection of perturbations `Δs::StochasticAD.AbstractFIs`. These were `similar_empty(Δs, V)`, which created an empty perturbation of type `V`, `similar_new(Δs, Δ, w)`, which created a new perturbation of size `Δ` and weight `w`, `map(map_func, Δs)`,
which propagates a collection of perturbations through a mapping function, and `combine((Δs2, Δs1)))` which combines multiple collections of perturbations together.
We can test out our rule:
```@example rule
@show stochastic_triple(mygeometric, 0.1)
# try feeding an input that already has a pertrubation
f(x) = mygeometric(2 * x + 0.1 * rand(Bernoulli(x)))^2
@show stochastic_triple(f, 0.1)
# verify against black-box finite differences
N = 1000000
samples_stochad = [derivative_estimate(f, 0.1) for i in 1:N]
samples_fd = [(f(0.105) - f(0.095)) / 0.01 for i in 1:N]
println("Stochastic AD: $(mean(samples_stochad)) ± $(std(samples_stochad) / sqrt(N))")
println("Finite differences: $(mean(samples_fd)) ± $(std(samples_fd) / sqrt(N))")
nothing # hide
```
## Distribution-specific customization of differentiation algorithm
```@docs
randst
InversionMethodDerivativeCoupling
```
================================================
FILE: docs/src/index.md
================================================
```@raw html
<img class="display-light-only" src="images/path_skeleton.png">
<img class="display-dark-only" src="images/path_skeleton_dark.png">
```
# StochasticAD
[StochasticAD](https://github.com/gaurav-arya/StochasticAD.jl) is an experimental, research package for automatic differentiation (AD) of stochastic programs.
It implements AD algorithms for handling programs that can contain *discrete* randomness, based on the methodology developed in [this NeurIPS 2022 paper](https://doi.org/10.48550/arXiv.2210.08572).
## Introduction
Derivatives are all about how functions are affected by a tiny change `ε` in their input. First, let's imagine perturbing the input of a deterministic, differentiable function such as $f(p) = p^2$ at $p = 2$.
```@example continuous
using StochasticAD
f(p) = p^2
stochastic_triple(f, 2) # Feeds 2 + ε into f
```
The output tells us that if we change the input `2` by a tiny amount `ε`, the output of `f` will change by approximately `4ε`. This is the case we're familiar with: we can get the value `4` by applying the chain rule, $\frac{\mathrm{d}}{\mathrm{d} p} p^2 = 2p = 4$. Thinking in terms of tiny changes, the output above looks a lot like a [dual number](https://en.wikipedia.org/wiki/Dual_number). But what happens with a discrete random function? Let's give it a try.
```@example discrete
import Random # hide
Random.seed!(4321) # hide
using StochasticAD, Distributions
f(p) = rand(Bernoulli(p)) # 1 with probability p, 0 otherwise
stochastic_triple(f, 0.5) # Feeds 0.5 + ε into f
```
The output of a [Bernoulli variable](https://en.wikipedia.org/wiki/Bernoulli_distribution) cannot change by a tiny amount: it is either `0` or `1`. But in the probabilistic world, there is another way to change by a tiny amount *on average*: jump by a large amount, with tiny probability. `StochasticAD` introduces a stochastic triple object, which generalizes dual numbers by including a *third* component to describe these perturbations. Here, the stochastic triple says that the original random output was `0`, but given a small change `ε` in the input, the output will jump up to `1` with probability approximately `2ε`.
Stochastic triples can be used to construct a new random program whose average is the derivative of the average of the original program. We simply propagate stochastic triples through the program via [`stochastic_triple`](@ref), and then sum up the "dual" and "triple" components at the end via [`derivative_contribution`](@ref). This process is packaged together in the function [`derivative_estimate`](@ref). Let's try a crazier example, where we mix discrete and continuous randomness!
```@example estimate
using StochasticAD, Distributions
import Random # hide
Random.seed!(1234) # hide
function X(p)
a = p * (1 - p)
b = rand(Binomial(10, p))
c = 2 * b + 3 * rand(Bernoulli(p))
return a * c * rand(Normal(b, a))
end
st = @show stochastic_triple(X, 0.6) # sample a single stochastic triple at p = 0.6
@show derivative_contribution(st) # which produces a single derivative estimate...
samples = [derivative_estimate(X, 0.6) for i in 1:1000] # many samples from derivative program
derivative = mean(samples)
uncertainty = std(samples) / sqrt(1000)
println("derivative of 𝔼[X(p)] = $derivative ± $uncertainty")
```
## Index
See the [public API](public_api.md) for a walkthrough of the API, and the tutorials on differentiating a [random walk](tutorials/random_walk.md), a [stochastic game of life](tutorials/game_of_life.md), and a [particle filter](tutorials/particle_filter.md), and solving [stochastic optimization and variational problems](tutorials/optimizations.md) with discrete randomness. This is a prototype package with a number of [limitations](limitations.md).
================================================
FILE: docs/src/limitations.md
================================================
# Limitations of StochasticAD
`StochasticAD` has a number of limitations that are important to be aware of:
* `StochasticAD` uses operator-overloading just like [ForwardDiff](https://juliadiff.org/ForwardDiff.jl/stable/), so all of the [limitations](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/) listed there apply here too. Also note that some useful features of `ForwardDiff`, such as chunking for greater efficiency with a large number of parameters, have not yet been implemented here.
* We have limited support for reverse-mode AD via [smoothing](public_api.md#Smoothing), which cannot be guaranteed to be unbiased in all cases.
* We do not yet support `if` statements with discrete random input. A workaround can be to use array indexing to express discrete random choices (see [the random walk tutorial](tutorials/random_walk.md) for an example).
* We do not yet support non-real values as intermediate values (e.g. a function such as `length(A[rand(Bernoulli(p))])` where `A` is an array of strings is in theory differentiable).
* We do not support discrete random variables that are implicitly implemented using continuous random variables, e.g. `rand() < p`.
* We support a limited assortment of discrete random variables: currently `Bernoulli`, `Binomial`, `Geometric`, `Poisson`, and `Categorical` from [Distributions](https://juliastats.org/Distributions.jl/). We are working on increasing coverage across `Distributions` as well as other libraries providing discrete random samplers such as [MeasureTheory](https://cscherrer.github.io/MeasureTheory.jl/stable/).
* Higher-order differentiation is not supported.
`StochasticAD` is still in active development! PRs are welcome.
================================================
FILE: docs/src/public_api.md
================================================
# API walkthrough
The function [`derivative_estimate`](@ref) transforms a stochastic program containing discrete randomness into a new program whose average is the derivative of the original.
```@docs
derivative_estimate
```
While [`derivative_estimate`](@ref) is self-contained, we can also use the functions below to work with stochastic triples directly.
```@docs
StochasticAD.stochastic_triple
StochasticAD.derivative_contribution
StochasticAD.value
StochasticAD.delta
StochasticAD.perturbations
```
Note that [`derivative_estimate`](@ref) is simply the composition of [`stochastic_triple`](@ref) and [`derivative_contribution`](@ref). We also provide a convenience function for mimicking the behaviour
of standard AD, where derivatives of discrete random steps are dropped:
```@docs
StochasticAD.dual_number
```
## Algorithms
```@docs
StochasticAD.ForwardAlgorithm
StochasticAD.EnzymeReverseAlgorithm
```
## Smoothing
What happens if we were to run [`derivative_contribution`](@ref) after each step, instead of only at the end? This is *smoothing*, which combines the second and third components of a single stochastic triple into a single dual component.
Smoothing no longer has a guarantee of unbiasedness, but is surprisingly accurate in a number of situations.
For example, the popular [straight through gradient estimator](https://stackoverflow.com/questions/38361314/the-concept-of-straight-through-estimator-ste) can be viewed as a special case of smoothing.
Forward smoothing rules are provided through `ForwardDiff`, and backward rules through `ChainRules`, so that e.g. `Zygote.gradient` and `ForwardDiff.derivative` will use smoothed rules for discrete random variables rather than dropping the gradients entirely.
Currently, special discrete->discrete constructs such as array indexing are not supported for smoothing.
## Optimization
We also provide utilities to make it easier to get started with forming and training a model via stochastic gradient descent:
```@docs
StochasticAD.StochasticModel
StochasticAD.stochastic_gradient
```
These are used in the [tutorial on stochastic optimization](tutorials/optimizations.md).
================================================
FILE: docs/src/tutorials/game_of_life.md
================================================
# Stochastic Game of Life
We consider a stochastic version of [Conway's Game of Life](https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life), played on a two-dimensional board. We shall use the following packages,
```@setup game_of_life
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()
```
```@example game_of_life
using Distributions
using StochasticAD
using OffsetArrays
using StaticArrays
```
## Setting up the stochastic Game of Life
Each turn, the standard Game of Life applies the following rules to each cell,
```math
\text{dead and 3 neighbours alive} \to \text{ alive}, \\
\text{alive and 0, 1, or 4 neighbours alive} \to \text{ dead}.
```
The cell's status does not change otherwise. In our stochastic version, these rules instead occur with probability `1-θ`, while the opposite event has probability `θ`. To initialize the board at the beginning of the game, we randomly set each cell alive with probability `p`.
The following high level function sets up the probabilities and provides them to `play_game_of_life`.
```@example game_of_life
function play(p, θ=0.1, N=12, T=10; log=false)
# N is the board half-length, T are game time steps
low = θ
high = 1-θ
birth_probs = SA[low, low, low, high, low] # 0, 1, 2, 3, 4 neighbours
death_probs = SA[high, high, low, low, high] # 0, 1, 2, 3, 4 neighbours
return play_game_of_life(p, vcat(birth_probs, death_probs), N, T; log)
end
```
We can now implement the Game of Life based on the specification. At the end of the game, we return the total number of alive cells.
```@example game_of_life
# A single turn of the game
function update_state(all_probs, N, board_new, board_old)
for i in -N:N
for j in -N:N
neighbours = board_old[i+1, j] + board_old[i-1, j] + board_old[i, j-1] + board_old[i, j+1]
index = board_new[i,j] * 5 + neighbours + 1
b = rand(Bernoulli(all_probs[index]))
board_new[i,j] += (1 - 2 * board_new[i,j]) * b
end
end
end
function play_game_of_life(p, all_probs, N, T; log=false)
dual_type = promote_type(typeof(rand(Bernoulli(p))), typeof.(rand.(Bernoulli.(all_probs)))...) # a hacky way of getting the correct array type
board = OffsetArray(zeros(dual_type, 2*N + 3, 2*N + 3), -(N+1):(N+1), -(N+1):(N+1)) # center board at (0,0), pad by 1
# initialize the board
for i in -N:N
for j in -N:N
board[i,j] = rand(Bernoulli(p))
end
end
board_old = similar(board)
log && (history = [])
# play the game
for time_step in 1:T
copy!(board_old, board)
update_state(all_probs, N, board, board_old)
log && push!(history, copy(board))
end
if !log
return sum(board)
else
return sum(board), board, history
end
end
play(0.5, 0.1) # play the game with p = 0.5 and θ = 0.1
```
!!! note
Note that we did have to be careful to write this program to be compatible with the [current capabilities of `StochasticAD`](../limitations.md). For example, we concatenated `birth_probs` and `death_probs` into a single array `all_probs` and used the index `board[i, j] * 5 + neighbours + 1` to find the probability, rather than use the more natural `if alive... else...` syntax.
## Differentiating the Game of Life
Let's differentiate the Game of Life!
```@example game_of_life
@show stochastic_triple(play, 0.5) # let's take a look at a single stochastic triple
samples = [derivative_estimate(play, 0.5) for i in 1:10000] # take many samples
derivative = mean(samples)
uncertainty = std(samples) / sqrt(10000)
println("derivative of 𝔼[play(p)] = $derivative ± $uncertainty")
```
The following sketch of the final state of the board for a single run gives some insight into what the stochastic triples are doing. The original board is depicted in grey and white for dead and alive, and the cells which flip from dead to alive in the "alternative" path consider by the triples are marked with + signs, while the cells which flip from alive to dead are marked with X signs.
```@raw html
<img src="../images/final_gol_board.png" width="50%"/>
``` ⠀
================================================
FILE: docs/src/tutorials/optimizations.md
================================================
# Stochastic optimizations with discrete randomness
```@setup random_walk
import Pkg
Pkg.activate("../../../tutorials/toy_optimizations")
Pkg.develop(path="../../..")
Pkg.instantiate()
```
In this tutorial, we solve two stochastic optimization problems using `StochasticAD` where the optimization objective is formed using discrete distributions. We will need the following packages:
```@example optimizations
using Distributions # defines several supported discrete distributions
using StochasticAD
using CairoMakie # for plotting
using Optimisers # for stochastic gradient descent
```
## Optimizing our toy program
Recall the "crazy" program from the intro:
```@example optimizations
function X(p)
a = p * (1 - p)
b = rand(Binomial(10, p))
c = 2 * b + 3 * rand(Bernoulli(p))
return a * c * rand(Normal(b, a))
end
```
Let's maximize $\mathbb{E}[X(p)]$! First, let's setup the problem, using the [`StochasticModel`](@ref) helper utility to create a trainable model:
```@example optimizations
p0 = [0.5] # initial value of p, wrapped in an array for use in the stochastic model
m = StochasticModel(p -> -X(p[1]), p0) # formulate as minimization problem
```
Now, let's perform stochastic gradient descent using [Adam](https://arxiv.org/abs/1412.6980), where we use [`stochastic_gradient`](@ref) to obtain a gradient of the model.
```@example optimizations
iterations = 1000
trace = Float64[]
o = Adam() # use Adam for optimization
s = Optimisers.setup(o, m)
for i in 1:iterations
# Perform a gradient step
Optimisers.update!(s, m, stochastic_gradient(m))
push!(trace, m.p[])
end
p_opt = m.p[] # Our optimized value of p
```
Finally, let's plot the results of our optimization, and also perform a sweep through the parameter space to verify the accuracy of our estimator:
```@example optimizations
## Sweep through parameters to find average and derivative
ps = 0.02:0.02:0.98 # values of p to sweep
N = 1000 # number of samples at each p
avg = [mean(X(p) for _ in 1:N) for p in ps]
derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]
## Make plots
f = Figure()
ax = f[1, 1] = Axis(f, title = "Estimates", xlabel="Value of p")
lines!(ax, ps, avg, label = "≈ E[X(p)]")
lines!(ax, ps, derivative, label = "≈ d/dp E[X(p)]")
vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0)
hlines!(ax, [0.0], color = :black, linewidth = 1.0)
ylims!(ax, (-50, 80))
f[1, 2] = Legend(f, ax, framevisible = false)
ax = f[2, 1:2] = Axis(f, title = "Optimizer trace", xlabel="Iterations", ylabel="Value of p")
lines!(ax, trace, color = :green, linewidth = 2.0)
save("crazy_opt.png", f, px_per_unit = 4) # hide
nothing # hide
```

## Solving a variational problem
Let's consider a toy variational program: we find a [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) that is close to the distribution of a [negative Binomial](https://en.wikipedia.org/wiki/Negative_binomial_distribution), via minimization of the [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) $D_{\mathrm{KL}}$. Concretely, let us solve
```math
\underset{p \in \mathbb{R}}{\operatorname{argmin}}\; D_{\mathrm{KL}}\left(\mathrm{Pois}(p) \hspace{.3em}\middle\|\hspace{.3em} \mathrm{NBin}(10, 0.25) \right).
```
The following program produces an unbiased estimate of the objective:
```@example optimizations
function X(p)
i = rand(Poisson(p))
return logpdf(Poisson(p), i) - logpdf(NegativeBinomial(10, 0.25), i)
end
```
We can now optimize the KL-divergence via stochastic gradient descent!
```@example optimizations
# Minimize E[X] = KL(Poisson(p)| NegativeBinomial(10, 0.25))
iterations = 1000
p0 = [10.0]
m = StochasticModel(p -> X(p[1]), p0)
trace = Float64[]
o = Adam(0.1)
s = Optimisers.setup(o, m)
for i in 1:iterations
Optimisers.update!(s, m, stochastic_gradient(m))
push!(trace, m.p[])
end
p_opt = m.p[]
```
Let's plot our results in the same way as before:
```@example optimizations
ps = 10:0.5:50
N = 1000
avg = [mean(X(p) for _ in 1:N) for p in ps]
derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps]
f = Figure()
ax = f[1, 1] = Axis(f, title = "Estimates", xlabel="Value of p")
lines!(ax, ps, avg, label = "≈ E[X(p)]")
lines!(ax, ps, derivative, label = "≈ d/dp E[X(p)]")
vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0)
hlines!(ax, [0.0], color = :black, linewidth = 1.0)
ylims!(ax, (-2.5, 5))
f[1, 2] = Legend(f, ax, framevisible = false)
ax = f[2, 1:2] = Axis(f, title = "Optimizer trace", ylabel="Value of p", xlabel="Iterations")
lines!(ax, trace, color = :green, linewidth = 2.0)
save("variational.png", f, px_per_unit = 4) # hide
nothing # hide
```

================================================
FILE: docs/src/tutorials/particle_filter.md
================================================
# Differentiable particle filter
Using a bootstrap particle sampler, we can approximate the posterior distributions
of the states given noisy and partial observations of the state of a hidden Markov
model by a cloud of `K` weighted particles with weights `W`.
In this tutorial, we are going to:
- implement a differentiable particle filter based on `StochasticAD.jl`.
- visualize the particle filter in ``d = 2`` dimensions.
- compare the gradient based on the differentiable particle filter to a biased
gradient estimator as well as to the gradient of a differentiable Kalman filter.
- show how to benchmark primal evaluation, forward- and reverse-mode AD of the
particle filter.
## Setup
We will make use of several julia packages. For example, we are going to use
`Distributions` and `DistributionsAD` that implement the reparameterization trick
for Gaussian distributions used in the observation and state-transition model, which
we specify below. We also import `GaussianDistributions.jl` to implement the
differentiable Kalman filter.
### Package dependencies
```@setup particle_filter
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()
```
```@example particle_filter
# activate tutorial project file
# load dependencies
using StochasticAD
using Distributions
using DistributionsAD
using Random
using Statistics
using StatsBase
using LinearAlgebra
using Zygote
using ForwardDiff
using GaussianDistributions
using GaussianDistributions: correct, ⊕
using Measurements
using UnPack
using Plots
using LaTeXStrings
using BenchmarkTools
```
### Particle filter
For convenience, we first introduce the new type `StochasticModel` with the following
fields:
- `T`: total number of time steps.
- `start`: starting distribution for the initial state. For example, in the form of a narrow
Gaussian `start(θ) = Gaussian(x0, 0.001 * I(d))`.
- `dyn`: pointwise differentiable stochastic program in the form of Markov transition densities.
For example, `dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q(θ))`, where `Q(θ)` denotes the
covariance matrix.
- `obs`: observation model having a smooth conditional probability density depending on
current state `x` and parameters `θ`. For example, `obs(x, θ) = MvNormal(x, R(θ))`,
where `R(θ)` denotes the covariance matrix.
For parameters `θ`, `rand(start(θ))` gives a sample from the prior distribution of the
starting distribution. For current state `x` and parameters `θ`, `xnew = rand(dyn(x, θ))`
samples the new state (i.e. `dyn` gives for each `x, θ` a distribution-like object). Finally,
`y = rand(obs(x, θ))` samples an observation.
We can then define the `ParticleFilter` type that wraps a stochastic model `StochM::StochasticModel`,
a sampling strategy (with arguments `p, K, sump=1`) and observational data `ys`.
For simplicity, our implementation assumes a observation-likelihood function being available
via `pdf(obs(x, θ), y)`.
```@example particle_filter
struct StochasticModel{TType<:Integer,T1,T2,T3}
T::TType # time steps
start::T1 # prior
dyn::T2 # dynamical model
obs::T3 # observation model
end
struct ParticleFilter{mType<:Integer,MType<:StochasticModel,yType,sType}
m::mType # number of particles
StochM::MType # stochastic model
ys::yType # observations
sample_strategy::sType # sampling function
end
```
### Kalman filter
We consider a stochastic program that fulfills the assumptions of a Kalman filter.
We follow [Kalman.jl](https://github.com/mschauer/Kalman.jl/blob/master/README.md) to implement a differentiable version.
Our `KalmanFilter` type wraps a stochastic model `StochM::StochasticModel` and observational data `ys`. It assumes a
observation-likelihood function is implemented via `llikelihood(yres, S)`. The Kalman filter
contains the following fields:
- `d`: dimension of the state-transition matrix ``\Phi`` according to ``x = \Phi x + w`` with ``w \sim \operatorname{Normal}(0,Q)``.
- `StochM`: Stochastic model of type `StochasticModel`.
- `H`: linear map from the state space into the observed space according to ``y = H x + \nu`` with ``\nu \sim \operatorname{Normal}(0,R)``.
- `R`: covariance matrix entering the observation model according to ``y = H x + \nu`` with ``\nu \sim \operatorname{Normal}(0,R)``.
- `Q`: covariance matrix entering the state-transition model according to ``x = \Phi x + w`` with ``w \sim \operatorname{Normal}(0,Q)``.
- `ys`: observations.
```@example particle_filter
llikelihood(yres, S) = GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres)
struct KalmanFilter{dType<:Integer,MType<:StochasticModel,HType,RType,QType,yType}
# H, R = obs
# θ, Q = dyn
d::dType
StochM::MType # stochastic model
H::HType # observation model, maps the true state space into the observed space
R::RType # observation model, covariance matrix
Q::QType # dynamical model, covariance matrix
ys::yType # observations
end
```
To get observations `ys` from the latent states `xs` based on the
(true, potentially unknown) parameters `θ`, we simulate a single particle
from the forward model returning a vector of observations (no resampling steps).
```@example particle_filter
function simulate_single(StochM::StochasticModel, θ)
@unpack T, start, dyn, obs = StochM
x = rand(start(θ))
y = rand(obs(x, θ))
xs = [x]
ys = [y]
for t in 2:T
x = rand(dyn(x, θ))
y = rand(obs(x, θ))
push!(xs, x)
push!(ys, y)
end
xs, ys
end
```
A particle filter becomes efficient if resampling steps are included. Resampling
is numerically attractive because particles with small weight are discarded, so
computational resources are not wasted on particles with vanishing weight.
Here, let us implement a stratified resampling strategy, see for example
[Murray (2012)](https://arxiv.org/abs/1202.6163), where `p` denotes the probabilities of `K` particles
with `sump = sum(p)`.
```@example particle_filter
function sample_stratified(p, K, sump=1)
n = length(p)
U = rand()
is = zeros(Int, K)
i = 1
cw = p[1]
for k in 1:K
t = sump * (k - 1 + U) / K
while cw < t && i < n
i += 1
@inbounds cw += p[i]
end
is[k] = i
end
return is
end
```
This sampling strategy can be used within a differentiable resampling step in our
particle filter using the `use_new_weight` function as implemented in
`StochasticAD.jl`. The `resample` function below returns the states `X_new`
and weights `W_new` of the resampled particles.
- `m`: number of particles.
- `X`: current particle states.
- `W`: current weight vector of the particles.
- `ω == sum(W)` is an invariant.
- `sample_strategy`: specific resampling strategy to be used. For example, `sample_stratified`.
- `use_new_weight=true`: Allows one to switch between biased, stop-gradient method and
differentiable resampling step.
```@example particle_filter
function resample(m, X, W, ω, sample_strategy, use_new_weight=true)
js = Zygote.ignore(() -> sample_strategy(W, m, ω))
X_new = X[js]
if use_new_weight
# differentiable resampling
W_chosen = W[js]
W_new = map(w -> ω * new_weight(w / ω) / m, W_chosen)
else
# stop gradient, biased approach
W_new = fill(ω / m, m)
end
X_new, W_new
end
```
Note that we added a `if` condition that allows us to switch between the differentiable
resampling step and the stop-gradient approach.
We're now equipped with all primitive operations to set up the particle filter,
which propagates particles with weights `W` preserving the invariant `ω == sum(W)`.
We never normalize `W` and, therefore, `ω` in the code below contains likelihood
information. The particle-filter implementation defaults to return particle
positions and weights at `T` if `store_path=false` and takes the following input
arguments:
- `θ`: parameters for the stochastic program (state-transition and observation model).
- `store_path=false`: Option to store the path of the particles, e.g. to visualize/inspect
their trajectories.
- `use_new_weight=true`: Option to switch between the stop-gradient and our differentiable
resampling step method. Defaults to using differentiable resampling.
- `s`: controls the number of resampling steps according to `t > 1 && t < T && (t % s == 0)`.
```@example particle_filter
function (F::ParticleFilter)(θ; store_path=false, use_new_weight=true, s=1)
# s controls the number of resampling steps
@unpack m, StochM, ys, sample_strategy = F
@unpack T, start, dyn, obs = StochM
X = [rand(start(θ)) for j in 1:m] # particles
W = [1 / m for i in 1:m] # weights
ω = 1 # total weight
store_path && (Xs = [X])
for (t, y) in zip(1:T, ys)
# update weights & likelihood using observations
wi = map(x -> pdf(obs(x, θ), y), X)
W = W .* wi
ω_old = ω
ω = sum(W)
# resample particles
if t > 1 && t < T && (t % s == 0) # && 1 / sum((W / ω) .^ 2) < length(W) ÷ 32
X, W = resample(m, X, W, ω, sample_strategy, use_new_weight)
end
# update particle states
if t < T
X = map(x -> rand(dyn(x, θ)), X)
store_path && Zygote.ignore(() -> push!(Xs, X))
end
end
(store_path ? Xs : X), W
end
```
Following [Kalman.jl](https://github.com/mschauer/Kalman.jl/blob/master/README.md), we implement
a differentiable Kalman filter to check the ground-truth gradient. Our Kalman filter
returns an updated posterior state estimate and the log-likelihood and takes the
parameters of the stochastic program as an input.
```@example particle_filter
function (F::KalmanFilter)(θ)
@unpack d, StochM, H, R, Q = F
@unpack start = StochM
x = start(θ)
Φ = reshape(θ, d, d)
x, yres, S = GaussianDistributions.correct(x, ys[1] + R, H)
ll = llikelihood(yres, S)
xs = Any[x]
for i in 2:length(ys)
x = Φ * x ⊕ Q
x, yres, S = GaussianDistributions.correct(x, ys[i] + R, H)
ll += llikelihood(yres, S)
push!(xs, x)
end
xs, ll
end
```
For both filters, it is straightforward to obtain the log-likelihood via:
```@example particle_filter
function log_likelihood(F::ParticleFilter, θ, use_new_weight=true, s=1)
_, W = F(θ; store_path=false, use_new_weight=use_new_weight, s=s)
log(sum(W))
end
```
and
```@example particle_filter
function log_likelihood(F::KalmanFilter, θ)
_, ll = F(θ)
ll
end
```
For convenience, we define functions for
- forward-mode AD (and differentiable resampling step) to compute the gradient of
the log-likelihood of the particle filter.
- reverse-mode AD (and differentiable resampling step) to compute the gradient of
the log-likelihood of the particle filter.
- forward-mode AD (and stop-gradient method) to compute the gradient of
the log-likelihood of the particle filter (without the `new_weight` function).
- forward-mode AD to compute the gradient of the log-likelihood of the Kalman filter.
```@example particle_filter
forw_grad(θ, F::ParticleFilter; s=1) = ForwardDiff.gradient(θ -> log_likelihood(F, θ, true, s), θ)
back_grad(θ, F::ParticleFilter; s=1) = Zygote.gradient(θ -> log_likelihood(F, θ, true, s), θ)[1]
forw_grad_biased(θ, F::ParticleFilter; s=1) = ForwardDiff.gradient(θ -> log_likelihood(F, θ, false, s), θ)
forw_grad_Kalman(θ, F::KalmanFilter) = ForwardDiff.gradient(θ -> log_likelihood(F, θ), θ)
```
## Model
Having set up all core functionalities, we can now define the specific stochastic
model.
We consider the following system with a ``d``-dimensional latent process,
```math
\begin{aligned}
x_i &= \Phi x_{i-1} + w_i &\text{ with } w_i \sim \operatorname{Normal}(0,Q),\\
y_i &= x_i + \nu_i &\text{ with } \nu_i \sim \operatorname{Normal}(0,R),
\end{aligned}
```
where ``\Phi`` is a ``d``-dimensional rotation matrix.
```@example particle_filter
seed = 423897
### Define model
# here: n-dimensional rotation matrix
Random.seed!(seed)
T = 20 # time steps
d = 2 # dimension
# generate a rotation matrix
M = randn(d, d)
c = 0.3 # scaling
O = exp(c * (M - transpose(M)) / 2)
@assert det(O) ≈ 1
@assert transpose(O) * O ≈ I(d)
θtrue = vec(O) # true parameter
# observation model
R = 0.01 * collect(I(d))
obs(x, θ) = MvNormal(x, R) # y = H x + ν with ν ~ Normal(0, R)
# dynamical model
Q = 0.02 * collect(I(d))
dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q) # x = Φ*x + w with w ~ Normal(0,Q)
# starting position
x0 = randn(d)
# prior distribution
start(θ) = Gaussian(x0, 0.001 * collect(I(d)))
# put it all together
stochastic_model = StochasticModel(T, start, dyn, obs)
# relevant corresponding Kalman filterng defs
H_Kalman = collect(I(d))
R_Kalman = Gaussian(zeros(Float64, d), R)
# Φ_Kalman = O
Q_Kalman = Gaussian(zeros(Float64, d), Q)
###
### simulate model
Random.seed!(seed)
xs, ys = simulate_single(stochastic_model, θtrue)
```
## Visualization
Using `particle_filter(θ; store_path=true)` and `kalman_filter(θ)`, it is
straightforward to visualize both filters for our observed data.
```@example particle_filter
m = 1000
kalman_filter = KalmanFilter(d, stochastic_model, H_Kalman, R_Kalman, Q_Kalman, ys)
particle_filter = ParticleFilter(m, stochastic_model, ys, sample_stratified)
```
```@example particle_filter
### run and visualize filters
Xs, W = particle_filter(θtrue; store_path=true)
fig = plot(getindex.(xs, 1), getindex.(xs, 2), legend=false, xlabel=L"x_1", ylabel=L"x_2") # x1 and x2 are bad names..conflicting notation
scatter!(fig, getindex.(ys, 1), getindex.(ys, 2))
for i in 1:min(m, 100) # note that Xs has obs noise.
local xs = [Xs[t][i] for t in 1:T]
scatter!(fig, getindex.(xs, 1), getindex.(xs, 2), marker_z=1:T, color=:cool, alpha=0.1) # color to indicate time step
end
xs_Kalman, ll_Kalman = kalman_filter(θtrue)
plot!(getindex.(mean.(xs_Kalman), 1), getindex.(mean.(xs_Kalman), 2), legend=false, color="red")
png("pf_1") # hide
```

## Bias
We can also investigate the distribution of the gradients from the particle filter
with and without differentiable resampling step, as compared to the gradient computed
by differentiating the Kalman filter.
```@example particle_filter
### compute gradients
Random.seed!(seed)
X = [forw_grad(θtrue, particle_filter) for i in 1:200] # gradient of the particle filter *with* differentiation of the resampling step
Random.seed!(seed)
Xbiased = [forw_grad_biased(θtrue, particle_filter) for i in 1:200] # Gradient of the particle filter *without* differentiation of the resampling step
# pick an arbitrary coordinate
index = 1 # take derivative with respect to first parameter (2-dimensional example has a rotation matrix with four parameters in total)
# plot histograms for the sampled derivative values
fig = plot(normalize(fit(Histogram, getindex.(X, index), nbins=20), mode=:pdf), legend=false) # ours
plot!(normalize(fit(Histogram, getindex.(Xbiased, index), nbins=20), mode=:pdf)) # biased
vline!([mean(X)[index]], color=1)
vline!([mean(Xbiased)[index]], color=2)
# add derivative of differentiable Kalman filter as a comparison
XK = forw_grad_Kalman(θtrue, kalman_filter)
vline!([XK[index]], color="black")
png("pf_2") # hide
```

The estimator using the `new_weight` function agrees with the gradient value from
the Kalman filter and the [particle filter AD scheme developed by Ścibior and Wood](https://arxiv.org/abs/2106.10314),
unlike biased estimators that neglect the contribution of the derivative from the
resampling step. However, the biased estimator displays a smaller variance.
## Benchmark
Finally, we can use `BenchmarkTools.jl` to benchmark the run times of the primal
pass with respect to forward-mode and reverse-mode AD of the particle filter. As
expected, forward-mode AD outperforms reverse-mode AD for the small number of
parameters considered here.
```@example particle_filter
# secs for how long the benchmark should run, see https://juliaci.github.io/BenchmarkTools.jl/stable/
secs = 1
suite = BenchmarkGroup()
suite["scaling"] = BenchmarkGroup(["grads"])
suite["scaling"]["primal"] = @benchmarkable log_likelihood(particle_filter, θtrue)
suite["scaling"]["forward"] = @benchmarkable forw_grad(θtrue, particle_filter)
suite["scaling"]["backward"] = @benchmarkable back_grad(θtrue, particle_filter)
tune!(suite)
results = run(suite, verbose=true, seconds=secs)
t1 = measurement(mean(results["scaling"]["primal"].times), std(results["scaling"]["primal"].times) / sqrt(length(results["scaling"]["primal"].times)))
t2 = measurement(mean(results["scaling"]["forward"].times), std(results["scaling"]["forward"].times) / sqrt(length(results["scaling"]["forward"].times)))
t3 = measurement(mean(results["scaling"]["backward"].times), std(results["scaling"]["backward"].times) / sqrt(length(results["scaling"]["backward"].times)))
@show t1 t2 t3
ts = (t1, t2, t3) ./ 10^6 # ms
@show ts
```
================================================
FILE: docs/src/tutorials/random_walk.md
================================================
# Random walk
```@setup random_walk
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()
```
In this tutorial, we differentiate a random walk over the integers using `StochasticAD`. We will need the following packages,
```@example random_walk
using Distributions # defines several supported discrete distributions
using StochasticAD
using StaticArrays # for more efficient small arrays
```
## Setting up the random walk
Let's define a function for simulating the walk.
```@example random_walk
function simulate_walk(probs, steps, n)
state = 0
for i in 1:n
probs_here = probs(state) # transition probabilities for possible steps
step_index = rand(Categorical(probs_here)) # which step do we take?
step = steps[step_index] # get size of step
state += step
end
return state
end
```
Here, `steps` is a (1-indexed) array of the possible steps we can take. Each of these steps has a certain probability. To make things more interesting, we take in a *function* `probs` to produce these probabilities that can depend on the current state of the random walk.
Let's zoom in on the two lines where discrete randomness is involved.
```
step_index = rand(Categorical(probs_here)) # which step do we take?
step = steps[step_index] # get size of step
```
This is a cute pattern for making a discrete choice. First, we sample from a `Categorical` distribution from `Distributions.jl`, using the probabilities `probs_here` at our current position. This gives us an index between `1` and `length(steps)`, which we can use to pick the actual step to take. Stochastic triples propagate through both steps!
## Differentiating the random walk
Let's define a toy problem. We consider a random walk with `-1` and `+1` steps, where the probability of `+1` starts off high but decays exponentially with a decay length of `p`. We take `n = 100` steps and set `p = 50`.
```@example random_walk
using StochasticAD
const steps = SA[-1, 1] # move left or move right
make_probs(p) = X -> SA[1 - exp(-X / p), exp(-X / p)]
f(p, n) = simulate_walk(make_probs(p), steps, n)
@show f(50, 100) # let's run a single random walk with p = 50
@show stochastic_triple(p -> f(p, 100), 50) # let's see how a single stochastic triple looks like at p = 50
```
Time to differentiate! For fun, let's differentiate the *square* of the output of the random walk.
```@example random_walk
f_squared(p, n) = f(p, n)^2
samples = [derivative_estimate(p -> f_squared(p, 100), 50) for i in 1:1000] # many samples from derivative program at p = 50
derivative = mean(samples)
uncertainty = std(samples) / sqrt(1000)
println("derivative of 𝔼[f_squared] = $derivative ± $uncertainty")
```
## Computing variance
A crucial figure of merit for a derivative estimator is its variance. We compute the standard deviation (square root of the variance) of our estimator over a range of `n`.
```@example random_walk
n_range = 10:10:100 # range for testing asymptotic variance behaviour
p_range = 2 .* n_range
nsamples = 10000
stds_triple = Float64[]
for (n, p) in zip(n_range, p_range)
std_triple = std(derivative_estimate(p -> f_squared(p, n), p)
for i in 1:(nsamples))
push!(stds_triple, std_triple)
end
@show stds_triple
```
For comparison with other unbiased estimators, we also compute `stds_score` and `stds_score_baseline` for the
[score function gradient estimator](https://arxiv.org/pdf/1906.10652.pdf), both without and with a variance-reducing batch-average control variate (CV). (For details, see [`core.jl`](https://github.com/gaurav-arya/StochasticAD.jl/blob/main/tutorials/random_walk/core.jl) and [`compare_score.jl`](https://github.com/gaurav-arya/StochasticAD.jl/blob/main/random_walk/compare_score.jl).) We can now graph the standard deviation of each estimator versus $n$, observing lower variance in the unbiased derivative estimate produced by stochastic triples:
```@raw html
<img src="../images/compare_score.png" width="50%"/>
``` ⠀
================================================
FILE: docs/src/tutorials/reverse_demo.md
================================================
```@meta
EditURL = "../../../tutorials/reverse_example/reverse_demo.jl"
```
# Simple reverse mode example
```@setup random_walk
import Pkg
Pkg.activate("../../../tutorials")
Pkg.develop(path="../../..")
Pkg.instantiate()
import Random
Random.seed!(1234)
```
Load our packages
````@example reverse_demo
using StochasticAD
using Distributions
using Enzyme
using LinearAlgebra
````
Let us define our target function.
````@example reverse_demo
# Define a toy `StochasticAD`-differentiable function for computing an integer value from a string.
string_value(strings, index) = Int(sum(codepoint, strings[index]))
string_value(strings, index::StochasticTriple) = StochasticAD.propagate(index -> string_value(strings, index), index)
function f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling())
strings = ["cat", "dog", "meow", "woofs"]
index = randst(Categorical(θ); derivative_coupling)
return string_value(strings, index)
end
θ = [0.1, 0.5, 0.3, 0.1]
@show f(θ)
nothing
````
First, let's compute the sensitivity of `f` in a particular direction via forward-mode Stochastic AD.
````@example reverse_demo
u = [1.0, 2.0, 4.0, -7.0]
@show derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
nothing
````
Now, let's do the same with reverse-mode.
````@example reverse_demo
@show derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))
````
Let's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative.
````@example reverse_demo
forward() = derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u)
reverse() = derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins))))
N = 40000
directional_derivs_fwd = [forward() for i in 1:N]
derivs_bwd = [reverse() for i in 1:N]
directional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd]
println("Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))")
println("Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))")
@assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2)
nothing
````
---
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
================================================
FILE: ext/StochasticADEnzymeExt.jl
================================================
module StochasticADEnzymeExt
using StochasticAD
using Enzyme
function enzyme_target(u, X, p, backend)
# equivalent to derivative_estimate(X, p; backend, direction = u), but specialize to real output to make Enzyme happier
st = StochasticAD.stochastic_triple_direction(X, p, u; backend)
if !(StochasticAD.valtype(st) <: Real)
error("EnzymeReverseAlgorithm only supports real-valued outputs.")
end
return derivative_contribution(st)
end
function StochasticAD.derivative_estimate(X, p, alg::StochasticAD.EnzymeReverseAlgorithm;
direction = nothing, alg_data = (; forward_u = nothing))
if !isnothing(direction)
error("EnzymeReverseAlgorithm does not support keyword argument `direction`")
end
if p isa AbstractVector
Δu = zeros(float(eltype(p)), length(p))
u = isnothing(alg_data.forward_u) ?
rand(StochasticAD.RNG, float(eltype(p)), length(p)) : alg_data.forward_u
autodiff(Enzyme.Reverse, enzyme_target, Active, Duplicated(u, Δu),
Const(X), Const(p), Const(alg.backend))
return Δu
elseif p isa Real
u = isnothing(alg_data.forward_u) ? rand(StochasticAD.RNG, float(typeof(p))) :
forward_u
((du, _, _, _),) = autodiff(Enzyme.Reverse, enzyme_target, Active, Active(u),
Const(X), Const(p), Const(alg.backend))
return du
else
error("EnzymeReverseAlgorithm only supports p::Real or p::AbstractVector")
end
end
end
================================================
FILE: src/StochasticAD.jl
================================================
module StochasticAD
### Public API
export stochastic_triple, derivative_contribution, perturbations, smooth_triple,
dual_number, StochasticTriple # For working with stochastic triples
export derivative_estimate, StochasticModel, stochastic_gradient # Higher level functionality
export new_weight # Particle resampling
export PrunedFIsBackend,
PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend,
StrategyWrapperFIsBackend
export PrunedFIs, PrunedFIsAggressive, DictFIs, SmoothedFIs, StrategyWrapperFIs
export randst
export InversionMethodDerivativeCoupling
### Imports
using Random
using Distributions
using DistributionsAD
using ChainRulesCore
using ChainRulesOverloadGeneration
using ExprTools
using ForwardDiff
using Functors
import ChainRulesCore
# resolve conflicts while this code exists in both.
const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
const refresh_rules = ChainRulesOverloadGeneration.refresh_rules
const RNG = copy(Random.default_rng())
### Files responsible for backends
include("finite_infinitesimals.jl")
include("backends/pruned.jl")
include("backends/pruned_aggressive.jl")
include("backends/dict.jl")
include("backends/smoothed.jl")
include("backends/abstract_wrapper.jl")
include("backends/strategy_wrapper.jl")
using .PrunedFIsModule
using .PrunedFIsAggressiveModule
using .DictFIsModule
using .SmoothedFIsModule
using .AbstractWrapperFIsModule
using .StrategyWrapperFIsModule
include("prelude.jl") # Defines global constants
include("smoothing.jl") # Smoothing rules. Placed before general rules so that new_weight frule is caught by overload generation.
include("stochastic_triple.jl") # Defines stochastic triple object and higher level functions
include("general_rules.jl") # Defines rules for propagation through deterministic functions
include("discrete_randomness.jl") # Defines rules for propagation through discrete random functions
include("propagate.jl") # Experimental generalized forward propagation functionality
include("algorithms.jl") # Add algorithm-based higher-level interface
include("misc.jl") # Miscellaneous functions that do not fit in the usual flow
end
================================================
FILE: src/algorithms.jl
================================================
abstract type AbstractStochasticADAlgorithm end
"""
ForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm
A differentiation algorithm relying on forward propagation of stochastic triples.
The `backend` argument controls the algorithm used by the third component of the stochastic triples.
!!! note
The required computation time for forward-mode AD scales linearly with the number of
parameters in `p` (but is unaffected by the number of parameters in `X(p)`).
"""
struct ForwardAlgorithm{B <: StochasticAD.AbstractFIsBackend} <:
AbstractStochasticADAlgorithm
backend::B
end
"""
EnzymeReverseAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm
A differentiation algorithm relying on transposing the propagation of stochastic triples to
produce a reverse-mode algorithm. The transposition is performed by [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl),
which must be loaded for the algorithm to run.
Currently, only real- and vector-valued inputs are supported, and only real-valued outputs are supported.
The `backend` argument controls the algorithm used by the third component of the stochastic triples.
In the call to `derivative_estimate`, this algorithm optionally accepts `alg_data` with the field `forward_u`,
which specifies the directional derivative used in the forward pass that will be transposed.
If `forward_u` is not provided, it is randomly generated.
!!! warning
For the reverse-mode algorithm to yield correct results, the employed `backend` cannot use input-dependent pruning
strategies. A suggested reverse-mode compatible backend is `PrunedFIsBackend(Val(:wins))`.
Additionally, this algorithm relies on the ability of `Enzyme.jl` to differentiate the forward stochastic triple run.
It is recommended to check that the primal function `X` is type stable for its input `p` using a tool such as
[JET.jl](https://github.com/aviatesk/JET.jl), with all code executed in a function with no global state.
In addition, sometimes `X` may be type stable but stochastic triples introduce additional type instabilities.
This can be debugged by checking type stability of Enzyme's target, which is
`Base.get_extension(StochasticAD, :StochasticADEnzymeExt).enzyme_target(u, X, p, backend)`,
where `u` is a test direction.
!!! note
For more details on the reverse-mode approach, see the following papers and talks:
* ["You Only Linearize Once: Tangents Transpose to Gradients"](https://arxiv.org/abs/2204.10923), Radul et al. 2022.
* ["Reverse mode ADEV via YOLO: tangent estimators transpose to gradient estimators"](https://www.youtube.com/watch?v=pnPmk-leSsE), Becker et al. 2024
* ["Probabilistic Programming with Programmable Variational Inference"](https://pldi24.sigplan.org/details/pldi-2024-papers/87/Probabilistic-Programming-with-Programmable-Variational-Inference), Becker et al. 2024
"""
struct EnzymeReverseAlgorithm{B <: StochasticAD.AbstractFIsBackend}
backend::B
end
function derivative_estimate(
X, p, alg::ForwardAlgorithm; direction = nothing, alg_data::NamedTuple = (;))
return derivative_estimate(X, p; backend = alg.backend, direction)
end
@doc raw"""
derivative_estimate(X, p, alg::AbstractStochasticADAlgorithm = ForwardAlgorithm(PrunedFIsBackend()); direction=nothing, alg_data::NamedTuple = (;))
Compute an unbiased estimate of ``\frac{\mathrm{d}\mathbb{E}[X(p)]}{\mathrm{d}p}``,
the derivative of the expectation of the random function `X(p)` with respect to its input `p`.
Both `p` and `X(p)` can be any object supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/),
e.g. scalars or abstract arrays.
The output of `derivative_estimate` has the same outer structure as `p`, but with each
scalar in `p` replaced by a derivative estimate of `X(p)` with respect to that entry.
For example, if `X(p) <: AbstractMatrix` and `p <: Real`, then the output would be a matrix.
The `alg` keyword argument specifies the [algorithm](public_api.md#Algorithms) used to compute the derivative estimate.
For backward compatibility, an additional signature `derivative_estimate(X, p; backend, direction=nothing)`
is supported, which uses `ForwardAlgorithm` by default with the supplied `backend.`
The `alg_data` keyword argument can specify any additional data that specific algorithms accept or require.
When `direction` is provided, the output is only differentiated with respect to a perturbation
of `p` in that direction.
# Example
```jldoctest
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);
julia> derivative_estimate(rand ∘ Bernoulli, 0.5) # A random quantity that averages to the true derivative.
2.0
julia> derivative_estimate(x -> [rand(Bernoulli(x * i/4)) for i in 1:3], 0.5)
3-element Vector{Float64}:
0.2857142857142857
0.6666666666666666
0.0
```
"""
derivative_estimate
================================================
FILE: src/backends/abstract_wrapper.jl
================================================
module AbstractWrapperFIsModule
import ..StochasticAD
export AbstractWrapperFIs
"""
AbstractWrapperFIs{V, FIs} <: StochasticAD.AbstractFIs{V}
A convenience type for backend strategies that wrap another backend. A subtype `WrapperFIs <: AbstractWrapperFIs`
should have a field called Δs containing the wrapped backend, and should also define the following methods:
* `StochasticAD.similar_type(::Type{<:WrapperFIs}, V, FIs)`: return the type of a new
`WrapperFIs` with value type `V` and wrapped backend type `FIs`,
* `AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::WrapperFIs, Δs::AbstractFIs)`: construct
a new `WrapperFIs` wrapping `Δs` given an existing wrapped instace `wrapper_Δs`.
* `AbstractWrapperFIsModule.reconstruct_wrapper(::Type{<:WrapperFIs}, Δs::AbstractFIs)`: construct
a new `WrapperFIs` wrapping `Δs` given the type of an existing `WrapperFIs`.
Then, all other methods will generically be forwarded to the inner backend, except those overloaded by the
specific wrapper type.
"""
abstract type AbstractWrapperFIs{V, FIs} <: StochasticAD.AbstractFIs{V} end
function reconstruct_wrapper end
function StochasticAD.similar_new(Δs::AbstractWrapperFIs, Δ, w)
reconstruct_wrapper(Δs, StochasticAD.similar_new(Δs.Δs, Δ, w))
end
function StochasticAD.similar_empty(Δs::AbstractWrapperFIs, V)
reconstruct_wrapper(Δs, StochasticAD.similar_empty(Δs.Δs, V))
end
function StochasticAD.similar_type(WrapperFIs::Type{<:AbstractWrapperFIs{V0, FIs}},
V) where {V0, FIs}
return StochasticAD.similar_type(WrapperFIs, V, StochasticAD.similar_type(FIs, V))
end
StochasticAD.valtype(Δs::AbstractWrapperFIs) = StochasticAD.valtype(Δs.Δs)
function StochasticAD.couple(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
Δs_all;
rep = nothing,
kwargs...) where {V, FIs}
_Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)
_rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
return reconstruct_wrapper(StochasticAD.get_any(Δs_all),
StochasticAD.couple(FIs, _Δs_all; _rep_kwarg..., kwargs...))
end
function StochasticAD.combine(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
Δs_all;
rep = nothing,
kwargs...) where {V, FIs}
_Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)
_rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
return reconstruct_wrapper(StochasticAD.get_any(Δs_all),
StochasticAD.combine(FIs, _Δs_all; _rep_kwarg..., kwargs...))
end
function StochasticAD.get_rep(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}},
Δs_all;
kwargs...) where {V, FIs}
_Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all)
return reconstruct_wrapper(StochasticAD.get_any(Δs_all),
StochasticAD.get_rep(FIs, _Δs_all; kwargs...))
end
function StochasticAD.scalarize(Δs::AbstractWrapperFIs; rep = nothing, kwargs...)
_rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;)
return StochasticAD.structural_map(StochasticAD.scalarize(
Δs.Δs; _rep_kwarg..., kwargs...)) do _Δs
reconstruct_wrapper(Δs, _Δs)
end
end
function StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs, Δs_all; kwargs...)
StochasticAD.derivative_contribution(Δs.Δs, Δs_all; kwargs...)
end
StochasticAD.alltrue(f, Δs::AbstractWrapperFIs) = StochasticAD.alltrue(f, Δs.Δs)
StochasticAD.perturbations(Δs::AbstractWrapperFIs) = StochasticAD.perturbations(Δs.Δs)
function StochasticAD.filter_state(Δs::AbstractWrapperFIs, state)
StochasticAD.filter_state(Δs.Δs, state)
end
function StochasticAD.weighted_map_Δs(f, Δs::AbstractWrapperFIs; kwargs...)
reconstruct_wrapper(Δs, StochasticAD.weighted_map_Δs(f, Δs.Δs; kwargs...))
end
StochasticAD.new_Δs_strategy(Δs::AbstractWrapperFIs) = StochasticAD.new_Δs_strategy(Δs.Δs)
function Base.empty(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}) where {V, FIs}
return reconstruct_wrapper(WrapperFIs, empty(FIs))
end
Base.empty(Δs::AbstractWrapperFIs) = reconstruct_wrapper(Δs, empty(Δs.Δs))
Base.isempty(Δs::AbstractWrapperFIs) = isempty(Δs.Δs)
Base.length(Δs::AbstractWrapperFIs) = length(Δs.Δs)
Base.iszero(Δs::AbstractWrapperFIs) = iszero(Δs.Δs)
function StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs)
StochasticAD.derivative_contribution(Δs.Δs)
end
function Base.convert(::Type{<:AbstractWrapperFIs{V}}, Δs::AbstractWrapperFIs) where {V}
reconstruct_wrapper(Δs, convert(StochasticAD.similar_type(typeof(Δs.Δs), V), Δs.Δs))
end
function StochasticAD.send_signal(
Δs::AbstractWrapperFIs, signal::StochasticAD.AbstractPerturbationSignal)
reconstruct_wrapper(Δs, StochasticAD.send_signal(Δs.Δs, signal))
end
function Base.show(io::IO, Δs::AbstractWrapperFIs)
return show(io, Δs.Δs)
end
end
================================================
FILE: src/backends/dict.jl
================================================
module DictFIsModule
export DictFIsBackend, DictFIs
import ..StochasticAD
using Dictionaries
"""
DictFIsBackend <: StochasticAD.AbstractFIsBackend
A dictionary backend algorithm which keeps entries for each perturbation that has occurred without pruning.
Currently very unoptimized.
"""
struct DictFIsBackend <: StochasticAD.AbstractFIsBackend end
"""
DictFIsState
State maintained by dictionary backend.
"""
mutable struct DictFIsState
tag_count::Int64
valid::Bool
DictFIsState(valid = true) = new(0, valid)
end
struct InfinitesimalEvent
tag::Any # unique identifier
w::Float64 # weight (infinitesimal probability wε)
end
Base.:<(event1::InfinitesimalEvent, event2::InfinitesimalEvent) = event1.tag < event2.tag
function Base.:(==)(event1::InfinitesimalEvent, event2::InfinitesimalEvent)
event1.tag == event2.tag
end
Base.:isless(event1::InfinitesimalEvent, event2::InfinitesimalEvent) = event1 < event2
"""
DictFIs{V} <: StochasticAD.AbstractFIs{V}
The implementing backend structure for DictFIsBackend.
"""
struct DictFIs{V} <: StochasticAD.AbstractFIs{V}
dict::Dictionary{InfinitesimalEvent, V}
state::DictFIsState
end
state(Δs::DictFIs) = Δs.state
### Empty / no perturbation
function DictFIs{V}(state::DictFIsState) where {V}
DictFIs{V}(Dictionary{InfinitesimalEvent, V}(), state)
end
StochasticAD.similar_empty(Δs::DictFIs, V::Type) = DictFIs{V}(Δs.state)
Base.empty(Δs::DictFIs{V}) where {V} = StochasticAD.similar_empty(Δs::DictFIs, V::Type)
function Base.empty(::Type{<:DictFIs{V}}) where {V}
DictFIs{V}(DictFIsState(false))
end
### Create a new perturbation with infinitesimal probability
function new_perturbation(Δ::V, w::Real, state::DictFIsState) where {V}
state.tag_count += 1
event = InfinitesimalEvent(state.tag_count, w)
DictFIs{V}(Dictionary([event], [Δ]), state)
end
function StochasticAD.similar_new(Δs::DictFIs, Δ::V, w::Real) where {V}
new_perturbation(Δ, w, Δs.state)
end
### Create Δs backend for the first stochastic triple of computation
StochasticAD.create_Δs(::DictFIsBackend, V) = DictFIs{V}(DictFIsState())
### Convert type of a backend
function Base.convert(::Type{DictFIs{V}}, Δs::DictFIs) where {V}
DictFIs{V}(convert(Dictionary{InfinitesimalEvent, V}, Δs.dict), Δs.state)
end
### Getting information about Δs
Base.isempty(Δs::DictFIs) = isempty(Δs.dict)
Base.length(Δs::DictFIs) = length(Δs.dict)
Base.iszero(Δs::DictFIs) = isempty(Δs) || all(iszero.(Δs.dict))
function StochasticAD.derivative_contribution(Δs::DictFIs{V}) where {V}
sum((Δ * event.w for (event, Δ) in pairs(Δs.dict)), init = zero(V) * 0.0)
end
function StochasticAD.perturbations(Δs::DictFIs)
[(; Δ, weight = event.w, state = event) for (event, Δ) in pairs(Δs.dict)]
end
### Unary propagation
function StochasticAD.weighted_map_Δs(f, Δs::DictFIs; kwargs...)
# Pass key as state in map
mapped_values_and_weights = map(f, collect(Δs.dict), keys(Δs.dict))
mapped_values = first.(mapped_values_and_weights)
mapped_weights = last.(mapped_values_and_weights)
scaled_events = map((event, a) -> InfinitesimalEvent(event.tag, event.w * a),
keys(Δs.dict),
mapped_weights) # TODO: should original events (with old tag) also be modified?
dict = Dictionary(scaled_events, mapped_values)
DictFIs(dict, Δs.state)
end
StochasticAD.alltrue(f, Δs::DictFIs) = all(map(f, collect(Δs.dict)))
### Coupling
function StochasticAD.get_rep(::Type{<:DictFIs}, Δs_all)
for Δs in StochasticAD.structural_iterate(Δs_all)
if Δs.state.valid
return Δs
end
end
return first(Δs_all)
end
function StochasticAD.couple(FIs::Type{<:DictFIs}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all),
out_rep = nothing,
kwargs...)
all_keys = Iterators.map(StochasticAD.structural_iterate(Δs_all)) do Δs
keys(Δs.dict)
end
distinct_keys = unique(all_keys |> Iterators.flatten)
Δs_coupled_dict = [StochasticAD.structural_map(
Δs -> isassigned(Δs.dict, key) ?
Δs.dict[key] :
zero(eltype(Δs.dict)),
Δs_all)
for key in distinct_keys]
DictFIs(Dictionary(distinct_keys, Δs_coupled_dict), rep.state)
end
function StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...)
Δs_dicts = Iterators.map(Δs -> Δs.dict, StochasticAD.structural_iterate(Δs_all))
Δs_combined_dict = reduce(Δs_dicts) do Δs_dict1, Δs_dict2
mergewith((x, y) -> StochasticAD.structural_map(+, x, y), Δs_dict1, Δs_dict2)
end
DictFIs(Δs_combined_dict, rep.state)
end
function StochasticAD.scalarize(Δs::DictFIs; out_rep = nothing)
# TODO: use vcat here?
tupleify(Δ1, Δ2) = StochasticAD.structural_map(tuple, Δ1, Δ2)
Δ_all_allkeys = foldl(tupleify, values(Δs.dict))
Δ_all_rep = first(values(Δs.dict))
_keys = keys(Δs.dict)
return StochasticAD.structural_map(Δ_all_rep, Δ_all_allkeys) do _, Δ_allkeys
return DictFIs(Dictionary(_keys, Δ_allkeys), Δs.state)
end
end
function StochasticAD.filter_state(Δs::DictFIs{V}, key) where {V}
haskey(Δs.dict, key) ? Δs.dict[key] : zero(V)
end
### Miscellaneous
StochasticAD.similar_type(::Type{<:DictFIs}, V::Type) = DictFIs{V}
StochasticAD.valtype(::Type{<:DictFIs{V}}) where {V} = V
end
================================================
FILE: src/backends/pruned.jl
================================================
module PrunedFIsModule
import ..StochasticAD
export PrunedFIsBackend, PrunedFIs
"""
PrunedFIsBackend <: StochasticAD.AbstractFIsBackend
A backend algorithm that prunes between perturbations as soon as they clash (e.g. added together).
Currently chooses uniformly between all perturbations.
"""
struct PrunedFIsBackend{M <: Val} <: StochasticAD.AbstractFIsBackend
pruning_mode::M
function PrunedFIsBackend(pruning_mode::M = Val(:weights)) where {M}
if pruning_mode isa Val{:weights} || pruning_mode isa Val{:wins}
return new{M}(pruning_mode)
else
error("Unsupported pruning_mode $pruning_mode for `PrunedFIsBackend.")
end
end
end
"""
PrunedFIsState
State maintained by pruning backend.
"""
mutable struct PrunedFIsState{M, W}
tag::Int32
weight::Float64
valid::Bool
# TODO: generalize (wins, pruning_mode) into a general interface for accumulating state
# that informs future pruning decisions.
wins::W
pruning_mode::M
function PrunedFIsState(pruning_mode::M, valid = true) where {M <: Val}
wins = pruning_mode isa Val{:wins} ? (valid ? 1 : 0) : nothing
state::PrunedFIsState = new{M, typeof(wins)}(0, 0.0, valid, wins)
state.tag = objectid(state) % typemax(Int32)
return state
end
end
Base.:(==)(state1::PrunedFIsState, state2::PrunedFIsState) = state1.tag == state2.tag
# c.f. https://github.com/JuliaLang/julia/blob/61c3521613767b2af21dfa5cc5a7b8195c5bdcaf/base/hashing.jl#L38C45-L38C51
Base.hash(state::PrunedFIsState) = state.tag
"""
PrunedFIs{V} <: StochasticAD.AbstractFIs{V}
The implementing backend structure for PrunedFIsBackend.
"""
struct PrunedFIs{V, S <: PrunedFIsState} <: StochasticAD.AbstractFIs{V}
Δ::V
state::S
end
### Empty / no perturbation
PrunedFIs{V}(Δ::V, state::S) where {V, S <: PrunedFIsState} = PrunedFIs{V, S}(Δ, state)
PrunedFIs{V}(state::PrunedFIsState) where {V} = PrunedFIs{V}(zero(V), state)
# TODO: avoid allocations here
function StochasticAD.similar_empty(Δs::PrunedFIs, V::Type)
PrunedFIs{V}(PrunedFIsState(Δs.state.pruning_mode, false))
end
Base.empty(Δs::PrunedFIs{V}) where {V} = StochasticAD.similar_empty(Δs::PrunedFIs, V::Type)
# we truly have no clue what the state is here, so use an invalidated state
function Base.empty(::Type{<:PrunedFIs{V, S}}) where {V, M, S <: PrunedFIsState{M}}
PrunedFIs{V}(PrunedFIsState(M(), false))
end
### Create a new perturbation with infinitesimal probability
function StochasticAD.similar_new(Δs::PrunedFIs, Δ::V, w::Real) where {V}
if iszero(w)
return StochasticAD.similar_empty(Δs, V)
end
state = PrunedFIsState(Δs.state.pruning_mode)
state.weight += w
Δs = PrunedFIs{V}(Δ, state)
return Δs
end
### Create Δs backend for the first stochastic triple of computation
function StochasticAD.create_Δs(backend::PrunedFIsBackend, V)
PrunedFIs{V}(PrunedFIsState(backend.pruning_mode, false))
end
### Convert type of a backend
function Base.convert(::Type{<:PrunedFIs{V}}, Δs::PrunedFIs) where {V}
PrunedFIs{V}(convert(V, Δs.Δ), Δs.state)
end
### Getting information about perturbations
# "empty" here means no perturbation or a perturbation that has been pruned away
Base.isempty(Δs::PrunedFIs) = !Δs.state.valid
Base.length(Δs::PrunedFIs) = isempty(Δs) ? 0 : 1
function Base.iszero(Δs::PrunedFIs)
isempty(Δs) || all(iszero, StochasticAD.structural_iterate(Δs.Δ))
end
Base.iszero(Δs::PrunedFIs{<:Real}) = isempty(Δs) || iszero(Δs.Δ)
Base.iszero(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) || all(iszero.(Δs.Δ))
isapproxzero(Δs::PrunedFIs) = isempty(Δs) || isapprox(Δs.Δ, zero(Δs.Δ))
# we lazily prune, so check if empty first
function pruned_value(Δs::PrunedFIs{V}) where {V}
isempty(Δs) ? StochasticAD.structural_map(zero, Δs.Δ) : Δs.Δ
end
pruned_value(Δs::PrunedFIs{<:Real}) = isempty(Δs) ? zero(Δs.Δ) : Δs.Δ
pruned_value(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ
pruned_value(Δs::PrunedFIs{<:AbstractArray}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ
StochasticAD.derivative_contribution(Δs::PrunedFIs) = pruned_value(Δs) * Δs.state.weight
function StochasticAD.perturbations(Δs::PrunedFIs)
((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),)
end
### Unary propagation
function StochasticAD.weighted_map_Δs(f, Δs::PrunedFIs; kwargs...)
Δ_out, weight_out = f(pruned_value(Δs), Δs.state)
# TODO: we could add a direct overload for map_Δs that elides the below line
Δs.state.weight *= weight_out
PrunedFIs(Δ_out, Δs.state)
end
StochasticAD.alltrue(f, Δs::PrunedFIs) = f(pruned_value(Δs))
### Coupling
function StochasticAD.get_rep(FIs::Type{<:PrunedFIs}, Δs_all)
return empty(FIs) #StochasticAD.get_any(Δs_all)
end
function get_pruned_state(Δs_all; Δ_func = nothing, rep, out_rep = nothing)
if !isnothing(Δ_func) && isnothing(out_rep)
error("Specifying Δ_func requires out_rep to be specified.")
end
function op(cur_state, Δs)
# lazy pruning optimization temporarily disabled with custom Δ_func
# (because custom Δ_func's may prefer not to lazily prune)
(isnothing(Δ_func) && isapproxzero(Δs)) && return cur_state
candidate_state = Δs.state
if !candidate_state.valid ||
(candidate_state == cur_state)
return cur_state
end
if !cur_state.valid
return candidate_state
end
# Compute "strength" of each perturbation for pruning proposal
if !isnothing(Δ_func)
# TODO: structural_map for each state can take asymptotically more time than necessary when combining many distinct states
candidate_Δ = StochasticAD.structural_map(
Base.Fix2(StochasticAD.filter_state, candidate_state), Δs_all)
candidate_Δ_func::Float64 = Δ_func(candidate_Δ, candidate_state, out_rep)
cur_Δ = StochasticAD.structural_map(
Base.Fix2(StochasticAD.filter_state, cur_state), Δs_all)
cur_Δ_func::Float64 = Δ_func(cur_Δ, cur_state, out_rep)
else
candidate_Δ_func = 1.0
cur_Δ_func = 1.0
end
candidate_intrinsic_strength = Δs.state.pruning_mode isa Val{:wins} ?
candidate_state.wins : abs(candidate_state.weight)
cur_intrinsic_strength = Δs.state.pruning_mode isa Val{:wins} ? cur_state.wins :
abs(cur_state.weight)
candidate_strength = candidate_intrinsic_strength * candidate_Δ_func
cur_strength = cur_intrinsic_strength * cur_Δ_func
both_states_bad = iszero(candidate_strength) && iszero(cur_strength)
if both_states_bad
cur_state.valid = false
candidate_state.valid = false
return cur_state
end
# Prune between perturbations
total_strength = cur_strength + candidate_strength
p = candidate_strength / total_strength
if isone(p) || (rand(StochasticAD.RNG) < p)
cur_state.valid = false
if Δs.state.pruning_mode isa Val{:wins}
candidate_state.wins += 1
end
candidate_state.weight *= 1 / p
return candidate_state
else
candidate_state.valid = false
if Δs.state.pruning_mode isa Val{:wins}
cur_state.wins += 1
end
cur_state.weight *= 1 / (1 - p)
return cur_state
end
end
dummy_state = PrunedFIsState(rep.state.pruning_mode, false) # For type stability, as well as retval if no better state found. TODO: can this be avoided?
_new_state = foldl(op, StochasticAD.structural_iterate(Δs_all); init = dummy_state)
return _new_state::PrunedFIsState
end
# for pruning, coupling amounts to getting rid of perturbed values that have been
# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.
function StochasticAD.couple(
FIs::Type{<:PrunedFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all),
out_rep = nothing, Δ_func = nothing, kwargs...)
state = get_pruned_state(Δs_all; rep, Δ_func)
Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here
PrunedFIs(Δ_coupled, state)
end
# basically couple combined with a sum.
function StochasticAD.combine(
FIs::Type{<:PrunedFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all),
Δ_func = nothing, out_rep = nothing, kwargs...)
state = get_pruned_state(Δs_all;
rep,
out_rep,
Δ_func = !isnothing(Δ_func) ? (Δ, state, val) -> Δ_func(sum(Δ), state, val) :
Δ_func)
Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all))
PrunedFIs(Δ_combined, state)
end
function StochasticAD.scalarize(Δs::PrunedFIs; out_rep = nothing)
return StochasticAD.structural_map(Δs.Δ) do Δ
return PrunedFIs(Δ, Δs.state)
end
end
function StochasticAD.filter_state(Δs::PrunedFIs{V}, state) where {V}
Δs.state == state ? pruned_value(Δs) : zero(V)
end
### Miscellaneous
function StochasticAD.similar_type(::Type{<:PrunedFIs{V0, M}}, V::Type) where {V0, M}
PrunedFIs{V, M}
end
StochasticAD.valtype(::Type{<:PrunedFIs{V}}) where {V} = V
function Base.show(io::IO, Δs::PrunedFIs{V}) where {V}
print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε")
end
end
================================================
FILE: src/backends/pruned_aggressive.jl
================================================
module PrunedFIsAggressiveModule
import ..StochasticAD
export PrunedFIsAggressiveBackend, PrunedFIsAggressive
"""
PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend
A backend algorithm that aggressively prunes between perturbations as soon as they are created.
"""
struct PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend end
"""
PrunedFIsAggressiveState
State maintained by aggressive pruning backend.
"""
mutable struct PrunedFIsAggressiveState
active_tag::Int64 # 0 is always a dummy tag
weight::Float64
tag_count::Int64
valid::Bool
PrunedFIsAggressiveState(valid = true) = new(0, 0.0, 0, valid)
end
"""
PrunedFIsAggressive{V} <: StochasticAD.AbstractFIs{V}
The implementing backend structure for PrunedFIsAggressiveBackend.
"""
struct PrunedFIsAggressive{V} <: StochasticAD.AbstractFIs{V}
Δ::V
tag::Int
state::PrunedFIsAggressiveState
# directly called when propagating an existing perturbation
end
### Empty / no perturbation
function PrunedFIsAggressive{V}(state::PrunedFIsAggressiveState) where {V}
PrunedFIsAggressive{V}(zero(V), -1, state)
end
function StochasticAD.similar_empty(Δs::PrunedFIsAggressive, V::Type)
PrunedFIsAggressive{V}(Δs.state)
end
function Base.empty(Δs::PrunedFIsAggressive{V}) where {V}
StochasticAD.similar_empty(Δs, V)
end
# we truly have no clue what the state is here, so use an invalidated state
function Base.empty(::Type{<:PrunedFIsAggressive{V}}) where {V}
PrunedFIsAggressive{V}(PrunedFIsAggressiveState(false))
end
### Create a new perturbation with infinitesimal probability
function new_perturbation(Δ::V, w::Real, state::PrunedFIsAggressiveState) where {V}
total_weight = state.weight + w
if rand(StochasticAD.RNG) * total_weight < state.weight
state.weight += w
return PrunedFIsAggressive{V}(state)
else
state.tag_count += 1
state.active_tag = state.tag_count
state.weight += w
return PrunedFIsAggressive{V}(Δ, state.active_tag, state)
end
end
function StochasticAD.similar_new(Δs::PrunedFIsAggressive, Δ::V, w::Real) where {V}
new_perturbation(Δ, w, Δs.state)
end
### Create Δs backend for the first stochastic triple of computation
function StochasticAD.create_Δs(::PrunedFIsAggressiveBackend, V)
PrunedFIsAggressive{V}(PrunedFIsAggressiveState())
end
### Convert type of a backend
function Base.convert(::Type{PrunedFIsAggressive{V}}, Δs::PrunedFIsAggressive) where {V}
PrunedFIsAggressive{V}(convert(V, Δs.Δ), Δs.tag, Δs.state)
end
### Getting information about perturbations
# "empty" here means no perturbation or a perturbation that has been pruned away
Base.isempty(Δs::PrunedFIsAggressive) = Δs.tag != Δs.state.active_tag
Base.length(Δs::PrunedFIsAggressive) = isempty(Δs) ? 0 : 1
Base.iszero(Δs::PrunedFIsAggressive) = isempty(Δs) || iszero(Δs.Δ)
# we lazily prune, so check if empty first
pruned_value(Δs::PrunedFIsAggressive{V}) where {V} = isempty(Δs) ? zero(V) : Δs.Δ
function StochasticAD.derivative_contribution(Δs::PrunedFIsAggressive)
pruned_value(Δs) * Δs.state.weight
end
function StochasticAD.perturbations(Δs::PrunedFIsAggressive)
((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),)
end
### Unary propagation
function StochasticAD.weighted_map_Δs(f, Δs::PrunedFIsAggressive; kwargs...)
Δ_out, weight_out = f(Δs.Δ, nothing)
Δs.state.weight *= weight_out
PrunedFIsAggressive(Δ_out, Δs.tag, Δs.state)
end
StochasticAD.alltrue(f, Δs::PrunedFIsAggressive) = f(Δs.Δ)
### Coupling
function StochasticAD.get_rep(::Type{<:PrunedFIsAggressive}, Δs_all)
# Get some Δs with a valid state, or any if all are invalid.
return reduce((Δs1, Δs2) -> Δs1.state.valid ? Δs1 : Δs2,
StochasticAD.structural_iterate(Δs_all))
end
# for pruning, coupling amounts to getting rid of perturbed values that have been
# lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid.
function StochasticAD.couple(FIs::Type{<:PrunedFIsAggressive}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all),
out_rep = nothing, kwargs...)
state = rep.state
Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here
PrunedFIsAggressive(Δ_coupled, state.active_tag, state)
end
# basically couple combined with a sum.
function StochasticAD.combine(FIs::Type{<:PrunedFIsAggressive}, Δs_all;
rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...)
state = rep.state
Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all))
PrunedFIsAggressive(Δ_combined, state.active_tag, state)
end
function StochasticAD.scalarize(Δs::PrunedFIsAggressive; out_rep = nothing)
return StochasticAD.structural_map(Δs.Δ) do Δ
return PrunedFIsAggressive(Δ, Δs.tag, Δs.state)
end
end
StochasticAD.filter_state(Δs::PrunedFIsAggressive, _) = pruned_value(Δs)
### Miscellaneous
StochasticAD.similar_type(::Type{<:PrunedFIsAggressive}, V::Type) = PrunedFIsAggressive{V}
StochasticAD.valtype(::Type{<:PrunedFIsAggressive{V}}) where {V} = V
# should I have a mime input?
function Base.show(io::IO, mime::MIME"text/plain",
Δs::PrunedFIsAggressive{V}) where {V}
print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε, tag $(Δs.tag)")
end
function Base.show(io::IO, Δs::PrunedFIsAggressive{V}) where {V}
print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε, tag $(Δs.tag)")
end
end
================================================
FILE: src/backends/smoothed.jl
================================================
module SmoothedFIsModule
import ..StochasticAD
export SmoothedFIsBackend, SmoothedFIs
"""
SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend
A backend algorithm that smooths perturbations togethers.
"""
struct SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend end
"""
SmoothedFIs{V} <: StochasticAD.AbstractFIs{V}
The implementing backend structure for SmoothedFIsBackend.
"""
# TODO: make type of δ generic
struct SmoothedFIs{V, V_float} <: StochasticAD.AbstractFIs{V}
δ::V_float
function SmoothedFIs{V}(δ) where {V}
# hardcode Float64 representation for now, for simplicity.
δ_f64 = StochasticAD.structural_map(Base.Fix1(convert, Float64), δ)
return new{V, typeof(δ_f64)}(δ_f64)
end
end
### Empty / no perturbation
StochasticAD.similar_empty(::SmoothedFIs, V::Type) = SmoothedFIs{V}(0.0)
Base.empty(::Type{<:SmoothedFIs{V}}) where {V} = SmoothedFIs{V}(0.0)
Base.empty(Δs::SmoothedFIs) = empty(typeof(Δs))
### Create a new perturbation with infinitesimal probability
function StochasticAD.similar_new(::SmoothedFIs, Δ::V, w::Real) where {V}
SmoothedFIs{V}(Δ * w)
end
StochasticAD.new_Δs_strategy(::SmoothedFIs) = StochasticAD.TwoSidedStrategy()
### Create Δs backend for the first stochastic triple of computation
StochasticAD.create_Δs(::SmoothedFIsBackend, V) = SmoothedFIs{V}(0.0)
### Convert type of a backend
function Base.convert(FIs::Type{<:SmoothedFIs{V}}, Δs::SmoothedFIs) where {V}
SmoothedFIs{V}(Δs.δ)::FIs
end
### Getting information about perturbations
Base.isempty(Δs::SmoothedFIs) = false
Base.iszero(Δs::SmoothedFIs) = iszero(Δs.δ)
Base.iszero(Δs::SmoothedFIs{<:Tuple}) = all(iszero.(Δs.δ))
StochasticAD.derivative_contribution(Δs::SmoothedFIs) = Δs.δ
### Unary propagation
function StochasticAD.weighted_map_Δs(f, Δs::SmoothedFIs; deriv, out_rep, kwargs...)
SmoothedFIs{typeof(out_rep)}(deriv(Δs.δ))
end
StochasticAD.alltrue(f, Δs::SmoothedFIs) = true
### Coupling
StochasticAD.get_rep(::Type{<:SmoothedFIs}, Δs_all) = StochasticAD.get_any(Δs_all)
function StochasticAD.couple(
::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep, kwargs...)
SmoothedFIs{typeof(out_rep)}(StochasticAD.structural_map(Δs -> Δs.δ, Δs_all))
end
function StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, kwargs...)
V_out = StochasticAD.valtype(first(StochasticAD.structural_iterate(Δs_all)))
Δ_combined = sum(Δs -> Δs.δ, StochasticAD.structural_iterate(Δs_all))
SmoothedFIs{V_out}(Δ_combined)
end
function StochasticAD.scalarize(Δs::SmoothedFIs; out_rep)
return StochasticAD.structural_map(out_rep, Δs.δ) do out, δ
return SmoothedFIs{typeof(out)}(δ)
end
end
### Miscellaneous
StochasticAD.similar_type(::Type{<:SmoothedFIs}, V::Type) = SmoothedFIs{V, Float64}
StochasticAD.valtype(::Type{<:SmoothedFIs{V}}) where {V} = V
function Base.show(io::IO, Δs::SmoothedFIs)
print(io, "$(Δs.δ)ε")
end
end
================================================
FILE: src/backends/strategy_wrapper.jl
================================================
module StrategyWrapperFIsModule
using ..StochasticAD
using ..StochasticAD.AbstractWrapperFIsModule
export StrategyWrapperFIsBackend, StrategyWrapperFIs
struct StrategyWrapperFIsBackend{
B <: StochasticAD.AbstractFIsBackend,
S <: StochasticAD.AbstractPerturbationStrategy
} <:
StochasticAD.AbstractFIsBackend
backend::B
strategy::S
end
struct StrategyWrapperFIs{
V,
FIs <: StochasticAD.AbstractFIs{V},
S <: StochasticAD.AbstractPerturbationStrategy
} <:
AbstractWrapperFIs{V, FIs}
Δs::FIs
strategy::S
end
function StochasticAD.create_Δs(backend::StrategyWrapperFIsBackend, V)
return StrategyWrapperFIs(StochasticAD.create_Δs(backend.backend, V), backend.strategy)
end
function StochasticAD.similar_type(::Type{<:StrategyWrapperFIs{V0, FIs0, S}},
V,
FIs) where {V0, FIs0, S}
return StrategyWrapperFIs{V, FIs, S}
end
function AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::StrategyWrapperFIs, Δs)
return StrategyWrapperFIs(Δs, wrapper_Δs.strategy)
end
function AbstractWrapperFIsModule.reconstruct_wrapper(
::Type{
<:StrategyWrapperFIs{V, FIs, S},
},
Δs) where {V, FIs, S}
return StrategyWrapperFIs(Δs, S())
end
StochasticAD.new_Δs_strategy(Δs::StrategyWrapperFIs) = Δs.strategy
end
================================================
FILE: src/discrete_randomness.jl
================================================
## Helper functions for discrete distributions
# index of the parameter p
_param_index(::Geometric) = 1
_param_index(::Bernoulli) = 1
_param_index(::Binomial) = 2
_param_index(::Poisson) = 1
_param_index(::Categorical) = 1
_get_parameter(d) = params(d)[_param_index(d)]
# constructors
for dist in [:Geometric, :Bernoulli, :Binomial, :Poisson, :Categorical]
@eval _constructor(::$dist) = $dist
end
# reconstruct probability distribution with new paramter value
function _reconstruct(d, p)
i = _param_index(d)
return _constructor(d)(params(d)[1:(i - 1)]..., p, params(d)[(i + 1):end]...)
end
# support of probability distribution
_has_finite_support(d) = false
_has_finite_support(d::Union{Bernoulli, Binomial, Categorical}) = true
_get_support(d::Union{Bernoulli, Binomial, Categorical}) = minimum(d):maximum(d)
# manual overloads to ensure that static-ness is preserved for Bernoulli's and Categoricals with static arrays.
# since mapping over the range above could result in allocating vectors.
_get_support(::Bernoulli) = (0, 1)
# the map below looks a bit silly, but it gives us a collection of the categories with the same structure as probs(d).
_get_support(d::Categorical) = map((val, prob) -> val, 1:ncategories(d), probs(d))
## Derivative couplings
# Derivative coupling approaches, determining which weighted perturbations to consider
abstract type AbstractDerivativeCoupling end
"""
InversionMethodDerivativeCoupling(; mode::Val = Val(:positive_weight), handle_zeroprob::Val = Val(true))
Specifies an inversion method coupling for generating perturbations from a univariate distribution.
Valid choices of `mode` are `Val(:positive_weight)`, `Val(:always_right)`, and `Val(:always_left)`.
# Example
```jldoctest
julia> using StochasticAD, Distributions, Random; Random.seed!(4321);
julia> function X(p)
return randst(Bernoulli(1 - p); derivative_coupling = InversionMethodDerivativeCoupling(; mode = Val(:always_right)))
end
X (generic function with 1 method)
julia> stochastic_triple(X, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability -2.0ε)
```
"""
Base.@kwdef struct InversionMethodDerivativeCoupling{M, HZP}
mode::M = Val(:positive_weight)
handle_zeroprob::HZP = Val(true)
end
# Strategies for precisely which perturbations to form given a derivative coupling
struct SingleSidedStrategy <: AbstractPerturbationStrategy end
struct TwoSidedStrategy <: AbstractPerturbationStrategy end
struct SmoothedStraightThroughStrategy <: AbstractPerturbationStrategy end
struct StraightThroughStrategy <: AbstractPerturbationStrategy end
struct IgnoreDiscreteStrategy <: AbstractPerturbationStrategy end
new_Δs_strategy(Δs) = SingleSidedStrategy()
# Derivative coupling high-level interface
"""
δtoΔs(d, val, δ, Δs::AbstractFIs)
Given the parameter `val` of a distribution `d` and an infinitesimal change `δ`,
return the discrete change in the output, with a similar representation to `Δs`.
"""
δtoΔs(d, val, δ, Δs, derivative_coupling) = δtoΔs(
d, val, δ, Δs, derivative_coupling, new_Δs_strategy(Δs))
function δtoΔs(d, val, δ, Δs, derivative_coupling, ::SingleSidedStrategy)
_δtoΔs(d, val, δ, Δs, derivative_coupling)
end
function δtoΔs(d, val, δ, Δs, derivative_coupling, ::TwoSidedStrategy)
Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling)
Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling)
return combine((scale(Δs1, 0.5), scale(Δs2, -0.5)))
end
# TODO: implement this ST for other distributions and couplings, if meaningful?
function δtoΔs(d::Union{Bernoulli, Binomial},
val,
δ,
Δs,
derivative_coupling::InversionMethodDerivativeCoupling,
::StraightThroughStrategy)
p = succprob(d)
Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling)
Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling)
return combine((scale(Δs1, 1 - p), scale(Δs2, -p)))
end
function δtoΔs(d, val::V, δ, Δs, derivative_coupling, ::IgnoreDiscreteStrategy) where {V}
similar_empty(Δs, V)
end
# Implement straight through strategy, works for all distrs, but does something that is only
# meaningful for smoothed backends (using one(val))
function δtoΔs(d, val, δ, Δs, derivative_coupling, ::SmoothedStraightThroughStrategy)
p = _get_parameter(d)
δout = ForwardDiff.derivative(a -> mean(_reconstruct(d, p + a * δ)), 0.0)
return similar_new(Δs, one(val), δout)
end
# Derivative coupling low-level implementations
function _δtoΔs(d::Geometric,
val::V,
δ::Real,
Δs::AbstractFIs,
derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
p = succprob(d)
if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||
(derivative_coupling.mode isa Val{:always_right})
return val > 0 ? similar_new(Δs, -one(V), δ * val / p / (1 - p)) :
similar_empty(Δs, V)
elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||
(derivative_coupling.mode isa Val{:always_left})
return similar_new(Δs, one(V), -δ * (val + 1) / p)
else
return similar_empty(Δs, V)
end
end
function _δtoΔs(d::Bernoulli,
val::V,
δ::Real,
Δs::AbstractFIs,
derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
p = succprob(d)
if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||
(derivative_coupling.mode isa Val{:always_right})
return isone(val) ? similar_empty(Δs, V) : similar_new(Δs, one(V), δ / (1 - p))
elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||
(derivative_coupling.mode isa Val{:always_left})
return isone(val) ? similar_new(Δs, -one(V), -δ / p) : similar_empty(Δs, V)
else
return similar_empty(Δs, V)
end
end
function _δtoΔs(d::Binomial,
val::V,
δ::Real,
Δs::AbstractFIs,
derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
p = succprob(d)
n = ntrials(d)
if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||
(derivative_coupling.mode isa Val{:always_right})
return val == n ? similar_empty(Δs, V) :
similar_new(Δs, one(V), δ * (n - val) / (1 - p))
elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||
(derivative_coupling.mode isa Val{:always_left})
return !iszero(val) ? similar_new(Δs, -one(V), -δ * val / p) : similar_empty(Δs, V)
else
return similar_empty(Δs, V)
end
end
function _δtoΔs(d::Poisson,
val::V,
δ::Real,
Δs::AbstractFIs,
derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
p = mean(d) # rate
if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) ||
(derivative_coupling.mode isa Val{:always_right})
return similar_new(Δs, 1, δ)
elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) ||
(derivative_coupling.mode isa Val{:always_left})
return val > 0 ? similar_new(Δs, -1, -δ * val / p) : similar_empty(Δs, V)
else
return similar_empty(Δs, V)
end
end
function _δtoΔs(d::Categorical,
val::V,
δs,
Δs::AbstractFIs,
derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed}
p = params(d)[1]
# NB: Although we might expect sum(δs) = 0, it is useful to handle things more generally, viewing δs
# as perturbing the Categorical distribution locally along some direction in the space of general measures.
# The below formulation gets things right in this case too.
left_sum = sum(δs[1:(val - 1)], init = zero(eltype(δs)))
right_sum = sum(δs[1:val], init = zero(eltype(δs)))
if (derivative_coupling.mode isa Val{:positive_weight} && left_sum > 0) ||
(derivative_coupling.mode isa Val{:always_left} && !iszero(left_sum))
# compute left_nonzero
if derivative_coupling.handle_zeroprob isa Val{true}
stop = rand() * left_sum
upto = zero(eltype(δs)) # The "upto" logic handles an edge case of probability 0 events that have non-zero derivative.
# It's a lot of logic to handle an edge case, but hopefully it's optimized away.
left_nonzero = val
for i in (val - 1):-1:1
if !iszero(p[i]) || ((upto += δs[i]) > stop)
left_nonzero = i
break
end
end
else
left_nonzero = val - 1
end
Δs_left = similar_new(Δs, left_nonzero - val, left_sum / p[val])
else
Δs_left = similar_empty(Δs, typeof(val))
end
if (derivative_coupling.mode isa Val{:positive_weight} && right_sum < 0) ||
(derivative_coupling.mode isa Val{:always_right} && !iszero(right_sum))
# compute right_nonzero
if derivative_coupling.handle_zeroprob isa Val{true}
stop = -rand() * right_sum
upto = zero(eltype(δs))
right_nonzero = val
for i in (val + 1):length(p)
if !iszero(p[i]) || ((upto += δs[i]) > stop)
right_nonzero = i
break
end
end
else
right_nonzero = val + 1
end
Δs_right = similar_new(Δs, right_nonzero - val, -right_sum / p[val])
else
Δs_right = similar_empty(Δs, typeof(val))
end
return combine((Δs_left, Δs_right); rep = Δs)
end
## Propagation couplings
abstract type AbstractPropagationCoupling end
"""
InversionMethodPropagationCoupling
Specifies an inversion method coupling for propagating perturbations.
"""
struct InversionMethodPropagationCoupling <: AbstractPropagationCoupling end
function _map_func(d, val, Δ, ::InversionMethodPropagationCoupling)
# construct alternative distribution
p = _get_parameter(d)
alt_d = _reconstruct(d, p + Δ)
# compute bounds on original ω
low = cdf(d, val - 1)
high = cdf(d, val)
# sample alternative value
alt_val = quantile(alt_d, rand(RNG) * (high - low) + low)
return convert(Signed, alt_val - val)
end
function _map_enumeration(d, val, Δ, ::InversionMethodPropagationCoupling)
# construct alternative distribution
p = _get_parameter(d)
alt_d = _reconstruct(d, p + Δ)
# compute bounds on original ω
low = cdf(d, val - 1)
high = cdf(d, val)
if _has_finite_support(alt_d)
map(_get_support(alt_d)) do alt_val
# interval intersect of (cdf(alt_d, alt_val - 1), cdf(alt_d, alt_val)) and (low, high)
alt_low = cdf(alt_d, alt_val - 1)
alt_high = cdf(alt_d, alt_val)
prob_alt = max(0.0, min(alt_high, high) - max(alt_low, low)) /
(high - low)
return (alt_val - val, prob_alt)
end
else
error("enumeration not supported for distribution $d. Does $d have finite support?")
end
end
## Overloading of random sampling
# Define randst interface
"""
randst(rng, d::Distributions.Sampleable; kwargs...)
When no keyword arguments are provided, `randst` behaves identically to `rand(rng, d)` in both ordinary computation
and for stochastic triple dispatches. However, `randst` also allows the user to provide various keyword arguments
for customizing the differentiation logic. The set of allowed keyword arguments depends on the type of `d`: a couple
common ones are `derivative_coupling` and `propagation_coupling`.
For developers: if you wish to accept custom keyword arguments in a stochastic triple dispatch, you should overload
`randst`, and redirect `rand` to your `randst` method. If you do not, it suffices to just overload `rand`.
"""
randst(rng, d::Distributions.Sampleable; kwargs...) = rand(rng, d)
randst(d::Distributions.Sampleable; kwargs...) = randst(Random.default_rng(), d; kwargs...)
# Define stochastic triple rules
for dist in [:Geometric, :Bernoulli, :Binomial, :Poisson]
@eval function Base.rand(rng::AbstractRNG,
d_st::$dist{StochasticTriple{T, V, FIs}}) where {T, V, FIs}
return randst(rng, d_st)
end
@eval function randst(rng::AbstractRNG,
d_st::$dist{StochasticTriple{T, V, FIs}};
Δ_kwargs = (;),
derivative_coupling = InversionMethodDerivativeCoupling(),
propagation_coupling = InversionMethodPropagationCoupling()) where {T, V, FIs}
st = _get_parameter(d_st)
d = _reconstruct(d_st, st.value)
val = convert(Signed, rand(rng, d))
Δs1 = δtoΔs(d, val, st.δ, st.Δs, derivative_coupling)
Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling),
st.Δs;
enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling),
deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling),
out_rep = val,
Δ_kwargs...)
StochasticTriple{T}(val, zero(val), combine((Δs2, Δs1); rep = Δs1)) # ensure that tags are in order in combine, in case backend wishes to exploit this
end
end
# currently handle Categorical separately since parameter is a vector
# what if some elements in vector are not stochastic triples... promotion should take care of that?
function Base.rand(rng::AbstractRNG,
d_st::Categorical{StochasticTriple{T, V, FIs}}) where {T, V, FIs}
return randst(rng, d_st)
end
function randst(rng::AbstractRNG,
d_st::Categorical{<:StochasticTriple{T},
<:AbstractVector{<:StochasticTriple{T, V}}};
Δ_kwargs = (;),
derivative_coupling = InversionMethodDerivativeCoupling(),
propagation_coupling = InversionMethodPropagationCoupling()) where {T, V}
sts = _get_parameter(d_st) # stochastic triple for each probability
p = map(st -> st.value, sts) # try to keep the same type. e.g. static array -> static array. TODO: avoid allocations
d = _reconstruct(d_st, p)
val = convert(Signed, rand(rng, d))
Δs_all = map(st -> st.Δs, sts)
Δs_rep = get_rep(Δs_all)
Δs1 = δtoΔs(d, val, map(st -> st.δ, sts), Δs_rep, derivative_coupling)
Δs_coupled = couple(Δs_all; rep = Δs_rep, out_rep = p) # TODO: again, there are possible allocations here
Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling),
Δs_coupled;
enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling),
deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling),
out_rep = val,
Δ_kwargs...)
Δs = combine((Δs2, Δs1); rep = Δs1, out_rep = val, Δ_kwargs...)
StochasticTriple{T}(val, zero(val), Δs)
end
## Handling finite perturbation to Binomial number of trials
"""
DiscreteDeltaStochasticTriple{T, V, FIs <: AbstractFIs}
An experimental discrete stochastic triple type used internally for representing perturbations
to non-real quantities. Currently only used to represent a finite perturbation to the Binomial
parameter n.
## Constructor
- `value`: the primal value.
- `Δs``: some representation of the perturbation to the primal, which can have an unconventional
interpretation depending on `T`.
"""
struct DiscreteDeltaStochasticTriple{T, V, FIs <: AbstractFIs}
value::V
Δs::FIs
function DiscreteDeltaStochasticTriple{T, V, FIs}(value::V,
Δs::FIs) where {T, V,
FIs <: AbstractFIs}
new{T, V, FIs}(value, Δs)
end
end
function DiscreteDeltaStochasticTriple{T}(val::V, Δs::FIs) where {T, V, FIs <: AbstractFIs}
DiscreteDeltaStochasticTriple{T, V, FIs}(val, Δs)
end
function Distributions.Binomial(n::StochasticTriple{T}, p::Real) where {T}
return DiscreteDeltaStochasticTriple{T}(Binomial(n.value, p), n.Δs)
end
# TODO: Support functions other than `rand` called on a perturbed Binomial.
function Base.rand(rng::AbstractRNG,
d_st::DiscreteDeltaStochasticTriple{T, <:Binomial}) where {T}
return randst(rng, d_st)
end
function randst(rng::AbstractRNG,
d_st::DiscreteDeltaStochasticTriple{T, <:Binomial}) where {T}
d = d_st.value
val = rand(rng, d)
function map_func(Δ)
if Δ >= 0
return rand(StochasticAD.RNG, Binomial(Δ, value(succprob(d))))
else
return -rand(StochasticAD.RNG,
Hypergeometric(value(val), ntrials(d) - value(val), -Δ))
end
end
Δs = map(map_func, d_st.Δs)
if val isa StochasticTriple
return StochasticTriple{T}(val.value, val.δ, combine((Δs, val.Δs); rep = Δs))
else
return StochasticTriple{T}(val, zero(val), Δs)
end
end
================================================
FILE: src/finite_infinitesimals.jl
================================================
# TODO: make this a module, with the interface exported?
##
"""
AbstractFIsBackend
An abstract type for backend strategies of Finite perturbations that occur with Infinitesimal probability (FIs).
"""
abstract type AbstractFIsBackend end
"""
AbstractFIs{V}
An abstract type for concrete backend representations of Finite Infinitesimals.
"""
abstract type AbstractFIs{V} end
### Some of the necessary interface notes below.
# TODO: document
function create_Δs end
function similar_new end
function similar_empty end
function similar_type end
valtype(Δs::AbstractFIs) = valtype(typeof(Δs))
# TODO: typeof ∘ first is a loose check, should make more robust.
# TODO: perhaps deprecate these methods in favor of an explicit first argument?
couple(Δs_all; kwargs...) = couple(typeof(first(Δs_all)), Δs_all; kwargs...)
combine(Δs_all; kwargs...) = combine(typeof(first(Δs_all)), Δs_all; kwargs...)
get_rep(Δs_all; kwargs...) = get_rep(typeof(first(Δs_all)), Δs_all; kwargs...)
function scalarize end
function derivative_contribution end
function alltrue end
function perturbations end
function filter_state end
function weighted_map_Δs end
function map_Δs(f, Δs::AbstractFIs; kwargs...)
StochasticAD.weighted_map_Δs((Δs, state) -> (f(Δs, state), 1.0), Δs; kwargs...)
end
function Base.map(f, Δs::AbstractFIs; kwargs...)
StochasticAD.map_Δs((Δs, _) -> f(Δs), Δs; kwargs...)
end
# We also add a scale to deriv for scaling smoothed perturbations
function scale(Δs::AbstractFIs, a::Real)
StochasticAD.weighted_map_Δs((Δ, state) -> (Δ, a),
Δs;
deriv = Base.Fix1(*, a),
out_rep = Δs)
end
function new_Δs_strategy end
# utility function useful e.g. for get_rep in some backends
function get_any(Δs_all)
# The code below is a bit ridiculous, but it's faster than `first` for small structures:)
foldl((Δs1, Δs2) -> Δs1, StochasticAD.structural_iterate(Δs_all))
end
abstract type AbstractPerturbationStrategy end
abstract type AbstractPerturbationSignal end
function send_signal end
# Ignore signals by default since they do not change semantics.
function StochasticAD.send_signal(
Δs::StochasticAD.AbstractFIs, ::StochasticAD.AbstractPerturbationSignal)
return Δs
end
================================================
FILE: src/general_rules.jl
================================================
"""
Operators which have already been overloaded by StochasticAD.
"""
const handled_ops = Tuple{DataType, Int}[]
"""
define_triple_overload(sig)
Given the signature type-type of the primal function, define operator
overloading rules for stochastic triples.
Currently supports functions with all-real inputs and one real output.
"""
# TODO: special case optimizations
# TODO: generalizations to not-all-real inputs and/or not-one-real output
function define_triple_overload(sig)
opT, argTs = Iterators.peel(ExprTools.parameters(sig))
opT <: Type{<:Type} && return # not handling constructors
sig <: Tuple{Type, Vararg{Any}} && return
opT <: Core.Builtin && return false # can't do operator overloading for builtins
isabstracttype(opT) || fieldcount(opT) == 0 || return false # not handling functors
isempty(argTs) && return false # we are an operator overloading AD, need operands
all(argT isa Type && Real <: argT for argT in argTs) || return
N = length(ExprTools.parameters(sig)) - 1 # skip the op
# Skip already-handled ops, as well as ops that will be handled manually later (and more correctly, see #79).
if (opT, N) in handled_ops || (opT.instance in UNARY_TYPEFUNCS_WRAP)
return
end
push!(handled_ops, (opT, N))
if opT.instance in UNARY_PREDICATES && (N == 1)
@eval function (f::$opT)(st::StochasticTriple)
val = value(st)
out = f(val)
if !alltrue(Δ -> (f(val + Δ) == out), st.Δs)
error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)")
end
return out
end
elseif opT.instance in BINARY_PREDICATES && (N == 2)
# Special case equality comparisons as in https://github.com/JuliaDiff/ForwardDiff.jl/pull/481
if opT.instance == Base.:(==)
return_value_real = quote
out && iszero(delta(st))
end
return_value_st = quote
out2 = out && (delta(st1) == delta(st2))
end
else
return_value_real = quote
out
end
return_value_st = quote
out
end
end
@eval function (f::$opT)(st::StochasticTriple, x::Real)
val = value(st)
out = f(val, x)
if !alltrue(Δ -> (f(val + Δ, x) == out), st.Δs)
error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)")
end
return $return_value_real
end
@eval function (f::$opT)(x::Real, st::StochasticTriple)
val = value(st)
out = f(x, val)
if !alltrue(Δ -> (f(x, val + Δ) == out), st.Δs)
error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)")
end
return $return_value_real
end
@eval function (f::$opT)(st1::StochasticTriple, st2::StochasticTriple)
val1 = value(st1)
val2 = value(st2)
out = f(val1, val2)
Δs_coupled = couple((st1.Δs, st2.Δs); out_rep = (val1, val2))
safe_perturb = alltrue(Δs -> f(val1 + Δs[1], val2 + Δs[2]) == out, Δs_coupled)
if !safe_perturb
error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)")
end
return $return_value_st
end
elseif N == 1
if Base.return_types(frule, (Tuple{NoTangent, Real}, opT, Real))[1] <:
Tuple{Any, NoTangent}
return
end
@eval function (f::$opT)(st::StochasticTriple{T}; kwargs...) where {T}
run_frule = δ -> begin
args_tangent = (NoTangent(), δ)
return frule(args_tangent, f, value(st); kwargs...)
end
val, δ0 = run_frule(delta(st))
δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0
if !iszero(st.Δs)
Δs = map(Δ -> f(st.value + Δ; kwargs...) - val, st.Δs;
deriv = last ∘ run_frule, out_rep = val)
else
Δs = similar_empty(st.Δs, typeof(val))
end
return StochasticTriple{T}(val, δ, Δs)
end
elseif N == 2
if Base.return_types(frule, (Tuple{NoTangent, Real, Real}, opT, Real, Real))[1] <:
Tuple{Any, NoTangent}
return
end
for R in AMBIGUOUS_TYPES
@eval function (f::$opT)(st::StochasticTriple{T}, x::$R; kwargs...) where {T}
run_frule = δ -> begin
args_tangent = (NoTangent(), δ, zero(x))
return frule(args_tangent, f, value(st), x; kwargs...)
end
val, δ0 = run_frule(delta(st))
δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ?
zero(value(st)) : δ0
if !iszero(st.Δs)
Δs = map(Δ -> f(st.value + Δ, x; kwargs...) - val, st.Δs;
deriv = last ∘ run_frule, out_rep = val)
else
Δs = similar_empty(st.Δs, typeof(val))
end
return StochasticTriple{T}(val, δ, Δs)
end
@eval function (f::$opT)(x::$R, st::StochasticTriple{T}; kwargs...) where {T}
run_frule = δ -> begin
args_tangent = (NoTangent(), zero(x), δ)
return frule(args_tangent, f, x, value(st); kwargs...)
end
val, δ0 = run_frule(delta(st))
δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ?
zero(value(st)) : δ0
if !iszero(st.Δs)
Δs = map(Δ -> f(x, st.value + Δ; kwargs...) - val, st.Δs;
deriv = last ∘ run_frule, out_rep = val)
else
Δs = similar_empty(st.Δs, typeof(val))
end
return StochasticTriple{T}(val, δ, Δs)
end
end
@eval function (f::$opT)(sts::Vararg{StochasticTriple{T}, 2}; kwargs...) where {T}
run_frule = δs -> begin
args_tangent = (NoTangent(), δs...)
args = (f, value.(sts)...)
return frule(args_tangent, args...; kwargs...)
end
val, δ0 = run_frule(delta.(sts))
δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0
Δs_all = map(st -> getfield(st, :Δs), sts)
if all(iszero.(Δs_all))
Δs = similar_empty(first(sts).Δs, typeof(val))
else
vals_in = value.(sts)
Δs_coupled = couple(Tuple(Δs_all); out_rep = vals_in)
mapfunc = let vals_in = vals_in
Δ -> (f((vals_in .+ Δ)...; kwargs...) - val)
end
Δs = map(mapfunc, Δs_coupled; deriv = last ∘ run_frule, out_rep = val)
end
return StochasticTriple{T}(val, δ, Δs)
end
end
end
on_new_rule(define_triple_overload, frule)
### Extra overloads
# TODO: generalize the below logic to compactly handle a wider range of functions.
# See also https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/dual.jl.
function Base.hash(st::StochasticTriple, hsh::UInt)
if !isempty(st.Δs)
error("Hashing a stochastic triple with perturbations not yet supported.")
end
hash(StochasticAD.value(st), hsh)
end
#=
This is a hacky experimental way to convert a float-like stochastic triple
into an integer-like one, to facilitate generic coding.
=#
function Base.round(I::Type{<:Integer}, st::StochasticTriple{T, V}) where {T, V}
return StochasticTriple{T}(round(I, st.value), map(Δ -> round(I, st.value + Δ), st.Δs))
end
for op in UNARY_TYPEFUNCS_NOWRAP
function (::typeof(op))(::Type{<:StochasticTriple{T, V, FIs}}) where {T, V, FIs}
return op(V)
end
end
for op in UNARY_TYPEFUNCS_WRAP
function (::typeof(op))(::Type{StochasticTriple{T, V, FIs}}) where {T, V, FIs}
return StochasticTriple{T, V, FIs}(op(V), zero(V), empty(FIs))
end
function (::typeof(op))(st::StochasticTriple)
return op(typeof(st))
end
end
for op in RNG_TYPEFUNCS_WRAP
function (::typeof(op))(rng::AbstractRNG,
::Type{StochasticTriple{T, V, FIs}}) where {T, V, FIs}
return StochasticTriple{T, V, FIs}(op(rng, V), zero(V), empty(FIs))
end
end
#=
The short-circuit "x == y" case in Base.isapprox is bad for us
because it could unnecessarily lead to a boolean-predicate
depends on output error where StochasticAD cannot prove correctness.
We patch up the rule by removing the short-circuit, allowing some common
cases to work.
In the future, we will ideally handle the overloading rule in a more general
way. (E.g. by catching the chain rule for isapprox and recursively calling isapprox
on the values.)
=#
function Base.isapprox(st1::StochasticTriple, st2::StochasticTriple;
atol::Real = 0, rtol::Real = Base.rtoldefault(st1, st2, atol),
nans::Bool = false, norm::Function = abs)
(isfinite(st1) && isfinite(st2) &&
norm(st1 - st2) <= max(atol, rtol * max(norm(st1), norm(st2)))) ||
(nans && isnan(st1) && isnan(st2))
end
function Base.isapprox(st1::StochasticTriple, x::Real;
atol::Real = 0, rtol::Real = Base.rtoldefault(st1, x, atol),
nans::Bool = false, norm::Function = abs)
(isfinite(st1) && isfinite(x) &&
norm(st1 - x) <= max(atol, rtol * max(norm(st1), norm(x)))) ||
(nans && isnan(st1) && isnan(x))
end
function Base.isapprox(x::Real, st::StochasticTriple; kwargs...)
return Base.isapprox(st, x; kwargs...)
end
# Alternate version of _isassigned that does not fall back on try/catch.
_isassigned(C, i) = (i in eachindex(C))
"""
Base.getindex(C::AbstractArray, st::StochasticTriple{T})
A simple prototype rule for array indexing. Assumes that underlying type of `st` can index into collection C.
"""
# TODO: support multiple indices, cartesian indices, non abstract array indexables, other use cases...
# Example to fix: A[:, :, st]
function Base.getindex(C::AbstractArray, st::StochasticTriple{T, V, FIs}) where {T, V, FIs}
val = C[st.value]
do_map = (Δ, state) -> begin
return value(C[st.value + Δ], state) - value(val, state)
end
# TODO: below doesn't support sparse arrays, use something like nextind
deriv = δ -> begin
scale = if _isassigned(C, st.value + 1) && _isassigned(C, st.value - 1)
1 / 2 * (value(C[st.value + 1]) - value(C[st.value - 1]))
elseif _isassigned(C, st.value + 1)
value(C[st.value + 1]) - value(C[st.value])
elseif _isassigned(C, st.value - 1)
value(C[st.value]) - value(C[st.value - 1])
else
zero(eltype(C))
end
return scale * δ
end
Δs = StochasticAD.map_Δs(do_map, st.Δs; deriv, out_rep = value(val))
if val isa StochasticTriple
Δs = combine((Δs, val.Δs))
end
return StochasticTriple{T}(value(val), delta(val), Δs)
end
================================================
FILE: src/misc.jl
================================================
@doc raw"""
StochasticModel(X, p)
Combine stochastic program `X` with parameter `p` into
a trainable model using [Functors](https://fluxml.ai/Functors.jl/stable/), where
`p <: AbstractArray`.
Formulate as a minimization problem, i.e. find ``p`` that minimizes ``\mathbb{E}[X(p)]``.
"""
struct StochasticModel{S <: AbstractArray, T}
X::T
p::S
end
@functor StochasticModel (p,)
"""
stochastic_gradient(m::StochasticModel)
Compute gradient with respect to the trainable parameter `p` of `StochasticModel(X, p)`.
"""
function stochastic_gradient(m::StochasticModel)
fmap(p -> derivative_estimate(m.X, p), m)
end
================================================
FILE: src/prelude.jl
================================================
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)
const UNARY_PREDICATES = [isinf, isnan, isfinite, iseven, isodd, isreal, isinteger]
const BINARY_PREDICATES = [
isequal,
isless,
<,
>,
==,
!=,
<=,
>=
]
const UNARY_TYPEFUNCS_NOWRAP = [Base.rtoldefault]
const UNARY_TYPEFUNCS_WRAP = [
typemin,
typemax,
floatmin,
floatmax,
zero,
one
]
const RNG_TYPEFUNCS_WRAP = [rand, randn, randexp]
"""
structural_iterate(args)
Internal helper function for iterating through the scalar values of a functor,
where AbstractFIs are also counted as scalars.
"""
function structural_iterate(args)
make_iterator(x) = x isa AbstractArray ? x : (x,)
exclude(x) = Functors.isleaf(x) || (x isa AbstractFIs)
iter = fmap(make_iterator, args; walk = Functors.IterateWalk(), cache = nothing,
exclude)
return iter
end
structural_iterate(args::NTuple{N, Union{Real, AbstractFIs}}) where {N} = args
structural_iterate(args::AbstractArray{T}) where {T <: Union{Real, AbstractFIs}} = args
structural_iterate(args::T) where {T <: Real} = (args,)
"""
structural_map(f, args)
Internal helper function for a structure-preserving map,
often to be used on a function's input/output arguments.
Currently uses [fmap](https://fluxml.ai/Functors.jl/stable/api/#Functors.fmap)
from Functors.jl as a backend.
"""
function structural_map(f, args...; only_vals = nothing)
walk = if only_vals isa Val{true}
Functors.StructuralWalk()
elseif (only_vals isa Val{false}) || isnothing(only_vals)
Functors.DefaultWalk()
else
error("Unsupported argument only_vals = $only_vals")
end
fmap((args...) -> args[1] isa AbstractArray ? f.(args...) : f(args...), args...;
cache = nothing,
walk)
end
================================================
FILE: src/propagate.jl
================================================
"""
A version of `value`` that allows unrecognized args to pass through.
"""
function get_value(arg)
if arg isa StochasticTriple
return value(arg)
else
# potentially dangerous, see also note in get_Δs
return arg
end
end
function get_Δs(arg, FIs)
if arg isa StochasticTriple
return arg.Δs
else
#=
this case is a bit dangerous: perturbations could be dropped here
if a leaf of a functor somehow contains a type that is not one of
the two above.
=#
return empty(similar_type(FIs, typeof(arg)))
end
end
function strip_Δs(arg; use_dual = Val(true))
if arg isa StochasticTriple
# TODO: replace check below with a more robust notion of discreteness.
if valtype(arg) <: Integer
return value(arg)
else
if use_dual isa Val{true}
return ForwardDiff.Dual{tag(arg)}(value(arg), delta(arg))
else
return StochasticAD.StochasticTriple{tag(arg)}(
value(arg), delta(arg), empty(arg.Δs))
end
end
else
return arg
end
end
"""
propagate(f, args...; keep_deltas = Val(false))
Propagates `args` through a function `f`, handling stochastic triples by independently running `f` on the primal
and the alternatives, rather than by inspecting the internals of `f` (which may possibly be unsupported by `StochasticAD`).
Currently handles deterministic functions `f` with any input and output that is `fmap`-able by `Functors.jl`.
If `f` has a continuously differentiable component, provide `keep_deltas = Val(true)`.
This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator
overloading rules based on dispatch. For example:
```jldoctest
using StochasticAD, Distributions
import Random # hide
Random.seed!(4321) # hide
function mybranch(x)
str = repr(x) # string-valued intermediate!
if length(str) < 2
return 3
else
return 7
end
end
function f(x)
return mybranch(9 + rand(Bernoulli(x)))
end
# stochastic_triple(f, 0.5) # this would fail
# Add a dispatch rule for mybranch using StochasticAD.propagate
mybranch(x::StochasticAD.StochasticTriple) = StochasticAD.propagate(mybranch, x)
stochastic_triple(f, 0.5) # now works
# output
StochasticTriple of Int64:
3 + 0ε + (4 with probability 2.0ε)
```
!!! warning
This function is experimental and subject to change.
"""
function propagate(f,
args...;
keep_deltas = Val(false),
provided_st_rep = nothing,
deriv = nothing)
# TODO: support kwargs to f (or just use kwfunc in macro)
#=
TODO: maybe don't iterate through every scalar of array below,
but rather have special array dispatch
=#
st_rep = if provided_st_rep === nothing
args_iter = structural_iterate(args)
function args_fold(arg1, arg2)
if arg1 isa StochasticTriple
if (arg2 isa StochasticTriple) && (tag(arg1) !== tag(arg2))
throw(ArgumentError("Tags of combined stochastic triples do not match!"))
end
return arg1
else
return arg2
end
end
foldl(args_fold, args_iter)
else
provided_st_rep
end
if !(st_rep isa StochasticTriple)
return f(args...)
end
primal_args = structural_map(get_value, args)
input_args = keep_deltas isa Val{false} ? primal_args : structural_map(strip_Δs, args)
#=
TODO: the below is dangerous is general.
It should be safe so long as f does not close over stochastic triples.
(If f is a closure, the parameters of f should be treated like any other parameters;
if they are stochastic triples and we are ignoring them, dangerous in general.)
=#
out = f(input_args...)
val = structural_map(value, out)
# TODO: what does the only_vals do in the below and why?
Δs_all = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), args;
only_vals = Val{true}())
# TODO: Coupling approach below needs to handle non-perturbable objects.
Δs_coupled = couple(backendtype(st_rep), Δs_all; rep = st_rep.Δs, out_rep = val)
function map_func(Δ_coupled)
perturbed_args = structural_map(+, primal_args, Δ_coupled)
#=
TODO: for f discrete random with randomness independent of params,
could couple here. But difficult without a splittable RNG.
=#
alt = f(perturbed_args...)
return structural_map((x, y) -> value(x) - y, alt, val)
end
Δs = map(map_func, Δs_coupled; out_rep = val, deriv)
# TODO: make sure all FI backends support interface needed below
new_out = structural_map(out, scalarize(Δs; out_rep = val)) do leaf_out, leaf_Δs
StochasticAD.StochasticTriple{tag(st_rep)}(value(leaf_out), delta(leaf_out),
leaf_Δs)
end
return new_out
end
================================================
FILE: src/smoothing.jl
================================================
### Particle resampling
@doc raw"""
new_weight(p::Real)
Simulate a Bernoulli variable whose primal output is always 1.
Uses a smoothing rule for use in forward and reverse-mode AD, which is exactly unbiased when the quantity is only
used in linear functions (e.g. used as an [importance weight](https://en.wikipedia.org/wiki/Importance_sampling)).
"""
new_weight(p::Real) = 1
function new_weight(p::ForwardDiff.Dual{T}) where {T}
Δp = ForwardDiff.partials(p)
val_p = ForwardDiff.value(p)
val_p = max(1e-5, val_p) # TODO: is this necessary?
ForwardDiff.Dual{T}(one(val_p), Δp / val_p)
end
function ChainRulesCore.frule((_, Δp), ::typeof(new_weight), p::Real)
val_p = max(1e-5, p) # TODO: is this necessary?
return one(p), Δp / val_p
end
function ChainRulesCore.rrule(::typeof(new_weight), p)
function new_weight_pullback(∇Ω)
return (ChainRulesCore.NoTangent(), ∇Ω / p)
end
return (one(p), new_weight_pullback)
end
# Smoothed rules for univariate single-parameter distributions.
function smoothed_delta(d, val, δ, derivative_coupling)
Δs_empty = SmoothedFIs{typeof(val)}(0.0)
return derivative_contribution(δtoΔs(d, val, δ, Δs_empty, derivative_coupling))
end
for (dist, i, field) in [
(:Geometric, :1, :p),
(:Bernoulli, :1, :p),
(:Binomial, :2, :p),
(:Poisson, :1, :λ),
(:Categorical, :1, :p)
] # i = index of parameter p
# dual overloading
@eval function Base.rand(rng::AbstractRNG,
d_dual::$dist{<:ForwardDiff.Dual{T}}) where {T}
return randst(rng, d_dual)
end
@eval function randst(rng::AbstractRNG,
d_dual::$dist{<:ForwardDiff.Dual{T}};
derivative_coupling = InversionMethodDerivativeCoupling()) where {T}
dual = params(d_dual)[$i]
# dual could represent an array of duals or a single one; map handles both cases.
p = map(value, dual)
# Generate a δ for each partial component.
partials_indices = ntuple(identity, length(first(dual).partials))
δs = map(i -> map(d -> ForwardDiff.partials(d)[i], dual), partials_indices)
d = $dist(params(d_dual)[1:($i - 1)]..., p,
params(d_dual)[($i + 1):end]...)
val = convert(Signed, rand(rng, d))
partials = ForwardDiff.Partials(map(
δ -> smoothed_delta(d, val, δ, derivative_coupling), δs))
ForwardDiff.Dual{T}(val, partials)
end
# frule
@eval function ChainRulesCore.frule(Δargs, ::typeof(rand), rng::AbstractRNG,
d::$dist)
return frule(Δargs, randst, rng, d)
end
@eval function ChainRulesCore.frule((_, _, Δd), ::typeof(randst), rng::AbstractRNG,
d::$dist; derivative_coupling = InversionMethodDerivativeCoupling())
val = convert(Signed, rand(rng, d))
Δval = smoothed_delta(d, val, Δd, derivative_coupling)
return (val, Δval)
end
# rrule
@eval function ChainRulesCore.rrule(::typeof(rand), rng::AbstractRNG, d::$dist)
return rrule(randst, rng, d)
end
@eval function ChainRulesCore.rrule(::typeof(randst),
rng::AbstractRNG,
d::$dist;
derivative_coupling = InversionMethodDerivativeCoupling())
val = convert(Signed, rand(rng, d))
function rand_pullback(∇out)
p = params(d)[$i]
if p isa Real
Δp = smoothed_delta(d, val, one(val), derivative_coupling)
else
# TODO: this rule is O(length(p)^2), whereas we should be able to do O(length(p)) by reversing through δtoΔs.
I = eachindex(p)
V = eltype(p)
onehot(i) = map(j -> j == i ? one(V) : zero(V), I)
Δp = map(i -> smoothed_delta(d, val, onehot(i), derivative_coupling), I)
end
# rrule_via_ad approach below not used because slow.
# Δp = rrule_via_ad(config, smoothed_delta, d, val, map(one, p))[2](∇out)[4]
Δd = ChainRulesCore.Tangent{typeof(d)}(; $field = ∇out * Δp)
return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), Δd)
end
return (val, rand_pullback)
end
end
================================================
FILE: src/stochastic_triple.jl
================================================
"""
StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}}
Stores the primal value of the computation, alongside a "dual" component
representing an infinitesimal change, and a "triple" component that tracks
discrete change(s) with infinitesimal probability.
Pretty printed as "value + δε + ({pretty print of Δs})".
## Constructor
- `value`: the primal value.
- `δ`: the value of the almost-sure derivative, i.e. the rate of "infinitesimal" change.
- `Δs`: alternate values with associated weights, i.e. Finite perturbations with Infinitesimal probability,
represented by a backend `FIs <: AbstractFIs`.
"""
struct StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}} <: Real
value::V
δ::V # infinitesimal change
Δs::FIs # finite changes with infinitesimal probabilities # (Δ = 3, p = 1*h)
function StochasticTriple{T, V, FIs}(value::V, δ::V,
Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}
new{T, V, FIs}(value, δ, Δs)
end
end
"""
value(st::StochasticTriple)
Return the primal value of `st`.
"""
value(x::Real, state = nothing) = x
# Experimental method for obtaining the alternate value of a stochastic triple associated with a certain backend state.
value(st::StochasticTriple) = st.value
function value(st::StochasticTriple, state)
st.value + filter_state(st.Δs, state)
end
#=
Support ForwardDiff.Dual for internal usage.
Assumes batch size is 1.
=#
value(d::ForwardDiff.Dual, state = nothing) = ForwardDiff.value(d)
"""
delta(st::StochasticTriple)
Return the almost-sure derivative of `st`, i.e. the rate of infinitesimal change.
"""
delta(x::Real) = zero(x)
delta(st::StochasticTriple) = st.δ
# Support ForwardDiff.Dual for internal usage.
delta(d::ForwardDiff.Dual) = ForwardDiff.partials(d)[1]
"""
perturbations(st::StochasticTriple)
Return the finite perturbation(s) of `st`, in a format dependent on the [backend](devdocs.md) used for storing perturbations.
"""
perturbations(x::Real) = ()
perturbations(st::StochasticTriple) = perturbations(st.Δs)
"""
send_signal(st::StochasticTriple, signal::AbstractPerturbationSignal)
send_signal(Δs::StochasticAD.AbstractFIs, signal::AbstractPerturbationSignal)
Send a certain signal to a stochastic triple's perturbation collection `st.Δs` (or to a `Δs` directly),
which the backend may process as it wishes. Semantically, unbiasedness should not be affected by the
sending of the signal. The new version of the first argument (`st` or `Δs`) after signal processing is
returned.
"""
send_signal(st::Real, ::AbstractPerturbationSignal) = st
function send_signal(st::StochasticTriple{T}, signal::AbstractPerturbationSignal) where {T}
new_Δs = send_signal(st.Δs, signal)
return StochasticTriple{T}(st.value, st.δ, new_Δs)
end
"""
derivative_contribution(st::StochasticTriple)
Return the derivative estimate given by combining the dual and triple components of `st`.
"""
derivative_contribution(x::Real) = zero(x)
derivative_contribution(d::ForwardDiff.Dual) = delta(d)
derivative_contribution(st::StochasticTriple) = st.δ + derivative_contribution(st.Δs)
"""
tag(st::StochasticTriple)
tag(::Type{<:StochasticTriple}))
Get the tag of a stochastic triple.
"""
tag(::Type{<:StochasticTriple{T}}) where {T} = T
tag(::Type{<:ForwardDiff.Dual{T}}) where {T} = T
tag(::StochasticTriple{T}) where {T} = T
tag(::ForwardDiff.Dual{T}) where {T} = T
"""
valtype(st::StochasticTriple)
valtype(st::Type{<:StochasticTriple})
Get the underlying type of the value tracked by a stochastic triple.
"""
valtype(st::StochasticTriple) = valtype(typeof(st))
valtype(::Type{<:StochasticTriple{T, V}}) where {T, V} = V
"""
backendtype(st::StochasticTriple)
backendtype(st::Type{<:StochasticTriple})
Get the backend type of a stochastic triple.
"""
backendtype(st::StochasticTriple) = backendtype(typeof(st))
backendtype(::Type{<:StochasticTriple{T, V, FIs}}) where {T, V, FIs} = FIs
"""
smooth_triple(st::StochasticTriple)
Smooth the dual and triple components of a stochastic triple into a single dual component.
Useful for avoiding unnecessary pruning when running multilinear functions on triples.
"""
smooth_triple(x::Real) = x
function smooth_triple(st::StochasticTriple{T, V, FIs}) where {T, V, FIs}
return StochasticTriple{T}(value(st), derivative_contribution(st), empty(FIs))
end
### Extra constructors of stochastic triples
function StochasticTriple{T}(value::V, δ::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}
StochasticTriple{T, V, FIs}(value, δ, Δs)
end
function StochasticTriple{T}(value::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}}
StochasticTriple{T}(value, zero(value), Δs)
end
function StochasticTriple{T}(value::A, δ::B,
Δs::FIs) where {T, A, B, C, FIs <: AbstractFIs{C}}
V = promote_type(A, B, C)
StochasticTriple{T}(convert(V, value), convert(V, δ), convert(similar_type(FIs, V), Δs))
end
### Conversion rules
# TODO: is this the right thing to do? Maybe, different from the promote case because there V was guaranteed to be an ancestor.
# Also, bad to do when already same type?
function Base.convert(::Type{StochasticTriple{T1, V, FIs}},
x::StochasticTriple{T2}) where {T1, T2, V, FIs}
(T1 !== T2) && throw(ArgumentError("Tags of combined stochastic triples do not match."))
StochasticTriple{T1, V, FIs}(convert(V, x.value), convert(V, x.δ), convert(FIs, x.Δs))
end
# TODO: ForwardDiff's promotion rules are a little more complicated, see https://github.com/JuliaDiff/ForwardDiff.jl/issues/322
# May need to look into why and possibly use them here too.
function Base.promote_rule(::Type{StochasticTriple{T, V1, FIs}},
::Type{StochasticTriple{T, V2, FIs2}}) where {T, V1, FIs, V2,
FIs2}
V = promote_type(V1, V2)
StochasticTriple{T, V, similar_type(FIs, V)}
end
function Base.promote_rule(::Type{StochasticTriple{T, V1, FIs}},
::Type{V2}) where {T, V1, FIs, V2 <: Real}
V = promote_type(V1, V2)
StochasticTriple{T, V, similar_type(FIs, V)}
end
function Base.convert(::Type{StochasticTriple{T, V, FIs}}, x::Real) where {T, V, FIs}
StochasticTriple{T, V, FIs}(convert(V, x), zero(V), empty(FIs))
end
### Creating the first stochastic triple in a computation
function StochasticTriple{T}(value::V, δ::V, backend::AbstractFIsBackend) where {T, V}
StochasticTriple{T}(value, δ, create_Δs(backend, V))
end
function StochasticTriple{T}(value::V, backend::AbstractFIsBackend) where {T, V}
StochasticTriple{T}(value, zero(V), backend)
end
function StochasticTriple{T}(value::A, δ::B, backend::AbstractFIsBackend) where {T, A, B}
V = promote_type(A, B)
StochasticTriple{T}(convert(V, value), convert(V, δ), backend)
end
### Showing a stochastic triple
function Base.summary(::StochasticTriple{T, V}) where {T, V}
return "StochasticTriple of $V"
end
function Base.show(io::IO, ::MIME"text/plain", st::StochasticTriple)
println(io, "$(summary(st)):")
show(io, st)
end
function Base.show(io::IO, st::StochasticTriple)
print(io, "$(st.value) + $(st.δ)ε")
if (!isempty(st.Δs))
print(io, " + ($(repr(st.Δs)))")
end
end
### Higher level functions
struct Tag{F, V}
end
function stochastic_triple_direction(f, p::V, direction; backend) where {V}
Δs = create_Δs(backend, Int) # TODO: necessity of hardcoding some type here suggests interface improvements
sts = structural_map(p, direction) do p_i, direction_i
StochasticTriple{Tag{typeof(f), V}}(p_i, direction_i,
similar_empty(Δs, typeof(p_i)))
end
return f(sts)
end
"""
stochastic_triple(X, p; backend=PrunedFIsBackend(), direction=nothing)
stochastic_triple(p; backend=PrunedFIsBackend(), direction=nothing)
For any `p` that is supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/),
e.g. scalars or abstract arrays,
differentiate the output with respect to each value of `p`,
returning an output of similar structure to `p`, where a particular value contains
the stochastic-triple output of `X` when perturbing the corresponding value in `p`
(i.e. replacing the original value `x` with `x + ε`).
When `direction` is provided, return only the stochastic-triple output of `X` with respect to a perturbation
of `p` in that particular direction.
When `X` is not provided, the identity function is used.
The `backend` keyword argument describes the algorithm used by the third component
of the stochastic triple, see [technical details](devdocs.md) for more details.
# Example
```jldoctest
julia> using Distributions, Random, StochasticAD; Random.seed!(4321);
julia> stochastic_triple(rand ∘ Bernoulli, 0.5)
StochasticTriple of Int64:
0 + 0ε + (1 with probability 2.0ε)
```
"""
function stochastic_triple(
f, p; direction = nothing, backend::AbstractFIsBackend = PrunedFIsBackend())
if direction !== nothing
return stochastic_triple_direction(f, p, direction; backend)
end
counter = begin
c = 0
(_) -> begin
c += 1
return c
end
end
indices = structural_map(counter, p)
map_func = perturbed_index -> begin
direction = structural_map(indices, p) do i, p_i
i == perturbed_index ? one(p_i) : zero(p_i)
end
stochastic_triple_direction(f, p, direction; backend)
end
return structural_map(map_func, indices)
end
stochastic_triple(p; kwargs...) = stochastic_triple(identity, p; kwargs...)
"""
dual_number(X, p; backend=PrunedFIsBackend(), direction=nothing)
dual_number(p; backend=PrunedFIsBackend(), direction=nothing)
A lightweight wrapper around [`stochastic_triple`](#StochasticAD.stochastic_triple) that entirely ignores the
derivative contribution of all discrete random components, so that it behaves like a regular dual number.
Mostly for fun -- this, of course, leads to a useless derivative estimate for discrete random functions!
"""
function dual_number(f, p; backend = PrunedFIsBackend(), kwargs...)
backend = StrategyWrapperFIsBackend(backend, IgnoreDiscreteStrategy())
stochastic_triple(f, p; backend, kwargs...)
end
dual_number(p; kwargs...) = dual_number(identity, p; kwargs...)
function derivative_estimate(f, p; kwargs...)
StochasticAD.structural_map(derivative_contribution, stochastic_triple(f, p; kwargs...))
end
================================================
FILE: test/game_of_life.jl
================================================
using StochasticAD
using Test
using Statistics
include("../tutorials/game_of_life/core.jl")
using .GoLCore: fd_clever, play, p, nsamples
@testset "AD and Finite Differences" begin
samples_fd_clever = [fd_clever(p) for i in 1:nsamples]
samples_st = [derivative_estimate(play, p) for i in 1:nsamples]
@test mean(samples_st)≈mean(samples_fd_clever) rtol=5e-2
end
================================================
FILE: test/random_walk.jl
================================================
using StochasticAD
using Test
using Statistics
using ForwardDiff: derivative
include("../tutorials/random_walk/core.jl")
using .RandomWalkCore: n, p, nsamples
using .RandomWalkCore: fX, get_dfX
@testset "Check unbiasedness" begin
fX_deriv = derivative(p -> get_dfX(p, n), p)
fX_deriv_estimate = mean(derivative_estimate(fX, p) for i in 1:nsamples)
@test isapprox(fX_deriv, fX_deriv_estimate; rtol = 1e-2)
end
================================================
FILE: test/resampling.jl
================================================
using StochasticAD
using Random, Test
using Distributions
using LinearAlgebra
using ForwardDiff
# test forward-mode AD and reverse-mode AD on the particle filter
### Particle Filter Functions
include("../tutorials/particle_filter/core.jl")
seed = 237347578
### Define model
Random.seed!(seed)
T = 3
d = 2
A(θ, a = 0.01) = [exp(-a)*cos(θ[]) exp(-a)*sin(θ[])
-exp(-a)*sin(θ[]) exp(-a)*cos(θ[])]
obs(x, θ) = MvNormal(x, 0.01 * collect(I(d)))
dyn(x, θ) = MvNormal(A(θ) * x, 0.02 * collect(I(d)))
x0 = [2.0, 0.0] # start value of the simulation
start(θ) = Dirac(x0)
θtrue = [0.20]
# put it all together
stochastic_model = ParticleFilterCore.StochasticModel(T, start, dyn, obs)
### simulate model
Random.seed!(seed)
xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue)
###
### initialize sampler
m = 1000
particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys,
ParticleFilterCore.sample_stratified)
###
@testset "new weight" begin
p = 0.5
st = stochastic_triple(p)
d = ForwardDiff.Dual(p, (1.0, 2.0))
@test new_weight(p) == one(p)
@test StochasticAD.value(new_weight(st)) == one(p)
@test StochasticAD.delta(new_weight(st)) == 1.0 / p
@test ForwardDiff.value(new_weight(d)) == one(p)
@test collect(ForwardDiff.partials(new_weight(d))) == [1.0 / p, 2.0 / p]
end
@testset "forward-mode and reverse-mode AD: single run" begin
Random.seed!(seed)
grad_forw = ParticleFilterCore.forw_grad(θtrue, particle_filter)
Random.seed!(seed)
grad_back = ParticleFilterCore.back_grad(θtrue, particle_filter)
@test grad_forw ≈ grad_back
end
@testset "AD and Finite Differences" begin
h = 0.02 # finite diff
N = 500 # number of samples
grad_fw = [ParticleFilterCore.forw_grad(θtrue, particle_filter)[1] for i in 1:N]
# grad_bw = @time [back_grad(θtrue, particle_filter) for i in 1:N]
grad_fd = [(ParticleFilterCore.log_likelihood(particle_filter, θtrue .+ h) -
ParticleFilterCore.log_likelihood(particle_filter, θtrue .- h)) / (2h)
for i in 1:N]
@test mean(grad_fd)≈mean(grad_fw) rtol=5e-2
end
================================================
FILE: test/runtests.jl
================================================
using SafeTestsets
using Test, Pkg
import Random
Random.seed!(1234)
const GROUP = get(ENV, "GROUP", "All")
const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")
@time begin
if GROUP == "All"
@time @safetestset "Triples" begin
include("triples.jl")
end
@time @safetestset "Game of life" begin
include("game_of_life.jl")
end
@time @safetestset "Random walk" begin
include("random_walk.jl")
end
@time @safetestset "Resampling" begin
include("resampling.jl")
end
end
end
================================================
FILE: test/triples.jl
================================================
using StochasticAD
using Test
using Distributions
using ForwardDiff
using OffsetArrays
using ChainRulesCore
using Random
using Zygote
const backends = [
PrunedFIsBackend(),
PrunedFIsAggressiveBackend(),
DictFIsBackend()
]
const backends_smoothed = [
SmoothedFIsBackend(),
StrategyWrapperFIsBackend(SmoothedFIsBackend(), StochasticAD.TwoSidedStrategy())
]
@testset "Distributions w.r.t. continuous parameter" begin
for backend in vcat(backends,
backends_smoothed,
:smoothing_autodiff)
MAX = 10000
nsamples = 100000
rtol = 5e-2 # friendly tolerance for stochastic comparisons. TODO: more motivated choice of tolerance.
### Make test cases
distributions = [
Bernoulli,
Geometric,
Poisson,
(p -> Categorical([p^2, 1 - p^2])),
(p -> Categorical([0, p^2, 0, 0, 1 - p^2])), # check that 0's are skipped over
(p -> Categorical([1.0, exp(p)] ./ (1.0 + exp(p)))), # test fix for #38 (floating point comparisons in Categorical logic)
(p -> Binomial(3, p)),
(p -> Binomial(20, p))
]
p_ranges = [(0.2, 0.8) for _ in 1:8]
out_ranges = [0:1, 0:MAX, 0:MAX, 1:2, 1:5, 1:2, 0:3, 0:20]
test_cases = collect(zip(distributions, p_ranges, out_ranges))
test_funcs = [x -> 7 * x - 3, x -> (x + 1)^2, x -> sqrt(x + 1)]
if backend isa DictFIsBackend
# Only test dictionary backend on Bernoulli to speed things up. Should still cover interface.
test_cases = test_cases[1:1]
elseif backend == :smoothing_autodiff || backend in backends_smoothed
# Only test smoothing backend on each unique distribution once to seed tests up.
test_cases = vcat(test_cases[1:4], test_cases[7])
# Only test unbiasedness of smoothing for linear function
test_funcs = test_funcs[1:1]
end
for (distr, p_range, out_range) in test_cases
for f in test_funcs
function get_mean(p)
dp = distr(p)
sum(pdf(dp, i) * f(i) for i in out_range)
end
low_p, high_p = p_range
for g in [p -> p, p -> high_p + low_p - p] # test both sides of derivative
full_func = f ∘ rand ∘ distr ∘ g
p = low_p + (high_p - low_p) * rand()
exact_deriv = ForwardDiff.derivative(p -> get_mean(g(p)), p)
if backend == :smoothing_autodiff
batched_full_func(p) = mean([full_func(p) for i in 1:nsamples])
# The array input used for ForwardDiff below is a trick to test multiple partials
triple_deriv_forward = mean(ForwardDiff.gradient(
arr -> batched_full_func(sum(arr)),
[2 * p, -p]))
triple_deriv_backward = Zygote.gradient(batched_full_func, p)[1]
@test isapprox(triple_deriv_forward, exact_deriv, rtol = rtol)
@test isapprox(triple_deriv_backward, exact_deriv, rtol = rtol)
else
get_deriv = () -> derivative_estimate(full_func, p; backend)
triple_deriv = mean(get_deriv() for i in 1:nsamples)
@test isapprox(triple_deriv, exact_deriv, rtol = rtol)
end
end
end
end
end
end
@testset "Perturbing n of binomial" begin
function get_triple_deriv(Δ)
# Manually create a finite perturbation to avoid any randomness in its creation
Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), Δ,
3.5)
st = StochasticAD.StochasticTriple{0}(5, 0, Δs)
st_continuous = stochastic_triple(0.5)
return derivative_contribution(rand(Binomial(st, st_continuous)))
end
for Δ in -2:2
triple_deriv = mean(get_triple_deriv(Δ) for i in 1:100000)
exact_deriv = 3.5 * 0.5 * Δ + 5
@test isapprox(triple_deriv, exact_deriv, rtol = 5e-2)
end
end
@testset "Nested binomials" begin
binbin = p -> rand(Binomial(rand(Binomial(10, p)), p)) # ∼ Binomial(10, p^2)
for p in [0.3, 0.7]
triple_deriv = mean(derivative_estimate(binbin, p) for i in 1:100000)
exact_deriv = 10 * 2 * p
@test isapprox(triple_deriv, exact_deriv, rtol = 5e-2)
end
end
@testset "Boolean comparisons" begin
for backend in backends
tested = falses(2)
while !(all(tested))
st = stochastic_triple(rand ∘ Bernoulli, 0.5; backend)
x = StochasticAD.value(st)
if x == 0
# Ensure errors on unsafe/unsupported boolean comparisons
@test_throws Exception st>0.5
@test_throws Exception 0.5<st
@test_throws Exception st==0
else
@test st > 0.5
@test 0.5 < st
@test st == 1
end
tested[x + 1] = true
end
@test stochastic_triple(1.0; backend) != 1
end
end
@testset "Array indexing" begin
for backend in vcat(backends, backends_smoothed)
p = 0.3
# Test indexing into array of floats with stochastic triple index
arr = [3.5, 5.2, 8.4]
(backend in backends_smoothed) && (arr[3] = 6.9) # make linear for smoothing test
function array_index(p)
index = rand(Categorical([p / 2, p / 2, 1 - p]))
return arr[index]
end
array_index_mean(p) = sum([p / 2, p / 2, (1 - p)] .* arr)
triple_array_index_deriv = mean(derivative_estimate(array_index, p; backend)
for i in 1:50000)
exact_array_index_deriv = ForwardDiff.derivative(array_index_mean, p)
@test isapprox(triple_array_index_deriv, exact_array_index_deriv, rtol = 5e-2)
# Don't run subsequent tests with smoothing backend
(backend in backends_smoothed) && continue
# Test indexing into array of stochastic triples with stochastic triple index
function array_index2(p)
arr2 = [rand(Bernoulli(p)), rand(Bernoulli(p)), rand(Bernoulli(p))] .* arr
index = rand(Categorical([p / 2, p / 2, 1 - p]))
return arr2[index]
end
array_index2_mean(p) = sum([p / 2 * p, p / 2 * p, (1 - p) * p] .* arr)
triple_array_index2_deriv = mean(derivative_estimate(array_index2, p; backend)
for i in 1:50000)
exact_array_index2_deriv = ForwardDiff.derivative(array_index2_mean, p)
@test isapprox(triple_array_index2_deriv, exact_array_index2_deriv, rtol = 5e-2)
# Test case where triple and alternate array value are coupled
function array_index3(p)
st = rand(Bernoulli(p))
arr2 = [-5, st]
return arr2[st + 1]
end
array_index3_mean(p) = -5 * (1 - p) + 1 * p
triple_array_index3_deriv = mean(derivative_estimate(array_index3, p; backend)
for i in 1:50000)
exact_array_index3_deriv = ForwardDiff.derivative(array_index3_mean, p)
@test isapprox(triple_array_index3_deriv, exact_array_index3_deriv, rtol = 5e-2)
end
end
@testset "Array/functor inputs to higher level functions" begin
for backend in backends
# Try a deterministic test function to compare to ForwardDiff
f(x) = (x[1] * x[2] * sin(x[3]) + exp(x[1] * x[2])) / x[3]
x = [1, 2, π / 2]
stochastic_ad_grad = derivative_estimate(f, x; backend)
stochastic_ad_grad2 = derivative_contribution.(stochastic_triple(f, x; backend))
stochastic_ad_grad_firsttwo = derivative_estimate(
f, x; direction = [1.0, 1.0, 0.0],
backend)
fd_grad = ForwardDiff.gradient(f, x)
@test stochastic_ad_grad ≈ fd_grad
@test stochastic_ad_grad ≈ stochastic_ad_grad2
@test stochastic_ad_grad[1] + stochastic_ad_grad[2] ≈ stochastic_ad_grad_firsttwo
# Try an OffsetArray too
f_off(x) = (x[0] * x[1] * sin(x[2]) + exp(x[0] * x[1])) / x[2]
x_off = OffsetArray([1, 2, π / 2], 0:2)
stochastic_ad_grad_off = derivative_estimate(f_off, x_off)
@test stochastic_ad_grad_off ≈ OffsetArray(stochastic_ad_grad, 0:2)
# Try a Functor
f_func(x) = (x[1] * x[2][1] * sin(x[2][2]) + exp(x[1] * x[2][1])) / x[2][2]
x_func = (1, [2, π / 2])
stochastic_ad_grad_func = derivative_estimate(f_func, x_func)
stochastic_ad_grad_func_expected = (stochastic_ad_grad[1], stochastic_ad_grad[2:3])
compare_grad_funcs = StochasticAD.structural_map(≈, stochastic_ad_grad_func,
stochastic_ad_grad_func_expected)
@test all(compare_grad_funcs |> StochasticAD.structural_iterate)
# Test StochasticModel + stochastic_gradient combination
m = StochasticModel(f, x)
@test stochastic_gradient(m).p ≈ stochastic_ad_grad
end
end
@testset "Propagation using frule with ZeroTangent" begin
st = stochastic_triple(0.5)
# Verify that the rule for imag indeed gives a ZeroTangent
value = StochasticAD.value(st)
δ = StochasticAD.delta(st)
@test frule((NoTangent(), δ), imag, value)[2] isa ZeroTangent
# Test that stochastic triples flow through this rule
out_st = imag(st)
@test StochasticAD.value(out_st) ≈ 0
@test StochasticAD.delta(out_st) ≈ 0
@test isempty(out_st.Δs)
end
@testset "Unary functions converting type to fixed instance" begin
for val in [0.5, 1]
st = stochastic_triple(val)
for op in StochasticAD.UNARY_TYPEFUNCS_WRAP
f = getfield(Base, Symbol(op))
out_st = f(st)
@test out_st isa StochasticAD.StochasticTriple
@test StochasticAD.value(out_st) ≈ f(val) ≈ f(typeof(val))
@test StochasticAD.delta(out_st) ≈ 0
@test isempty(out_st.Δs)
@test f(typeof(st)) == out_st
end
#=
It so happens that the UNARY_TYPEFUNCS_WRAP funcs all support both instances and types
whereas UNARY_TYPEFUNCS_NOWRAP only supports types, so we only test types in the below,
but this is a coincidence that may not hold in the future.
=#
for op in StochasticAD.UNARY_TYPEFUNCS_NOWRAP
f = getfield(Base, Symbol(op))
out = f(typeof(st))
@test out isa typeof(val)
@test out ≈ f(typeof(val))
end
RNG = copy(Random.GLOBAL_RNG)
for op in StochasticAD.RNG_TYPEFUNCS_WRAP
f = getfield(Random, Symbol(op))
out_st = f(copy(RNG), typeof(st))
@test out_st isa StochasticAD.StochasticTriple
@test StochasticAD.value(out_st) ≈ f(copy(RNG), typeof(val))
@test StochasticAD.delta(out_st) ≈ 0
@test isempty(out_st.Δs)
end
end
end
@testset "Hashing" begin
st = stochastic_triple(3.0)
@test_nowarn hash(st)
@test_nowarn hash(st, UInt(5))
d = Dict()
@test_nowarn d[st] = 5
@test d[st] == 5
@test d[3] == 5
# Test that we get an error with discrete random dictionary indices,
# since this isn't supported and we want to avoid silent failures.
Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0, 1.0)
st = StochasticAD.StochasticTriple{0}(1.0, 0, Δs)
@test_throws ErrorException d[rand(Bernoulli(st))]
end
@testset "Coupled comparison" begin
Δs_1 = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0,
1.0)
Δs_2 = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0,
1.0)
st_1 = StochasticAD.StochasticTriple{0}(1.0, 0, Δs_1)
st_2 = StochasticAD.StochasticTriple{0}(1.0, 0, Δs_2)
@test st_1 == st_1
@test_throws ErrorException st_1==st_2
end
@testset "Converting float stochastic triples to integer triples" begin
st = stochastic_triple(0.6)
@test round(Int, st) isa StochasticAD.StochasticTriple
@test StochasticAD.delta(round(Int, st)) ≈ 0
@test round(Int, st) ≈ 1
end
@testset "Approximate comparisons" begin
st = stochastic_triple(0.5)
@test st ≈ st
# Check that the rtol is indeed reasonable
@test st ≈ st + 1e-14
@test !(st ≈ st + 1)
@test_broken stochastic_triple(Inf) ≈ stochastic_triple(Inf)
end
@testset "Error on unmatched tags" begin
st1 = stochastic_triple(0.5)
st2 = stochastic_triple(x -> x^2, 0.5)
@test_throws ArgumentError convert(typeof(st1), st2)
end
@testset "Finite perturbation backend interface" begin
for backend in vcat(backends,
backends_smoothed)
# this boolean may need to become more fine-grained in the future
is_smoothed_backend = backend in backends_smoothed
#=
Test the backend interface across the finite perturbation backends,
which is currently a bit implicitly defined.
=#
V0 = Int
V1 = Float64
#=
All four of the below approaches should create an empty backend,
although the backend's internal state management may differ.
=#
Δs0 = StochasticAD.create_Δs(backend, V0) # used to create first triple in computation
FIs = typeof(Δs0)
Δs1 = empty(Δs0)
Δs2 = empty(typeof(Δs0))
Δs3 = StochasticAD.similar_empty(Δs0, V1)
for (Δs, V) in ((Δs0, V0), (Δs1, V0), (Δs2, V0), (Δs3, V1))
@test StochasticAD.valtype(Δs) === V
@test Δs isa StochasticAD.similar_type(FIs, V)
!is_smoothed_backend && @test isempty(Δs)
@test iszero(derivative_contribution(Δs))
end
# Test creation of a single perturbation
for Δ in (1, 3.0)
Δs0 = StochasticAD.create_Δs(backend, V0)
Δs1 = StochasticAD.similar_new(Δs0, Δ, 3.0)
@test StochasticAD.valtype(Δs1) === typeof(Δ)
@test Δs1 isa StochasticAD.similar_type(FIs, typeof(Δ))
!is_smoothed_backend && @test !isempty(Δs1)
@test derivative_contribution(Δs1) == 3Δ
# Test StochasticAD.alltrue
@test StochasticAD.alltrue(_Δ -> true, Δs1)
@test !StochasticAD.alltrue(_Δ -> false, Δs1) || is_smoothed_backend
# Test map
# We use a dummy deriv here and below. TODO: use a more interesting dummy for better testing.
Δs1_map = Base.map(Δ -> Δ^2, Δs1; deriv = identity, out_rep = Δ)
!is_smoothed_backend && @test derivative_contribution(Δs1_map) ≈ Δ^2 * 3.0
# Test map with weight (make a new copy so that original does not get reweighted)
Δs2 = StochasticAD.similar_new(StochasticAD.create_Δs(backend, V0), Δ, 3.0)
Δs2_weight_map = StochasticAD.weighted_map_Δs((Δ, _) -> (Δ^2, 2.0),
Δs2;
deriv = identity,
out_rep = Δ)
!is_smoothed_backend &&
@test derivative_contribution(Δs2_weight_map) ≈ Δ^2 * 3.0 * 2.0
# Also test scale
w2 = derivative_contribution(Δs2)
Δs2_scaled = StochasticAD.scale(Δs2, 2.0)
w2_scaled = derivative_contribution(Δs2_scaled)
@test w2_scaled ≈ 2.0 * w2
# Test map_Δs with filter state
if !is_smoothed_backend
Δs1_plus_Δs0 = StochasticAD.map_Δs(
(Δ, state) -> Δ +
StochasticAD.filter_state(Δs0,
state),
Δs1)
@test derivative_contribution(Δs1_plus_Δs0) ≈ Δ * 3.0
Δs1_plus_mapped = StochasticAD.map_Δs(
(Δ, state) -> Δ +
StochasticAD.filter_state(Δs1,
state),
Δs1_map)
@test derivative_contribution(Δs1_plus_mapped) ≈ Δ * 3.0 + Δ^2 * 3.0
end
end
# Test coupling
Δ_coupleds = (3, [4.0, 5.0], (2, [3.0, 4.0]))
for Δ_coupled in Δ_coupleds
function get_Δs_coupled(; do_combine = false, use_get_rep = false)
Δs0 = StochasticAD.create_Δs(backend, Int)
Δs1 = StochasticAD.similar_new(Δs0, 1, 3.0) # perturbation 1
Δs2 = StochasticAD.similar_new(Δs0, 1, 2.0) # perturbation 2
# A group of perturbations that all stem from perturbation 1.
Δs_all1 = StochasticAD.structural_map(Δ_coupled) do Δ
Base.map(_Δ -> Δ, Δs1; deriv = identity, out_rep = Δ)
end
# A group of perturbations that all stem from perturbation 2.
Δs_all2 = StochasticAD.structural_map(Δ_coupled) do Δ
Base.map(_Δ -> 2 * Δ, Δs2; deriv = (δ -> 2δ), out_rep = Δ)
end
# Join them into a single structure that should be coupled
Δs_all = (Δs_all1, Δs_all2)
kwargs = use_get_rep ? (; rep = StochasticAD.get_rep(FIs, Δs_all)) : (;)
if do_combine
return StochasticAD.combine(FIs, Δs_all; kwargs...)
else
return StochasticAD.couple(FIs, Δs_all;
out_rep = (Δ_coupled, Δ_coupled),
kwargs...)
end
end
#=
As a test function to apply to the coupled perturbation, we apply
a matmul followed by a sigmoid activation function and a sum.
=#
l = 2 * length(collect(StochasticAD.structural_iterate(Δ_coupled)))
A = rand(l, l)
function mapfunc(Δ_coupled)
arr = collect(StochasticAD.structural_iterate(Δ_coupled))
sum(x -> 1 / (1 + exp(-x)), A * arr)
end
# Test the above function, and also a simple sum.
for use_get_rep in (false, true)
Δs_coupled = get_Δs_coupled(; use_get_rep)
@test StochasticAD.valtype(Δs_coupled) == typeof((Δ_coupled, Δ_coupled))
for (mapfunc, check_combine) in ((mapfunc, false),
(Δ_coupled -> sum(StochasticAD.structural_iterate(Δ_coupled)),
true))
function get_contribution()
Δs_coupled = get_Δs_coupled(; use_get_rep)
Δs_coupled_mapped = map(mapfunc, Δs_coupled; deriv = (δ -> 1.0),
out_rep = 0.0)
return derivative_contribution(Δs_coupled_mapped)
end
zero_Δ_coupled = StochasticAD.structural_map(zero, Δ_coupled)
expected_contribution1 = 3.0 * mapfunc((Δ_coupled, zero_Δ_coupled))
expected_contribution2 = 2.0 * mapfunc((zero_Δ_coupled,
StochasticAD.structural_map(x -> 2x,
Δ_coupled)))
expected_contribution = expected_contribution1 + expected_contribution2
if !is_smoothed_backend
@test isapprox(mean(get_contribution() for i in 1:1000),
expected_contribution; rtol = 5e-2)
end
# For a simple sum, this should be equivalent to the combine behaviour.
if check_combine && !is_smoothed_backend
@test isapprox(
mean(derivative_contribution(get_Δs_coupled(;
do_combine = true))
for i in 1:1000),
expected_contribution;
rtol = 5e-2)
end
# Check scalarize
Δs_coupled2 = StochasticAD.couple(FIs,
StochasticAD.scalarize(Δs_coupled;
out_rep = (Δ_coupled,
Δ_coupled)),
out_rep = (Δ_coupled, Δ_coupled))
@test derivative_contribution(map(mapfunc, Δs_coupled;
deriv = (δ -> 1.0),
out_rep = 0.0)) ≈
derivative_contribution(map(mapfunc, Δs_coupled2;
deriv = (δ -> 1.0),
out_rep = 0.0))
end
end
end
end
end
@testset "Getting information about stochastic triples" begin
for backend in vcat(backends,
backends_smoothed)
Random.seed!(4321)
f(x) = rand(Bernoulli(x)) + x
st = stochastic_triple(f, 0.5; backend)
# Expected: 0.5 + 1.0ε + (1.0 with probability 2.0ε)
dual = ForwardDiff.Dual(0.5, 1.0)
@test StochasticAD.value(0.5) == 0.5
@test StochasticAD.value(st) == 0.5
@test StochasticAD.value(dual) == 0.5
@test iszero(StochasticAD.delta(0.5))
@test StochasticAD.delta(st) == 1.0
@test StochasticAD.delta(dual) == 1.0
if !(backend in backends_smoothed)
#=
NB: since the implementation of perturbations can be backend-specific, the
below property need not hold in general, but does for the current non-smoothed backends.
=#
p = only(perturbations(st))
@test p.Δ == 1 && p.weight == 2.0
@test derivative_contribution(st) == 3.0
else
# Since smoothed algorithm uses the two-sided strategy, we get a different derivative contribution.
@test derivative_contribution(st) == 2.0
end
@test StochasticAD.tag(st) === StochasticAD.Tag{typeof(f), Float64}
@test StochasticAD.valtype(st) === Float64
@test StochasticAD.valtype(st.Δs) === Float64
end
end
@testset "Propagation via StochasticAD.propagate" begin
for backend in backends
function form_triple(primal, δ, Δ, Δs_base)
Δs = map(_Δ -> Δ, Δs_base)
return StochasticAD.StochasticTriple{0}(primal, δ, Δs)
end
function test_propagate(f, primals, Δs; test_deltas = false)
Δs_base = StochasticAD.similar_new(StochasticAD.create_Δs(backend, Int),
0, 1.0)
_form_triple(x, δ, Δ) = form_triple(x, δ, Δ, Δs_base)
out = f(primals...)
out_Δ_expected = StochasticAD.structural_map(-,
f(StochasticAD.structural_map(+,
primals,
Δs)...),
f(primals...))
if test_deltas
duals = StochasticAD.structural_map(primals) do x
x isa AbstractFloat ? ForwardDiff.Dual{0}(x, rand(typeof(x))) : x
end
δs = StochasticAD.structural_map(StochasticAD.delta, duals)
out_δ_expected = StochasticAD.structural_map(StochasticAD.delta,
f(duals...))
else
δs = StochasticAD.structural_map(zero, primals)
out_δ_expected = StochasticAD.structural_map(zero, out)
end
input_sts = StochasticAD.structural_map(_form_triple, primals, δs, Δs)
out_st = StochasticAD.propagate(f, input_sts...; keep_deltas = Val{test_deltas})
# Test type
StochasticAD.structural_map(out_st, out, out_δ_expected,
out_Δ_expected) do x_st, x, δ, Δ
@test x_st isa StochasticAD.StochasticTriple{0, typeof(x)}
@test StochasticAD.value(x_st) == x
@test StochasticAD.delta(x_st) ≈ δ
p = only(perturbations(x_st))
@test p.Δ == Δ && p.weight == 1.0
end
end
#=
Test propagation through some simple functions.
f1: a simple if statement.
f2: involves array-containing-fucntor input and output.
f3: involves array-containing-functor input, but real output.
f4: length ∘ repr (real or array input, real output).
f5: mutates input array! Broken since unsupported.
f6: the first-arg (blob) should just be passed through without attempting
to perturb. Broken since unsupported.
f7: involves matrix-containing-functor input and output.
=#
function f1(x)
if x == 0
return 1
elseif x == 3
return 2
else
return 5
end
end
@test StochasticAD.propagate(f1, 0) === f1(0)
for (primal, Δ) in [(0, 3), (0, 4), (3, -1)]
test_propagate(f1, (primal,), (Δ,))
end
function f2(arr, scalar)
if sum(arr) + scalar <= 5
return arr .* scalar, sum(arr) * scalar
else
return arr .- scalar, sum(arr) - scalar
end
end
f3(arr, scalar) = f2(arr, scalar)[2]
primals1 = ([1, 1], 2)
Δs1 = ([2, 3], 5)
primals2 = ([1, 2], 1)
Δs2 = ([1, -2], 1)
primals3 = ([5, 2], -1)
Δs3 = ([-3, 1], 0)
for (primals, Δs) in [(primals1, Δs1), (primals2, Δs2), (primals3, Δs3)]
for test_deltas in (false, true)
if test_deltas
primals = StochasticAD.structural_map(float, primals)
Δs = StochasticAD.structural_map(float, Δs)
end
test_propagate(f2, primals, Δs; test_deltas)
test_propagate(f3, primals, Δs; test_deltas)
end
end
f4(x) = Base.length(repr(x))
for (primals, Δs) in [(2, 11), (([3, 14],), ([14, -152],))]
test_propagate(f4, primals, Δs)
end
function f5(arr)
if arr == [1, 2]
arr .+= 1
else
arr .-= 1
end
end
# Tests for f6 skipped (would break)
for (primals, Δs) in [([1, 2], [1, -1]), ([2, 4], [-1, -2]), ([2, 4], [-1, -1])]
@test_skip "propagate f5"
# test_propagate(f5, primals, Δs)
end
f6(blob, arr) = blob, f5(arr)
# Tests for f6 missing (would break)
@test_skip "propagate f6"
function f7(mat, scalar)
return mat * scalar, scalar + sum(mat)
end
test_propagate(f7, (rand(2, 2), 4.0), (rand(2, 2), 1.0); test_deltas = true)
end
end
@testset "zero'ing of Inf/NaN (#79)" begin
st = stochastic_triple(0.5)
st_zero = zero(1 / zero(st))
@test iszero(StochasticAD.value(st_zero))
@test iszero(StochasticAD.delta(st_zero))
end
@testset "smooth_triple" begin
f(p) = sum(rand(Bernoulli(p)) * i for i in 1:100)
f2(p) = sum(smooth_triple(rand(Bernoulli(p))) * i for i in 1:100)
p = 0.6
f_est = mean(derivative_estimate(f, p) for i in 1:10000)
f2_est = mean(derivative_estimate(f2, p) for i in 1:10000)
@test f_est≈f2_est rtol=5e-2
end
@testset "No unnecessary float promotion" begin
f(p) = rand(Bernoulli(p))^2
st = stochastic_triple(f, 0.5)
@test StochasticAD.valtype(st) == typeof(convert(Signed, f(0.5)))
end
================================================
FILE: tutorials/Project.toml
================================================
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
================================================
FILE: tutorials/README.md
================================================
The raw source code for our tutorials. You likely want to look at the documentation instead, where they are presented more clearly!
================================================
FILE: tutorials/game_of_life/core.jl
================================================
module GoLCore
using Random
using Distributions
using LinearAlgebra
using StochasticAD
using StaticArrays
using OffsetArrays
function update_state!(all_probs, N, board, board_old)
for i in (-N):N
for j in (-N):N
neighbours = board_old[i + 1, j] + board_old[i - 1, j] + board_old[i, j - 1] +
board_old[i, j + 1]
index = board[i, j] * 5 + neighbours + 1 # trick necessary because we do not have implementation support for stochastic triple not <: Real
b = rand(Bernoulli(all_probs[index]))
board[i, j] += (1 - 2 * board[i, j]) * b
end
end
end
function play_game_of_life(p, all_probs, N, T, log = false)
dual_type = promote_type(typeof(rand(Bernoulli(p))),
typeof.(rand.(Bernoulli.(all_probs)))...) # TODO: better way of getting the concrete type
board = OffsetArray(zeros(dual_type, 2 * N + 3, 2 * N + 3), (-(N + 1)):(N + 1),
(-(N + 1)):(N + 1)) # pad by 1
for i in (-N):N
for j in (-N):N
board[i, j] = rand(Bernoulli(p))
end
end
board_old = similar(board)
log && (history = [])
for time_step in 1:T
copy!(board_old, board)
update_state!(all_probs, N, board, board_old)
log && push!(history, copy(board))
end
if !log
return sum(board)
else
return sum(board), board, history
end
end
function play(p, θ = 0.1, N = 3, T = 3; log = false)
# N is the board half-length, T are game time steps
low = θ
high = 1 - θ
birth_probs = SA[low, low, low, high, low] # 0, 1, 2, 3, 4 neighbours
death_probs = SA[high, high, low, low, high] # 0, 1, 2, 3, 4 neighbours
return play_game_of_life(p, vcat(birth_probs, death_probs), N, T, log)
end
# An implementation of finite differences that uses "common random numbers"
# (the same seed), for more accurate checking, albeit with a finite step size h
# such that there is weight degeneracy as h → 0.
function fd_clever(p, h = 0.01)
state = copy(Random.default_rng())
run1 = play(p + h)
copy!(Random.default_rng(), state)
run2 = play(p - h)
(run1 - run2) / (2h)
end
# Provide some default parameters
p = 0.5
nsamples = 200_000
end
================================================
FILE: tutorials/game_of_life/plot_board.jl
================================================
include("core.jl")
using Plots
using Statistics
using BenchmarkTools
p = 0.5
_, board, history = stochastic_triple(p -> GoLCore.play(p; log = true), p)
anim1 = @animate for (i, board) in enumerate(history)
heatmap(collect(StochasticAD.value.(board)), title = "time $i", clim = (-1, 1),
c = :grays)
end
anim2 = @animate for (i, board) in enumerate(history)
heatmap(collect(StochasticAD.derivative_contribution.(board)), title = "time $i",
clim = (-1, 1), c = :grays)
end
gif(anim1, "game.gif", fps = 15)
gif(anim2, "perturbation.gif", fps = 15)
fig1 = heatmap(collect(StochasticAD.value.(board)), clim = (-1, 1), c = :grays)
fig2 = heatmap(collect(derivative_contribution.(board)), clim = (-1, 1), c = :grays) # TODO: graph perturbed values instead of derivative contribution
savefig(fig1, "board.png")
savefig(fig2, "perturbation.png")
================================================
FILE: tutorials/particle_filter/benchmark.jl
================================================
include("core.jl")
include("model.jl")
using Plots, LaTeXStrings
using BenchmarkTools
using Measurements
# Benchmark for primal, forward- and reverse-mode AD of particle sampler
### compute gradients
# secs for how long the benchmark should run, see https://juliaci.github.io/BenchmarkTools.jl/stable/
secs = 10
suite = BenchmarkGroup()
suite["scaling"] = BenchmarkGroup(["grads"])
suite["scaling"]["primal"] = @benchmarkable ParticleFilterCore.log_likelihood(
particle_filter,
θtrue)
suite["scaling"]["forward"] = @benchmarkable ParticleFilterCore.forw_grad(θtrue,
particle_filter)
suite["scaling"]["backward"] = @benchmarkable ParticleFilterCore.back_grad(θtrue,
particle_filter)
tune!(suite)
results = run(suite, verbose = true, seconds = secs)
t1 = measurement(mean(results["scaling"]["primal"].times),
std(results["scaling"]["primal"].times) /
sqrt(length(results["scaling"]["primal"].times)))
t2 = measurement(mean(results["scaling"]["forward"].times),
std(results["scaling"]["forward"].times) /
sqrt(length(results["scaling"]["forward"].times)))
t3 = measurement(mean(results["scaling"]["backward"].times),
std(results["scaling"]["backward"].times) /
sqrt(length(results["scaling"]["backward"].times)))
@show t1 t2 t3
ts = (t1, t2, t3) ./ 10^6 # ms
@show ts
BenchmarkTools.save("benchmark_data_" * string(d) * ".json", results)
================================================
FILE: tutorials/particle_filter/bias.jl
================================================
include("core.jl")
include("model.jl")
using Plots, LaTeXStrings
using Random
# Comparison of the derivative of the particle filter with and without differentiating the resampling step.
### compute gradients
Random.seed!(seed)
X = [ParticleFilterCore.forw_grad(θtrue, particle_filter) for i in 1:1000] # gradient of the particle filter *with* differentiation of the resampling step
Random.seed!(seed)
Xbiased = [ParticleFilterCore.forw_grad_biased(θtrue, particle_filter) for i in 1:1000] # Gradient of the particle filter *without* differentiation of the resampling step
# pick an arbitrary coordinate
index = 1 # take derivative with respect to first parameter (2-dimensional example has a rotation matrix with four parameters in total)
# plot histograms for the sampled derivative values
fig = plot(normalize(fit(Histogram, getindex.(X, index), nbins = 50), mode = :pdf),
legend = false) # ours
plot!(normalize(fit(Histogram, getindex.(Xbiased, index), nbins = 50), mode = :pdf)) # biased
vline!([mean(X)[index]], color = 1)
vline!([mean(Xbiased)[index]], color = 2)
# add derivative of differentiable Kalman filter as a comparison
XK = ParticleFilterCore.forw_grad_Kalman(θtrue, kalman_filter)
vline!([XK[index]], color = "black")
display(fig)
savefig(fig, "tails.pdf")
================================================
FILE: tutorials/particle_filter/core.jl
================================================
module ParticleFilterCore
# load dependencies
using Distributions
using DistributionsAD
using Random
using Statistics
using StatsBase
using LinearAlgebra
using Zygote
using StochasticAD
using ForwardDiff
using GaussianDistributions
using GaussianDistributions: correct, ⊕
using UnPack
### Particle Filter Functions
# Model defs
"""
StochasticModel{dType<:Integer,TType<:Integer,T1,T2,T3}
For parameters `θ`, `rand(start(θ))` gives a sample from the prior distribution of the
starting distribution. For current state `x` and parameters `θ`, `xnew = rand(dyn(x, θ))`
samples the new state (i.e. `dyn` gives for each `x, θ` a distribution-like object). Finally,
`y = rand(obs(x, θ))` samples an observation.
## Constructor
- `T`: total number of time steps.
- `start`: starting distribution for the initial state. For example, in the form of a narrow
Gaussian `start(θ) = Gaussian(x0, 0.001 * I(d))`.
- `dyn`: pointwise differentiable stochastic program in the form of Markov transition densities.
For example, `dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q(θ))`, where `Q(θ)` denotes the
covariance matrix.
- `obs`: observation model having a smooth conditional probability density depending on
current state `x` and parameters `θ`. For example, `obs(x, θ) = MvNormal(x, R(θ))`,
where `R(θ)` denotes the covariance matrix.
"""
struct StochasticModel{TType <: Integer, T1, T2, T3}
T::TType # time steps
start::T1 # prior
dyn::T2 # dynamical model
obs::T3 # observation model
end
# Particle filter
"""
ParticleFilter{mType<:Integer,MType<:StochasticModel,yType,sType}
Wraps a stochastic model `StochM::StochasticModel` and observational data `ys`.
Assumes a observation-likelihood is available via `pdf(obs(x, θ), y)`.
## Constructor
- `m`: number of particles.
- `StochM`: stochastic model of type `StochasticModel`.
- `ys`: observations.
- `sample_strategy`: strategy for the resampling step of the particle filter. For example,
stratified sampling as implemented in `sample_stratified`.
"""
struct ParticleFilter{mType <: Integer, MType <: StochasticModel, yType, sType}
m::mType # number of particles
StochM::MType # stochastic model
ys::yType # observations
sample_strategy::sType # sampling function
end
# Kalman filter
"""
KalmanFilter{dType<:Integer,MType<:StochasticModel,HType,RType,QType,yType}
Differentiable Kalman filter following https://github.com/mschauer/Kalman.jl/blob/master/README.md.
Wraps a stochastic mode
gitextract__gvecf0r/
├── .JuliaFormatter.toml
├── .git-blame-ignore-revs
├── .github/
│ └── workflows/
│ ├── CI.yml
│ ├── CompatHelper.yml
│ ├── Documentation.yml
│ ├── FormatCheck.yml
│ ├── TagBot.yml
│ └── benchmark.yml
├── .gitignore
├── CITATION.bib
├── LICENSE
├── Project.toml
├── README.md
├── benchmark/
│ ├── benchmarks.jl
│ ├── game_of_life.jl
│ ├── iteration.jl
│ ├── random_walk.jl
│ ├── runbenchmarks.jl
│ ├── simple_ops.jl
│ └── utils.jl
├── docs/
│ ├── Project.toml
│ ├── make.jl
│ └── src/
│ ├── assets/
│ │ └── extra_styles.css
│ ├── devdocs.md
│ ├── index.md
│ ├── limitations.md
│ ├── public_api.md
│ └── tutorials/
│ ├── game_of_life.md
│ ├── optimizations.md
│ ├── particle_filter.md
│ ├── random_walk.md
│ └── reverse_demo.md
├── ext/
│ └── StochasticADEnzymeExt.jl
├── src/
│ ├── StochasticAD.jl
│ ├── algorithms.jl
│ ├── backends/
│ │ ├── abstract_wrapper.jl
│ │ ├── dict.jl
│ │ ├── pruned.jl
│ │ ├── pruned_aggressive.jl
│ │ ├── smoothed.jl
│ │ └── strategy_wrapper.jl
│ ├── discrete_randomness.jl
│ ├── finite_infinitesimals.jl
│ ├── general_rules.jl
│ ├── misc.jl
│ ├── prelude.jl
│ ├── propagate.jl
│ ├── smoothing.jl
│ └── stochastic_triple.jl
├── test/
│ ├── game_of_life.jl
│ ├── random_walk.jl
│ ├── resampling.jl
│ ├── runtests.jl
│ └── triples.jl
└── tutorials/
├── Project.toml
├── README.md
├── game_of_life/
│ ├── core.jl
│ └── plot_board.jl
├── particle_filter/
│ ├── benchmark.jl
│ ├── bias.jl
│ ├── core.jl
│ ├── model.jl
│ ├── variance.jl
│ └── visualize.jl
├── random_walk/
│ ├── compare_score.jl
│ ├── core.jl
│ └── show_unbiased.jl
├── reverse_example/
│ └── reverse_demo.jl
└── toy_optimizations/
├── Project.toml
├── igarch.jl
├── intro.jl
└── variational.jl
Condensed preview — 72 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (236K chars).
[
{
"path": ".JuliaFormatter.toml",
"chars": 16,
"preview": "style = \"sciml\"\n"
},
{
"path": ".git-blame-ignore-revs",
"chars": 376,
"preview": "# Run this command to always ignore these in local `git blame`:\n# git config blame.ignoreRevsFile .git-blame-ignore-revs"
},
{
"path": ".github/workflows/CI.yml",
"chars": 602,
"preview": "name: CI\non:\n pull_request:\n push:\n branches:\n - main \n tags: '*'\njobs:\n unittest:\n runs-on: ubuntu-lat"
},
{
"path": ".github/workflows/CompatHelper.yml",
"chars": 457,
"preview": "name: CompatHelper\non:\n schedule:\n - cron: 0 0 * * *\n workflow_dispatch:\njobs:\n CompatHelper:\n runs-on: ubuntu-"
},
{
"path": ".github/workflows/Documentation.yml",
"chars": 727,
"preview": "name: Documentation\n\non:\n push:\n branches:\n - main\n tags: '*'\n pull_request:\n\njobs:\n build:\n runs-on: u"
},
{
"path": ".github/workflows/FormatCheck.yml",
"chars": 1130,
"preview": "name: format-check\n\non:\n push:\n branches:\n - 'main'\n - 'release-'\n tags: '*'\n pull_request:\n\njobs:\n b"
},
{
"path": ".github/workflows/TagBot.yml",
"chars": 362,
"preview": "name: TagBot\non:\n issue_comment:\n types:\n - created\n workflow_dispatch:\njobs:\n TagBot:\n if: github.event_n"
},
{
"path": ".github/workflows/benchmark.yml",
"chars": 525,
"preview": "\nname: Benchmarks\n\non:\n pull_request:\n push:\n branches:\n - main \n tags: '*'\n\njobs:\n benchmark:\n runs-on"
},
{
"path": ".gitignore",
"chars": 13,
"preview": "Manifest.toml"
},
{
"path": "CITATION.bib",
"chars": 583,
"preview": "@inproceedings{arya2022automatic,\n author = {Arya, Gaurav and Schauer, Moritz and Sch\\\"{a}fer, Frank and Rackauckas, Chr"
},
{
"path": "LICENSE",
"chars": 1101,
"preview": "MIT License\n\nCopyright (c) 2022 Gaurav Arya <aryag@mit.edu> and contributors\n\nPermission is hereby granted, free of char"
},
{
"path": "Project.toml",
"chars": 2053,
"preview": "name = \"StochasticAD\"\nuuid = \"e4facb34-4f7e-4bec-b153-e122c37934ac\"\nauthors = [\"Gaurav Arya <aryag@mit.edu> and contribu"
},
{
"path": "README.md",
"chars": 1636,
"preview": "\n\n"
},
{
"path": "benchmark/benchmarks.jl",
"chars": 341,
"preview": "using BenchmarkTools\n\ninclude(\"random_walk.jl\")\ninclude(\"game_of_life.jl\")\ninclude(\"iteration.jl\")\ninclude(\"simple_ops.j"
},
{
"path": "benchmark/game_of_life.jl",
"chars": 588,
"preview": "module GoLBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\nusing Statistics\nusing ForwardDiff: derivative\ninclude(\".."
},
{
"path": "benchmark/iteration.jl",
"chars": 2530,
"preview": "\"\"\"\nIn the library we have tried to avoid generated functions, instead reductions from base with the\nhope that the itera"
},
{
"path": "benchmark/random_walk.jl",
"chars": 781,
"preview": "module RandomWalkBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\nusing Statistics\nusing ForwardDiff: derivative\nincl"
},
{
"path": "benchmark/runbenchmarks.jl",
"chars": 295,
"preview": "using PkgBenchmark\n\ninclude(\"utils.jl\")\nusing .Utils\n\nresults = benchmarkpkg(dirname(@__DIR__),\n BenchmarkConfig(env "
},
{
"path": "benchmark/simple_ops.jl",
"chars": 1117,
"preview": "module SimpleOpsBenchmark\n\nusing BenchmarkTools\n\nusing StochasticAD\n\nconst suite = BenchmarkGroup()\n\nsuite[\"add\"] = Benc"
},
{
"path": "benchmark/utils.jl",
"chars": 487,
"preview": "module Utils\n\nexport print_group\n\nusing Functors\nusing BenchmarkTools\n\n## Printing\n\n# Type piracy, fine since just in be"
},
{
"path": "docs/Project.toml",
"chars": 274,
"preview": "[deps]\nDistributions = \"31c24e10-a181-5473-b8eb-7969acd0382f\"\nDocThemeIndigo = \"8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f\"\nDo"
},
{
"path": "docs/make.jl",
"chars": 1273,
"preview": "using Pkg\n\nusing Documenter\nusing StochasticAD\nusing DocThemeIndigo\nusing Literate\n\n### Formatting\n\nindigo = DocThemeInd"
},
{
"path": "docs/src/assets/extra_styles.css",
"chars": 195,
"preview": ".display-light-only {display: block;}\n.display-dark-only {display: none;}\n.theme--documenter-dark .display-light-only {d"
},
{
"path": "docs/src/devdocs.md",
"chars": 3982,
"preview": "# Developer documentation (WIP)\n\n## Writing a custom rule for stochastic triples\n\n### via `StochasticAD.propagate`\n\nTo h"
},
{
"path": "docs/src/index.md",
"chars": 3758,
"preview": "```@raw html\n<img class=\"display-light-only\" src=\"images/path_skeleton.png\">\n<img class=\"display-dark-only\" src=\"images/"
},
{
"path": "docs/src/limitations.md",
"chars": 1711,
"preview": "# Limitations of StochasticAD\n\n`StochasticAD` has a number of limitations that are important to be aware of:\n\n* `Stochas"
},
{
"path": "docs/src/public_api.md",
"chars": 2155,
"preview": "# API walkthrough\n \nThe function [`derivative_estimate`](@ref) transforms a stochastic program containing discrete rando"
},
{
"path": "docs/src/tutorials/game_of_life.md",
"chars": 4180,
"preview": "# Stochastic Game of Life\n\nWe consider a stochastic version of [Conway's Game of Life](https://en.wikipedia.org/wiki/Con"
},
{
"path": "docs/src/tutorials/optimizations.md",
"chars": 4757,
"preview": "# Stochastic optimizations with discrete randomness\n\n```@setup random_walk\nimport Pkg\nPkg.activate(\"../../../tutorials/t"
},
{
"path": "docs/src/tutorials/particle_filter.md",
"chars": 16944,
"preview": "# Differentiable particle filter\n\nUsing a bootstrap particle sampler, we can approximate the posterior distributions\nof "
},
{
"path": "docs/src/tutorials/random_walk.md",
"chars": 4024,
"preview": "# Random walk\n\n```@setup random_walk\nimport Pkg\nPkg.activate(\"../../../tutorials\")\nPkg.develop(path=\"../../..\")\nPkg.inst"
},
{
"path": "docs/src/tutorials/reverse_demo.md",
"chars": 2355,
"preview": "```@meta\nEditURL = \"../../../tutorials/reverse_example/reverse_demo.jl\"\n```\n\n# Simple reverse mode example\n\n```@setup ra"
},
{
"path": "ext/StochasticADEnzymeExt.jl",
"chars": 1490,
"preview": "module StochasticADEnzymeExt\n\nusing StochasticAD\nusing Enzyme\n\nfunction enzyme_target(u, X, p, backend)\n # equivalent"
},
{
"path": "src/StochasticAD.jl",
"chars": 2161,
"preview": "module StochasticAD\n\n### Public API\n\nexport stochastic_triple, derivative_contribution, perturbations, smooth_triple,\n "
},
{
"path": "src/algorithms.jl",
"chars": 4957,
"preview": "abstract type AbstractStochasticADAlgorithm end\n\n\"\"\"\n ForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: A"
},
{
"path": "src/backends/abstract_wrapper.jl",
"chars": 4766,
"preview": "module AbstractWrapperFIsModule\n\nimport ..StochasticAD\n\nexport AbstractWrapperFIs\n\n\"\"\"\n AbstractWrapperFIs{V, FIs} <:"
},
{
"path": "src/backends/dict.jl",
"chars": 5451,
"preview": "module DictFIsModule\n\nexport DictFIsBackend, DictFIs\n\nimport ..StochasticAD\nusing Dictionaries\n\n\"\"\"\n DictFIsBackend <"
},
{
"path": "src/backends/pruned.jl",
"chars": 9481,
"preview": "module PrunedFIsModule\n\nimport ..StochasticAD\n\nexport PrunedFIsBackend, PrunedFIs\n\n\"\"\"\n PrunedFIsBackend <: Stochasti"
},
{
"path": "src/backends/pruned_aggressive.jl",
"chars": 5530,
"preview": "module PrunedFIsAggressiveModule\n\nimport ..StochasticAD\n\nexport PrunedFIsAggressiveBackend, PrunedFIsAggressive\n\n\"\"\"\n "
},
{
"path": "src/backends/smoothed.jl",
"chars": 2948,
"preview": "module SmoothedFIsModule\n\nimport ..StochasticAD\n\nexport SmoothedFIsBackend, SmoothedFIs\n\n\"\"\"\n SmoothedFIsBackend <: S"
},
{
"path": "src/backends/strategy_wrapper.jl",
"chars": 1318,
"preview": "module StrategyWrapperFIsModule\n\nusing ..StochasticAD\nusing ..StochasticAD.AbstractWrapperFIsModule\n\nexport StrategyWrap"
},
{
"path": "src/discrete_randomness.jl",
"chars": 16620,
"preview": "## Helper functions for discrete distributions \n\n# index of the parameter p\n_param_index(::Geometric) = 1\n_param_index(:"
},
{
"path": "src/finite_infinitesimals.jl",
"chars": 2237,
"preview": "# TODO: make this a module, with the interface exported?\n\n## \n\"\"\"\n AbstractFIsBackend\n\nAn abstract type for backend s"
},
{
"path": "src/general_rules.jl",
"chars": 11268,
"preview": "\"\"\"\nOperators which have already been overloaded by StochasticAD. \n\"\"\"\nconst handled_ops = Tuple{DataType, Int}[]\n\n\"\"\"\n "
},
{
"path": "src/misc.jl",
"chars": 632,
"preview": "@doc raw\"\"\"\n StochasticModel(X, p)\n\nCombine stochastic program `X` with parameter `p` into \na trainable model using ["
},
{
"path": "src/prelude.jl",
"chars": 1834,
"preview": "const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)\n\nconst UNARY_PREDICATES = [is"
},
{
"path": "src/propagate.jl",
"chars": 5000,
"preview": "\"\"\"\nA version of `value`` that allows unrecognized args to pass through. \n\"\"\"\nfunction get_value(arg)\n if arg isa Sto"
},
{
"path": "src/smoothing.jl",
"chars": 4195,
"preview": "### Particle resampling\n\n@doc raw\"\"\"\n new_weight(p::Real)\n\n Simulate a Bernoulli variable whose primal output is a"
},
{
"path": "src/stochastic_triple.jl",
"chars": 10323,
"preview": "\"\"\" \n StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}}\n\nStores the primal value of the computation, alongside a "
},
{
"path": "test/game_of_life.jl",
"chars": 375,
"preview": "using StochasticAD\nusing Test\nusing Statistics\n\ninclude(\"../tutorials/game_of_life/core.jl\")\nusing .GoLCore: fd_clever, "
},
{
"path": "test/random_walk.jl",
"chars": 423,
"preview": "using StochasticAD\nusing Test\nusing Statistics\nusing ForwardDiff: derivative\n\ninclude(\"../tutorials/random_walk/core.jl\""
},
{
"path": "test/resampling.jl",
"chars": 2144,
"preview": "using StochasticAD\nusing Random, Test\nusing Distributions\nusing LinearAlgebra\nusing ForwardDiff\n\n# test forward-mode AD "
},
{
"path": "test/runtests.jl",
"chars": 599,
"preview": "using SafeTestsets\nusing Test, Pkg\nimport Random\n\nRandom.seed!(1234)\n\nconst GROUP = get(ENV, \"GROUP\", \"All\")\nconst is_AP"
},
{
"path": "test/triples.jl",
"chars": 27352,
"preview": "using StochasticAD\nusing Test\nusing Distributions\nusing ForwardDiff\nusing OffsetArrays\nusing ChainRulesCore\nusing Random"
},
{
"path": "tutorials/Project.toml",
"chars": 1214,
"preview": "[deps]\nBenchmarkTools = \"6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf\"\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nD"
},
{
"path": "tutorials/README.md",
"chars": 133,
"preview": "The raw source code for our tutorials. You likely want to look at the documentation instead, where they are presented mo"
},
{
"path": "tutorials/game_of_life/core.jl",
"chars": 2234,
"preview": "module GoLCore\n\nusing Random\nusing Distributions\nusing LinearAlgebra\nusing StochasticAD\nusing StaticArrays\nusing OffsetA"
},
{
"path": "tutorials/game_of_life/plot_board.jl",
"chars": 863,
"preview": "include(\"core.jl\")\nusing Plots\nusing Statistics\nusing BenchmarkTools\n\np = 0.5\n_, board, history = stochastic_triple(p ->"
},
{
"path": "tutorials/particle_filter/benchmark.jl",
"chars": 1383,
"preview": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing BenchmarkTools\nusing Measurements\n\n# Benchmark fo"
},
{
"path": "tutorials/particle_filter/bias.jl",
"chars": 1281,
"preview": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing Random\n\n# Comparison of the derivative of the par"
},
{
"path": "tutorials/particle_filter/core.jl",
"chars": 9989,
"preview": "module ParticleFilterCore\n\n# load dependencies\nusing Distributions\nusing DistributionsAD\nusing Random\nusing Statistics\nu"
},
{
"path": "tutorials/particle_filter/model.jl",
"chars": 1897,
"preview": "# ParticleFilter Model\n\nusing Random, LinearAlgebra, GaussianDistributions, Distributions\n\n# particle filter core functi"
},
{
"path": "tutorials/particle_filter/variance.jl",
"chars": 3198,
"preview": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\nusing Random\n\nRandom.seed!(seed)\n# Comparison of the va"
},
{
"path": "tutorials/particle_filter/visualize.jl",
"chars": 875,
"preview": "include(\"core.jl\")\ninclude(\"model.jl\")\nusing Plots, LaTeXStrings\n\n# visualization of stochastic process (observations an"
},
{
"path": "tutorials/random_walk/compare_score.jl",
"chars": 2292,
"preview": "include(\"core.jl\")\nusing Plots, LaTeXStrings\nusing Statistics\nusing StochasticAD\nusing ForwardDiff: derivative\nusing Pro"
},
{
"path": "tutorials/random_walk/core.jl",
"chars": 2890,
"preview": "module RandomWalkCore\n\nusing Random\nusing Statistics\nusing Distributions\nusing LinearAlgebra\nusing StochasticAD\nusing St"
},
{
"path": "tutorials/random_walk/show_unbiased.jl",
"chars": 2733,
"preview": "include(\"core.jl\")\nprintln(\"## Exact computation\\n\")\n\nusing ForwardDiff: derivative\nusing BenchmarkTools\nusing .RandomWa"
},
{
"path": "tutorials/reverse_example/reverse_demo.jl",
"chars": 3701,
"preview": "#text # Simple reverse mode example \n\n#text ```@setup random_walk\n#text import Pkg\n#text Pkg.activate(\"../../../tutorial"
},
{
"path": "tutorials/toy_optimizations/Project.toml",
"chars": 315,
"preview": "[deps]\nCairoMakie = \"13f3f980-e62b-5c42-98c6-ff1f3baf88f0\"\nDistributions = \"31c24e10-a181-5473-b8eb-7969acd0382f\"\nOptimi"
},
{
"path": "tutorials/toy_optimizations/igarch.jl",
"chars": 2063,
"preview": "# Poisson autoregression\ncd(@__DIR__)\nusing StochasticAD, Distributions\nusing Optimisers\nimport Random\nRandom.seed!(1234"
},
{
"path": "tutorials/toy_optimizations/intro.jl",
"chars": 1453,
"preview": "# Toy expectation optimization problem \ncd(@__DIR__)\nusing StochasticAD, Distributions, Optimisers\nimport Random # hide\n"
},
{
"path": "tutorials/toy_optimizations/variational.jl",
"chars": 1618,
"preview": "# Toy variational problem: Find Poisson(p) close to NegativeBinomial(10, 1-30/(10+30))\n# by minimization of the Kullback"
}
]
About this extraction
This page contains the full source code of the gaurav-arya/StochasticAD.jl GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 72 files (217.3 KB), approximately 65.4k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.