Repository: gaurav-arya/StochasticAD.jl Branch: main Commit: 24788b5bc82d Files: 72 Total size: 217.3 KB Directory structure: gitextract__gvecf0r/ ├── .JuliaFormatter.toml ├── .git-blame-ignore-revs ├── .github/ │ └── workflows/ │ ├── CI.yml │ ├── CompatHelper.yml │ ├── Documentation.yml │ ├── FormatCheck.yml │ ├── TagBot.yml │ └── benchmark.yml ├── .gitignore ├── CITATION.bib ├── LICENSE ├── Project.toml ├── README.md ├── benchmark/ │ ├── benchmarks.jl │ ├── game_of_life.jl │ ├── iteration.jl │ ├── random_walk.jl │ ├── runbenchmarks.jl │ ├── simple_ops.jl │ └── utils.jl ├── docs/ │ ├── Project.toml │ ├── make.jl │ └── src/ │ ├── assets/ │ │ └── extra_styles.css │ ├── devdocs.md │ ├── index.md │ ├── limitations.md │ ├── public_api.md │ └── tutorials/ │ ├── game_of_life.md │ ├── optimizations.md │ ├── particle_filter.md │ ├── random_walk.md │ └── reverse_demo.md ├── ext/ │ └── StochasticADEnzymeExt.jl ├── src/ │ ├── StochasticAD.jl │ ├── algorithms.jl │ ├── backends/ │ │ ├── abstract_wrapper.jl │ │ ├── dict.jl │ │ ├── pruned.jl │ │ ├── pruned_aggressive.jl │ │ ├── smoothed.jl │ │ └── strategy_wrapper.jl │ ├── discrete_randomness.jl │ ├── finite_infinitesimals.jl │ ├── general_rules.jl │ ├── misc.jl │ ├── prelude.jl │ ├── propagate.jl │ ├── smoothing.jl │ └── stochastic_triple.jl ├── test/ │ ├── game_of_life.jl │ ├── random_walk.jl │ ├── resampling.jl │ ├── runtests.jl │ └── triples.jl └── tutorials/ ├── Project.toml ├── README.md ├── game_of_life/ │ ├── core.jl │ └── plot_board.jl ├── particle_filter/ │ ├── benchmark.jl │ ├── bias.jl │ ├── core.jl │ ├── model.jl │ ├── variance.jl │ └── visualize.jl ├── random_walk/ │ ├── compare_score.jl │ ├── core.jl │ └── show_unbiased.jl ├── reverse_example/ │ └── reverse_demo.jl └── toy_optimizations/ ├── Project.toml ├── igarch.jl ├── intro.jl └── variational.jl ================================================ FILE CONTENTS ================================================ ================================================ FILE: .JuliaFormatter.toml ================================================ style = "sciml" ================================================ FILE: .git-blame-ignore-revs ================================================ # Run this command to always ignore these in local `git blame`: # git config blame.ignoreRevsFile .git-blame-ignore-revs # Run formatter 70fd432667fb431e08ba52728734108d822a1922 # Run formatter 21038a047c023330876feb9259cd5c92add3ca81 # Run formatter after bracket alignment removal 799277f9652258282a91ecfe976df5fb8ab64c82 # Format db4333c604cc23c3c36420f09aa998d01ef0214b ================================================ FILE: .github/workflows/CI.yml ================================================ name: CI on: pull_request: push: branches: - main tags: '*' jobs: unittest: runs-on: ubuntu-latest strategy: matrix: group: - Core version: - '1' - '1.7' steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v2 with: file: lcov.info ================================================ FILE: .github/workflows/CompatHelper.yml ================================================ name: CompatHelper on: schedule: - cron: 0 0 * * * workflow_dispatch: jobs: CompatHelper: runs-on: ubuntu-latest steps: - name: Pkg.add("CompatHelper") run: julia -e 'using Pkg; Pkg.add("CompatHelper")' - name: CompatHelper.main() env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} run: julia -e 'using CompatHelper; CompatHelper.main()' ================================================ FILE: .github/workflows/Documentation.yml ================================================ name: Documentation on: push: branches: - main tags: '*' pull_request: jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: version: '1' - name: Install dependencies run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - name: Build and deploy env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key DATADEPS_ALWAYS_ACCEPT: true run: julia --project=docs/ docs/make.jl ================================================ FILE: .github/workflows/FormatCheck.yml ================================================ name: format-check on: push: branches: - 'main' - 'release-' tags: '*' pull_request: jobs: build: runs-on: ${{ matrix.os }} strategy: matrix: julia-version: [1] julia-arch: [x86] os: [ubuntu-latest] steps: - uses: julia-actions/setup-julia@latest with: version: ${{ matrix.julia-version }} - uses: actions/checkout@v1 - name: Install JuliaFormatter and format # This will use the latest version by default but you can set the version like so: # # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' run: | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' julia -e 'using JuliaFormatter; format(".", verbose=true)' - name: Format check run: | julia -e ' out = Cmd(`git diff`) |> read |> String if out == "" exit(0) else @error "Some files have not been formatted !!!" write(stdout, out) exit(1) end' ================================================ FILE: .github/workflows/TagBot.yml ================================================ name: TagBot on: issue_comment: types: - created workflow_dispatch: jobs: TagBot: if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' runs-on: ubuntu-latest steps: - uses: JuliaRegistries/TagBot@v1 with: token: ${{ secrets.GITHUB_TOKEN }} ssh: ${{ secrets.DOCUMENTER_KEY }} ================================================ FILE: .github/workflows/benchmark.yml ================================================ name: Benchmarks on: pull_request: push: branches: - main tags: '*' jobs: benchmark: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@latest with: version: 1 - name: Install dependencies run: julia -e 'using Pkg; Pkg.activate("tutorials"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate();' - name: Run benchmarks run: julia --project=tutorials --color=yes benchmark/runbenchmarks.jl ================================================ FILE: .gitignore ================================================ Manifest.toml ================================================ FILE: CITATION.bib ================================================ @inproceedings{arya2022automatic, author = {Arya, Gaurav and Schauer, Moritz and Sch\"{a}fer, Frank and Rackauckas, Christopher}, booktitle = {Advances in Neural Information Processing Systems}, editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, pages = {10435--10447}, publisher = {Curran Associates, Inc.}, title = {Automatic Differentiation of Programs with Discrete Randomness}, url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/43d8e5fc816c692f342493331d5e98fc-Paper-Conference.pdf}, volume = {35}, year = {2022} } ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Gaurav Arya and contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Project.toml ================================================ name = "StochasticAD" uuid = "e4facb34-4f7e-4bec-b153-e122c37934ac" authors = ["Gaurav Arya and contributors"] version = "0.1.26" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [extensions] StochasticADEnzymeExt = "Enzyme" [compat] ChainRulesCore = "1.15" ChainRulesOverloadGeneration = "0.1" Dictionaries = "0.3" Distributions = "0.25" DistributionsAD = "0.6" ExprTools = "0.1" ForwardDiff = "0.10" Functors = "0.4.3" julia = "1" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["LinearAlgebra", "Pkg", "Printf", "Test", "Statistics", "SafeTestsets", "OffsetArrays", "StaticArrays", "Zygote", "ForwardDiff", "GaussianDistributions", "Measurements", "UnPack", "StatsBase", "DiffResults", "ChainRulesCore"] ================================================ FILE: README.md ================================================ ![](docs/src/images/path_skeleton.png#gh-light-mode-only) ![](docs/src/images/path_skeleton_dark.png#gh-dark-mode-only) # StochasticAD [![Build Status](https://github.com/gaurav-arya/StochasticAD.jl/workflows/CI/badge.svg?branch=main)](https://github.com/gaurav-arya/StochasticAD.jl/actions?query=workflow:CI) [![](https://img.shields.io/badge/docs-main-blue.svg)](https://gaurav-arya.github.io/StochasticAD.jl/dev/) [![arXiv article](https://img.shields.io/badge/article-arXiv%3A10.48550-B31B1B)](https://arxiv.org/abs/2210.08572) StochasticAD is an experimental, research package for automatic differentiation (AD) of stochastic programs. It implements AD algorithms for handling programs that can contain *discrete* randomness, based on the methodology developed in [this NeurIPS 2022 paper](https://doi.org/10.48550/arXiv.2210.08572). We're still working on docs and code cleanup! ## Installation The package can be installed with the Julia package manager: ```julia julia> using Pkg; julia> Pkg.add("StochasticAD"); ``` ## Citation ``` @inproceedings{arya2022automatic, author = {Arya, Gaurav and Schauer, Moritz and Sch\"{a}fer, Frank and Rackauckas, Christopher}, booktitle = {Advances in Neural Information Processing Systems}, editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, pages = {10435--10447}, publisher = {Curran Associates, Inc.}, title = {Automatic Differentiation of Programs with Discrete Randomness}, url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/43d8e5fc816c692f342493331d5e98fc-Paper-Conference.pdf}, volume = {35}, year = {2022} } ``` ================================================ FILE: benchmark/benchmarks.jl ================================================ using BenchmarkTools include("random_walk.jl") include("game_of_life.jl") include("iteration.jl") include("simple_ops.jl") const SUITE = BenchmarkGroup() SUITE["random_walk"] = RandomWalkBenchmark.suite SUITE["game_of_life"] = GoLBenchmark.suite SUITE["iteration"] = IterationBenchmark.suite SUITE["simple_ops"] = SimpleOpsBenchmark.suite ================================================ FILE: benchmark/game_of_life.jl ================================================ module GoLBenchmark using BenchmarkTools using StochasticAD using Statistics using ForwardDiff: derivative include("../tutorials/game_of_life/core.jl") using .GoLCore: play, p const suite = BenchmarkGroup() suite["original"] = @benchmarkable $play($p) suite["PrunedFIs"] = @benchmarkable derivative_estimate($play, $p; backend = PrunedFIsBackend()) suite["PrunedFIsAggressive"] = @benchmarkable derivative_estimate($play, $p; backend = PrunedFIsAggressiveBackend()) suite["SmoothedFIs"] = @benchmarkable derivative_estimate($play, $p; backend = SmoothedFIsBackend()) end ================================================ FILE: benchmark/iteration.jl ================================================ """ In the library we have tried to avoid generated functions, instead reductions from base with the hope that the iteration will be optimized to be zero-cost. This suite tests the performance of iteration on small nested structures, which crop up when `propagate` is called on small structures of scalars. The `couple` and `combine` operations of FIss, which use iteration, are benchmarked. """ module IterationBenchmark using BenchmarkTools using StochasticAD using StaticArrays const suite = BenchmarkGroup() # Examples consist of flat and non-flat versions of structures, to test zero-cost iteration. tups = Dict("easy" => (ntuple(identity, 3), (1, (2, 3))), "hard" => (ntuple(identity, 9), (1, (2, 3), (4, (5, (6, 7, 8), 9))))) SAs = Dict("easy" => (SA[1, 2, 3], (1, SA[2, 3])), "hard" => (SA[1, 2, 3, 4, 5, 6, 7, 8, 9], (1, SA[2, 3], (4, (5, SA[6, 7, 8], 9))))) for (setname, set) in (("tups", tups), ("SAs", SAs)) suite[setname] = BenchmarkGroup() setsuite = suite[setname] for case in ["easy", "hard"] casesuite = setsuite[case] = BenchmarkGroup() for isflat in [false, true] flatsuite = casesuite[isflat ? "flat" : "not flat"] = BenchmarkGroup() values = set[case][isflat ? 1 : 2] flatsuite["make_iterate_values"] = @benchmarkable StochasticAD.structural_iterate($values) iter_values = StochasticAD.structural_iterate(values) flatsuite["foldl_values"] = @benchmarkable foldl(+, $(iter_values)) flatsuite["iterate_values"] = @benchmarkable for i in $(iter_values) end for backend in [PrunedFIsBackend(), PrunedFIsAggressiveBackend()] FIs_suite = flatsuite[backend] = BenchmarkGroup() Δs = StochasticAD.create_Δs(backend, Int) Δs1 = StochasticAD.similar_new(Δs, 1, 1) Δs_all = StochasticAD.structural_map(x -> map(Δ -> x, Δs1), values) FIs_suite["make_iterate_Δs"] = @benchmarkable StochasticAD.structural_iterate($Δs_all) # We don't interpolate backend directly in below (i.e. do $FIs) because string interpolating a type # seems to lead to slow benchmarks. FIs_suite["couple_same"] = @benchmarkable StochasticAD.couple(typeof($Δs), $Δs_all) FIs_suite["combine_same"] = @benchmarkable StochasticAD.combine( typeof($Δs), $Δs_all) end end end end end ================================================ FILE: benchmark/random_walk.jl ================================================ module RandomWalkBenchmark using BenchmarkTools using StochasticAD using Statistics using ForwardDiff: derivative include("../tutorials/random_walk/core.jl") using .RandomWalkCore: n, p, nsamples using .RandomWalkCore: fX, get_dfX const suite = BenchmarkGroup() suite["original"] = @benchmarkable $(fX)($p) suite["PrunedFIs"] = @benchmarkable derivative_estimate($fX, $p; backend = PrunedFIsBackend()) suite["PrunedFIsAggressive"] = @benchmarkable derivative_estimate($fX, $p; backend = PrunedFIsAggressiveBackend()) suite["SmoothedFIs"] = @benchmarkable derivative_estimate($fX, $p; backend = SmoothedFIsBackend()) forwarddiff_func = p -> fX(p; hardcode_leftright_step = true) suite["ForwardDiff_smoothing"] = @benchmarkable derivative($forwarddiff_func, $p) end ================================================ FILE: benchmark/runbenchmarks.jl ================================================ using PkgBenchmark include("utils.jl") using .Utils results = benchmarkpkg(dirname(@__DIR__), BenchmarkConfig(env = Dict("JULIA_NUM_THREADS" => "1", "OMP_NUM_THREADS" => "1")), resultfile = joinpath(@__DIR__, "result.json")) @show results = print_group(results.benchmarkgroup) ================================================ FILE: benchmark/simple_ops.jl ================================================ module SimpleOpsBenchmark using BenchmarkTools using StochasticAD const suite = BenchmarkGroup() suite["add"] = BenchmarkGroup() suite["add_via_propagate_nodeltas"] = BenchmarkGroup() suite["add_via_propagate"] = BenchmarkGroup() suite["add"]["original"] = @benchmarkable +(0.5, 0.5) suite["add_via_propagate_nodeltas"]["original"] = @benchmarkable StochasticAD.propagate(+, 0.5, 0.5) suite["add_via_propagate"]["original"] = @benchmarkable StochasticAD.propagate(+, 0.5, 0.5; keep_deltas = Val{ true, }) for backend in [PrunedFIsBackend(), PrunedFIsAggressiveBackend()] suite["add"][backend] = @benchmarkable +(st, st) setup=(st = stochastic_triple(0.5; backend = $backend)) suite["add_via_propagate_nodeltas"][backend] = @benchmarkable StochasticAD.propagate(+, st, st) setup=(st = stochastic_triple(0.5; backend = $backend)) suite["add_via_propagate"][backend] = @benchmarkable StochasticAD.propagate(+, st, st; keep_deltas = Val{ true, }) setup=(st = stochastic_triple(0.5; backend = $backend)) end end ================================================ FILE: benchmark/utils.jl ================================================ module Utils export print_group using Functors using BenchmarkTools ## Printing # Type piracy, fine since just in benchmarking. (design of Functors should probably allow for user-customized functors) @functor BenchmarkTools.BenchmarkGroup function print_trial(t) ptime = BenchmarkTools.prettytime(time(t)) pallocs = "$(allocs(t)) allocs" return "$ptime, $pallocs" end function print_group(b) fmap(t -> (t isa BenchmarkTools.Trial ? print_trial(t) : t), b) end end ================================================ FILE: docs/Project.toml ================================================ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac" ================================================ FILE: docs/make.jl ================================================ using Pkg using Documenter using StochasticAD using DocThemeIndigo using Literate ### Formatting indigo = DocThemeIndigo.install(StochasticAD) format = Documenter.HTML(prettyurls = false, assets = [indigo, "assets/extra_styles.css"], repolink = "https://github.com/gaurav-arya/StochasticAD.jl", edit_link = "main") ### Pagination pages = [ "Overview" => "index.md", "Tutorials" => [ "tutorials/random_walk.md", "tutorials/game_of_life.md", "tutorials/particle_filter.md", "tutorials/optimizations.md", "tutorials/reverse_demo.md" ], "Public API" => "public_api.md", "Developer documentation" => "devdocs.md", "Limitations" => "limitations.md" ] ### Prepare literate tutorials # TODO (for now they are manually built into docs/src/tutorials and checked into repo) ### Make docs makedocs(sitename = "StochasticAD.jl", authors = "Gaurav Arya and other contributors", modules = [StochasticAD], format = format, pages = pages, warnonly = [:missing_docs]) try deploydocs(repo = "github.com/gaurav-arya/StochasticAD.jl", devbranch = "main", push_preview = true) catch e println("Error encountered while deploying docs:") showerror(stdout, e) end ================================================ FILE: docs/src/assets/extra_styles.css ================================================ .display-light-only {display: block;} .display-dark-only {display: none;} .theme--documenter-dark .display-light-only {display: none;} .theme--documenter-dark .display-dark-only {display: block;} ================================================ FILE: docs/src/devdocs.md ================================================ # Developer documentation (WIP) ## Writing a custom rule for stochastic triples ### via `StochasticAD.propagate` To handle a deterministic discrete construct that `StochasticAD` does not automatically handle (e.g. branching via `if`, boolean comparisons), it is often sufficient to simply add a dispatch rule that calls out to `StochasticAD.propagate`. ```@docs StochasticAD.propagate ``` ### via a custom dispatch If a function does not meet the conditions of `StochasticAD.propagate` and is not already supported, a custom dispatch may be necessary. For example, consider the following function which manually implements a geometric random variable: ```@example rule import Random Random.seed!(1234) # hide using Distributions # make rng input explicit function mygeometric(rng, p) x = 0 while !(rand(rng, Bernoulli(p))) x += 1 end return x end mygeometric(p) = mygeometric(Random.default_rng(), p) ``` This is equivalent to `rand(Geometric(p))` which is already supported, but for pedagogical purposes we will implement our own rule from scratch. Using the stochastic derivative formulas from [Automatic Differentiation of Programs with Discrete Randomness](https://doi.org/10.48550/arXiv.2210.08572), the right stochastic derivative of this program is given by ```math Y_R = X - 1, w_R = \frac{x}{p(1-p)}, ``` and the left stochastic derivative of this program is given by ```math Y_L = X + 1, w_L = -\frac{x+1}{p}. ``` Using these expressions, we can now write the dispatch rule for stochastic triples: ```@example rule using StochasticAD import StochasticAD: StochasticTriple, similar_new, similar_empty, combine function mygeometric(rng, p_st::StochasticTriple{T}) where {T} p = p_st.value rng_copy = copy(rng) # save a copy for coupling later x = mygeometric(rng, p) # Form the new discrete perturbations (combinations of weight w and perturbation Y - X) Δs1 = if p_st.δ > 0 # right stochastic derivative w = p_st.δ * x / (p * (1 - p)) x > 0 ? similar_new(p_st.Δs, -1, w) : similar_empty(p_st.Δs, Int) elseif p_st.δ < 0 # left stochastic derivative w = -p_st.δ * (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L similar_new(p_st.Δs, 1, w) else similar_empty(p_st.Δs, Int) end # Propagate any existing perturbations to p through the function function map_func(Δ) # Couple the samples by using the same RNG. (A simpler strategy would have been independent sampling, i.e. mygeometric(p + Δ) - x) mygeometric(copy(rng_copy), p + Δ) - x end Δs2 = map(map_func, p_st.Δs) # Return the output stochastic triple StochasticTriple{T}(x, zero(x), combine((Δs2, Δs1))) end ``` In the above, we used some of the interface functions supported by a collection of perturbations `Δs::StochasticAD.AbstractFIs`. These were `similar_empty(Δs, V)`, which created an empty perturbation of type `V`, `similar_new(Δs, Δ, w)`, which created a new perturbation of size `Δ` and weight `w`, `map(map_func, Δs)`, which propagates a collection of perturbations through a mapping function, and `combine((Δs2, Δs1)))` which combines multiple collections of perturbations together. We can test out our rule: ```@example rule @show stochastic_triple(mygeometric, 0.1) # try feeding an input that already has a pertrubation f(x) = mygeometric(2 * x + 0.1 * rand(Bernoulli(x)))^2 @show stochastic_triple(f, 0.1) # verify against black-box finite differences N = 1000000 samples_stochad = [derivative_estimate(f, 0.1) for i in 1:N] samples_fd = [(f(0.105) - f(0.095)) / 0.01 for i in 1:N] println("Stochastic AD: $(mean(samples_stochad)) ± $(std(samples_stochad) / sqrt(N))") println("Finite differences: $(mean(samples_fd)) ± $(std(samples_fd) / sqrt(N))") nothing # hide ``` ## Distribution-specific customization of differentiation algorithm ```@docs randst InversionMethodDerivativeCoupling ``` ================================================ FILE: docs/src/index.md ================================================ ```@raw html ``` # StochasticAD [StochasticAD](https://github.com/gaurav-arya/StochasticAD.jl) is an experimental, research package for automatic differentiation (AD) of stochastic programs. It implements AD algorithms for handling programs that can contain *discrete* randomness, based on the methodology developed in [this NeurIPS 2022 paper](https://doi.org/10.48550/arXiv.2210.08572). ## Introduction Derivatives are all about how functions are affected by a tiny change `ε` in their input. First, let's imagine perturbing the input of a deterministic, differentiable function such as $f(p) = p^2$ at $p = 2$. ```@example continuous using StochasticAD f(p) = p^2 stochastic_triple(f, 2) # Feeds 2 + ε into f ``` The output tells us that if we change the input `2` by a tiny amount `ε`, the output of `f` will change by approximately `4ε`. This is the case we're familiar with: we can get the value `4` by applying the chain rule, $\frac{\mathrm{d}}{\mathrm{d} p} p^2 = 2p = 4$. Thinking in terms of tiny changes, the output above looks a lot like a [dual number](https://en.wikipedia.org/wiki/Dual_number). But what happens with a discrete random function? Let's give it a try. ```@example discrete import Random # hide Random.seed!(4321) # hide using StochasticAD, Distributions f(p) = rand(Bernoulli(p)) # 1 with probability p, 0 otherwise stochastic_triple(f, 0.5) # Feeds 0.5 + ε into f ``` The output of a [Bernoulli variable](https://en.wikipedia.org/wiki/Bernoulli_distribution) cannot change by a tiny amount: it is either `0` or `1`. But in the probabilistic world, there is another way to change by a tiny amount *on average*: jump by a large amount, with tiny probability. `StochasticAD` introduces a stochastic triple object, which generalizes dual numbers by including a *third* component to describe these perturbations. Here, the stochastic triple says that the original random output was `0`, but given a small change `ε` in the input, the output will jump up to `1` with probability approximately `2ε`. Stochastic triples can be used to construct a new random program whose average is the derivative of the average of the original program. We simply propagate stochastic triples through the program via [`stochastic_triple`](@ref), and then sum up the "dual" and "triple" components at the end via [`derivative_contribution`](@ref). This process is packaged together in the function [`derivative_estimate`](@ref). Let's try a crazier example, where we mix discrete and continuous randomness! ```@example estimate using StochasticAD, Distributions import Random # hide Random.seed!(1234) # hide function X(p) a = p * (1 - p) b = rand(Binomial(10, p)) c = 2 * b + 3 * rand(Bernoulli(p)) return a * c * rand(Normal(b, a)) end st = @show stochastic_triple(X, 0.6) # sample a single stochastic triple at p = 0.6 @show derivative_contribution(st) # which produces a single derivative estimate... samples = [derivative_estimate(X, 0.6) for i in 1:1000] # many samples from derivative program derivative = mean(samples) uncertainty = std(samples) / sqrt(1000) println("derivative of 𝔼[X(p)] = $derivative ± $uncertainty") ``` ## Index See the [public API](public_api.md) for a walkthrough of the API, and the tutorials on differentiating a [random walk](tutorials/random_walk.md), a [stochastic game of life](tutorials/game_of_life.md), and a [particle filter](tutorials/particle_filter.md), and solving [stochastic optimization and variational problems](tutorials/optimizations.md) with discrete randomness. This is a prototype package with a number of [limitations](limitations.md). ================================================ FILE: docs/src/limitations.md ================================================ # Limitations of StochasticAD `StochasticAD` has a number of limitations that are important to be aware of: * `StochasticAD` uses operator-overloading just like [ForwardDiff](https://juliadiff.org/ForwardDiff.jl/stable/), so all of the [limitations](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/) listed there apply here too. Also note that some useful features of `ForwardDiff`, such as chunking for greater efficiency with a large number of parameters, have not yet been implemented here. * We have limited support for reverse-mode AD via [smoothing](public_api.md#Smoothing), which cannot be guaranteed to be unbiased in all cases. * We do not yet support `if` statements with discrete random input. A workaround can be to use array indexing to express discrete random choices (see [the random walk tutorial](tutorials/random_walk.md) for an example). * We do not yet support non-real values as intermediate values (e.g. a function such as `length(A[rand(Bernoulli(p))])` where `A` is an array of strings is in theory differentiable). * We do not support discrete random variables that are implicitly implemented using continuous random variables, e.g. `rand() < p`. * We support a limited assortment of discrete random variables: currently `Bernoulli`, `Binomial`, `Geometric`, `Poisson`, and `Categorical` from [Distributions](https://juliastats.org/Distributions.jl/). We are working on increasing coverage across `Distributions` as well as other libraries providing discrete random samplers such as [MeasureTheory](https://cscherrer.github.io/MeasureTheory.jl/stable/). * Higher-order differentiation is not supported. `StochasticAD` is still in active development! PRs are welcome. ================================================ FILE: docs/src/public_api.md ================================================ # API walkthrough The function [`derivative_estimate`](@ref) transforms a stochastic program containing discrete randomness into a new program whose average is the derivative of the original. ```@docs derivative_estimate ``` While [`derivative_estimate`](@ref) is self-contained, we can also use the functions below to work with stochastic triples directly. ```@docs StochasticAD.stochastic_triple StochasticAD.derivative_contribution StochasticAD.value StochasticAD.delta StochasticAD.perturbations ``` Note that [`derivative_estimate`](@ref) is simply the composition of [`stochastic_triple`](@ref) and [`derivative_contribution`](@ref). We also provide a convenience function for mimicking the behaviour of standard AD, where derivatives of discrete random steps are dropped: ```@docs StochasticAD.dual_number ``` ## Algorithms ```@docs StochasticAD.ForwardAlgorithm StochasticAD.EnzymeReverseAlgorithm ``` ## Smoothing What happens if we were to run [`derivative_contribution`](@ref) after each step, instead of only at the end? This is *smoothing*, which combines the second and third components of a single stochastic triple into a single dual component. Smoothing no longer has a guarantee of unbiasedness, but is surprisingly accurate in a number of situations. For example, the popular [straight through gradient estimator](https://stackoverflow.com/questions/38361314/the-concept-of-straight-through-estimator-ste) can be viewed as a special case of smoothing. Forward smoothing rules are provided through `ForwardDiff`, and backward rules through `ChainRules`, so that e.g. `Zygote.gradient` and `ForwardDiff.derivative` will use smoothed rules for discrete random variables rather than dropping the gradients entirely. Currently, special discrete->discrete constructs such as array indexing are not supported for smoothing. ## Optimization We also provide utilities to make it easier to get started with forming and training a model via stochastic gradient descent: ```@docs StochasticAD.StochasticModel StochasticAD.stochastic_gradient ``` These are used in the [tutorial on stochastic optimization](tutorials/optimizations.md). ================================================ FILE: docs/src/tutorials/game_of_life.md ================================================ # Stochastic Game of Life We consider a stochastic version of [Conway's Game of Life](https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life), played on a two-dimensional board. We shall use the following packages, ```@setup game_of_life import Pkg Pkg.activate("../../../tutorials") Pkg.develop(path="../../..") Pkg.instantiate() ``` ```@example game_of_life using Distributions using StochasticAD using OffsetArrays using StaticArrays ``` ## Setting up the stochastic Game of Life Each turn, the standard Game of Life applies the following rules to each cell, ```math \text{dead and 3 neighbours alive} \to \text{ alive}, \\ \text{alive and 0, 1, or 4 neighbours alive} \to \text{ dead}. ``` The cell's status does not change otherwise. In our stochastic version, these rules instead occur with probability `1-θ`, while the opposite event has probability `θ`. To initialize the board at the beginning of the game, we randomly set each cell alive with probability `p`. The following high level function sets up the probabilities and provides them to `play_game_of_life`. ```@example game_of_life function play(p, θ=0.1, N=12, T=10; log=false) # N is the board half-length, T are game time steps low = θ high = 1-θ birth_probs = SA[low, low, low, high, low] # 0, 1, 2, 3, 4 neighbours death_probs = SA[high, high, low, low, high] # 0, 1, 2, 3, 4 neighbours return play_game_of_life(p, vcat(birth_probs, death_probs), N, T; log) end ``` We can now implement the Game of Life based on the specification. At the end of the game, we return the total number of alive cells. ```@example game_of_life # A single turn of the game function update_state(all_probs, N, board_new, board_old) for i in -N:N for j in -N:N neighbours = board_old[i+1, j] + board_old[i-1, j] + board_old[i, j-1] + board_old[i, j+1] index = board_new[i,j] * 5 + neighbours + 1 b = rand(Bernoulli(all_probs[index])) board_new[i,j] += (1 - 2 * board_new[i,j]) * b end end end function play_game_of_life(p, all_probs, N, T; log=false) dual_type = promote_type(typeof(rand(Bernoulli(p))), typeof.(rand.(Bernoulli.(all_probs)))...) # a hacky way of getting the correct array type board = OffsetArray(zeros(dual_type, 2*N + 3, 2*N + 3), -(N+1):(N+1), -(N+1):(N+1)) # center board at (0,0), pad by 1 # initialize the board for i in -N:N for j in -N:N board[i,j] = rand(Bernoulli(p)) end end board_old = similar(board) log && (history = []) # play the game for time_step in 1:T copy!(board_old, board) update_state(all_probs, N, board, board_old) log && push!(history, copy(board)) end if !log return sum(board) else return sum(board), board, history end end play(0.5, 0.1) # play the game with p = 0.5 and θ = 0.1 ``` !!! note Note that we did have to be careful to write this program to be compatible with the [current capabilities of `StochasticAD`](../limitations.md). For example, we concatenated `birth_probs` and `death_probs` into a single array `all_probs` and used the index `board[i, j] * 5 + neighbours + 1` to find the probability, rather than use the more natural `if alive... else...` syntax. ## Differentiating the Game of Life Let's differentiate the Game of Life! ```@example game_of_life @show stochastic_triple(play, 0.5) # let's take a look at a single stochastic triple samples = [derivative_estimate(play, 0.5) for i in 1:10000] # take many samples derivative = mean(samples) uncertainty = std(samples) / sqrt(10000) println("derivative of 𝔼[play(p)] = $derivative ± $uncertainty") ``` The following sketch of the final state of the board for a single run gives some insight into what the stochastic triples are doing. The original board is depicted in grey and white for dead and alive, and the cells which flip from dead to alive in the "alternative" path consider by the triples are marked with + signs, while the cells which flip from alive to dead are marked with X signs. ```@raw html ``` ⠀ ================================================ FILE: docs/src/tutorials/optimizations.md ================================================ # Stochastic optimizations with discrete randomness ```@setup random_walk import Pkg Pkg.activate("../../../tutorials/toy_optimizations") Pkg.develop(path="../../..") Pkg.instantiate() ``` In this tutorial, we solve two stochastic optimization problems using `StochasticAD` where the optimization objective is formed using discrete distributions. We will need the following packages: ```@example optimizations using Distributions # defines several supported discrete distributions using StochasticAD using CairoMakie # for plotting using Optimisers # for stochastic gradient descent ``` ## Optimizing our toy program Recall the "crazy" program from the intro: ```@example optimizations function X(p) a = p * (1 - p) b = rand(Binomial(10, p)) c = 2 * b + 3 * rand(Bernoulli(p)) return a * c * rand(Normal(b, a)) end ``` Let's maximize $\mathbb{E}[X(p)]$! First, let's setup the problem, using the [`StochasticModel`](@ref) helper utility to create a trainable model: ```@example optimizations p0 = [0.5] # initial value of p, wrapped in an array for use in the stochastic model m = StochasticModel(p -> -X(p[1]), p0) # formulate as minimization problem ``` Now, let's perform stochastic gradient descent using [Adam](https://arxiv.org/abs/1412.6980), where we use [`stochastic_gradient`](@ref) to obtain a gradient of the model. ```@example optimizations iterations = 1000 trace = Float64[] o = Adam() # use Adam for optimization s = Optimisers.setup(o, m) for i in 1:iterations # Perform a gradient step Optimisers.update!(s, m, stochastic_gradient(m)) push!(trace, m.p[]) end p_opt = m.p[] # Our optimized value of p ``` Finally, let's plot the results of our optimization, and also perform a sweep through the parameter space to verify the accuracy of our estimator: ```@example optimizations ## Sweep through parameters to find average and derivative ps = 0.02:0.02:0.98 # values of p to sweep N = 1000 # number of samples at each p avg = [mean(X(p) for _ in 1:N) for p in ps] derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps] ## Make plots f = Figure() ax = f[1, 1] = Axis(f, title = "Estimates", xlabel="Value of p") lines!(ax, ps, avg, label = "≈ E[X(p)]") lines!(ax, ps, derivative, label = "≈ d/dp E[X(p)]") vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0) hlines!(ax, [0.0], color = :black, linewidth = 1.0) ylims!(ax, (-50, 80)) f[1, 2] = Legend(f, ax, framevisible = false) ax = f[2, 1:2] = Axis(f, title = "Optimizer trace", xlabel="Iterations", ylabel="Value of p") lines!(ax, trace, color = :green, linewidth = 2.0) save("crazy_opt.png", f, px_per_unit = 4) # hide nothing # hide ``` ![](crazy_opt.png) ## Solving a variational problem Let's consider a toy variational program: we find a [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) that is close to the distribution of a [negative Binomial](https://en.wikipedia.org/wiki/Negative_binomial_distribution), via minimization of the [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) $D_{\mathrm{KL}}$. Concretely, let us solve ```math \underset{p \in \mathbb{R}}{\operatorname{argmin}}\; D_{\mathrm{KL}}\left(\mathrm{Pois}(p) \hspace{.3em}\middle\|\hspace{.3em} \mathrm{NBin}(10, 0.25) \right). ``` The following program produces an unbiased estimate of the objective: ```@example optimizations function X(p) i = rand(Poisson(p)) return logpdf(Poisson(p), i) - logpdf(NegativeBinomial(10, 0.25), i) end ``` We can now optimize the KL-divergence via stochastic gradient descent! ```@example optimizations # Minimize E[X] = KL(Poisson(p)| NegativeBinomial(10, 0.25)) iterations = 1000 p0 = [10.0] m = StochasticModel(p -> X(p[1]), p0) trace = Float64[] o = Adam(0.1) s = Optimisers.setup(o, m) for i in 1:iterations Optimisers.update!(s, m, stochastic_gradient(m)) push!(trace, m.p[]) end p_opt = m.p[] ``` Let's plot our results in the same way as before: ```@example optimizations ps = 10:0.5:50 N = 1000 avg = [mean(X(p) for _ in 1:N) for p in ps] derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps] f = Figure() ax = f[1, 1] = Axis(f, title = "Estimates", xlabel="Value of p") lines!(ax, ps, avg, label = "≈ E[X(p)]") lines!(ax, ps, derivative, label = "≈ d/dp E[X(p)]") vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0) hlines!(ax, [0.0], color = :black, linewidth = 1.0) ylims!(ax, (-2.5, 5)) f[1, 2] = Legend(f, ax, framevisible = false) ax = f[2, 1:2] = Axis(f, title = "Optimizer trace", ylabel="Value of p", xlabel="Iterations") lines!(ax, trace, color = :green, linewidth = 2.0) save("variational.png", f, px_per_unit = 4) # hide nothing # hide ``` ![](variational.png) ================================================ FILE: docs/src/tutorials/particle_filter.md ================================================ # Differentiable particle filter Using a bootstrap particle sampler, we can approximate the posterior distributions of the states given noisy and partial observations of the state of a hidden Markov model by a cloud of `K` weighted particles with weights `W`. In this tutorial, we are going to: - implement a differentiable particle filter based on `StochasticAD.jl`. - visualize the particle filter in ``d = 2`` dimensions. - compare the gradient based on the differentiable particle filter to a biased gradient estimator as well as to the gradient of a differentiable Kalman filter. - show how to benchmark primal evaluation, forward- and reverse-mode AD of the particle filter. ## Setup We will make use of several julia packages. For example, we are going to use `Distributions` and `DistributionsAD` that implement the reparameterization trick for Gaussian distributions used in the observation and state-transition model, which we specify below. We also import `GaussianDistributions.jl` to implement the differentiable Kalman filter. ### Package dependencies ```@setup particle_filter import Pkg Pkg.activate("../../../tutorials") Pkg.develop(path="../../..") Pkg.instantiate() ``` ```@example particle_filter # activate tutorial project file # load dependencies using StochasticAD using Distributions using DistributionsAD using Random using Statistics using StatsBase using LinearAlgebra using Zygote using ForwardDiff using GaussianDistributions using GaussianDistributions: correct, ⊕ using Measurements using UnPack using Plots using LaTeXStrings using BenchmarkTools ``` ### Particle filter For convenience, we first introduce the new type `StochasticModel` with the following fields: - `T`: total number of time steps. - `start`: starting distribution for the initial state. For example, in the form of a narrow Gaussian `start(θ) = Gaussian(x0, 0.001 * I(d))`. - `dyn`: pointwise differentiable stochastic program in the form of Markov transition densities. For example, `dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q(θ))`, where `Q(θ)` denotes the covariance matrix. - `obs`: observation model having a smooth conditional probability density depending on current state `x` and parameters `θ`. For example, `obs(x, θ) = MvNormal(x, R(θ))`, where `R(θ)` denotes the covariance matrix. For parameters `θ`, `rand(start(θ))` gives a sample from the prior distribution of the starting distribution. For current state `x` and parameters `θ`, `xnew = rand(dyn(x, θ))` samples the new state (i.e. `dyn` gives for each `x, θ` a distribution-like object). Finally, `y = rand(obs(x, θ))` samples an observation. We can then define the `ParticleFilter` type that wraps a stochastic model `StochM::StochasticModel`, a sampling strategy (with arguments `p, K, sump=1`) and observational data `ys`. For simplicity, our implementation assumes a observation-likelihood function being available via `pdf(obs(x, θ), y)`. ```@example particle_filter struct StochasticModel{TType<:Integer,T1,T2,T3} T::TType # time steps start::T1 # prior dyn::T2 # dynamical model obs::T3 # observation model end struct ParticleFilter{mType<:Integer,MType<:StochasticModel,yType,sType} m::mType # number of particles StochM::MType # stochastic model ys::yType # observations sample_strategy::sType # sampling function end ``` ### Kalman filter We consider a stochastic program that fulfills the assumptions of a Kalman filter. We follow [Kalman.jl](https://github.com/mschauer/Kalman.jl/blob/master/README.md) to implement a differentiable version. Our `KalmanFilter` type wraps a stochastic model `StochM::StochasticModel` and observational data `ys`. It assumes a observation-likelihood function is implemented via `llikelihood(yres, S)`. The Kalman filter contains the following fields: - `d`: dimension of the state-transition matrix ``\Phi`` according to ``x = \Phi x + w`` with ``w \sim \operatorname{Normal}(0,Q)``. - `StochM`: Stochastic model of type `StochasticModel`. - `H`: linear map from the state space into the observed space according to ``y = H x + \nu`` with ``\nu \sim \operatorname{Normal}(0,R)``. - `R`: covariance matrix entering the observation model according to ``y = H x + \nu`` with ``\nu \sim \operatorname{Normal}(0,R)``. - `Q`: covariance matrix entering the state-transition model according to ``x = \Phi x + w`` with ``w \sim \operatorname{Normal}(0,Q)``. - `ys`: observations. ```@example particle_filter llikelihood(yres, S) = GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres) struct KalmanFilter{dType<:Integer,MType<:StochasticModel,HType,RType,QType,yType} # H, R = obs # θ, Q = dyn d::dType StochM::MType # stochastic model H::HType # observation model, maps the true state space into the observed space R::RType # observation model, covariance matrix Q::QType # dynamical model, covariance matrix ys::yType # observations end ``` To get observations `ys` from the latent states `xs` based on the (true, potentially unknown) parameters `θ`, we simulate a single particle from the forward model returning a vector of observations (no resampling steps). ```@example particle_filter function simulate_single(StochM::StochasticModel, θ) @unpack T, start, dyn, obs = StochM x = rand(start(θ)) y = rand(obs(x, θ)) xs = [x] ys = [y] for t in 2:T x = rand(dyn(x, θ)) y = rand(obs(x, θ)) push!(xs, x) push!(ys, y) end xs, ys end ``` A particle filter becomes efficient if resampling steps are included. Resampling is numerically attractive because particles with small weight are discarded, so computational resources are not wasted on particles with vanishing weight. Here, let us implement a stratified resampling strategy, see for example [Murray (2012)](https://arxiv.org/abs/1202.6163), where `p` denotes the probabilities of `K` particles with `sump = sum(p)`. ```@example particle_filter function sample_stratified(p, K, sump=1) n = length(p) U = rand() is = zeros(Int, K) i = 1 cw = p[1] for k in 1:K t = sump * (k - 1 + U) / K while cw < t && i < n i += 1 @inbounds cw += p[i] end is[k] = i end return is end ``` This sampling strategy can be used within a differentiable resampling step in our particle filter using the `use_new_weight` function as implemented in `StochasticAD.jl`. The `resample` function below returns the states `X_new` and weights `W_new` of the resampled particles. - `m`: number of particles. - `X`: current particle states. - `W`: current weight vector of the particles. - `ω == sum(W)` is an invariant. - `sample_strategy`: specific resampling strategy to be used. For example, `sample_stratified`. - `use_new_weight=true`: Allows one to switch between biased, stop-gradient method and differentiable resampling step. ```@example particle_filter function resample(m, X, W, ω, sample_strategy, use_new_weight=true) js = Zygote.ignore(() -> sample_strategy(W, m, ω)) X_new = X[js] if use_new_weight # differentiable resampling W_chosen = W[js] W_new = map(w -> ω * new_weight(w / ω) / m, W_chosen) else # stop gradient, biased approach W_new = fill(ω / m, m) end X_new, W_new end ``` Note that we added a `if` condition that allows us to switch between the differentiable resampling step and the stop-gradient approach. We're now equipped with all primitive operations to set up the particle filter, which propagates particles with weights `W` preserving the invariant `ω == sum(W)`. We never normalize `W` and, therefore, `ω` in the code below contains likelihood information. The particle-filter implementation defaults to return particle positions and weights at `T` if `store_path=false` and takes the following input arguments: - `θ`: parameters for the stochastic program (state-transition and observation model). - `store_path=false`: Option to store the path of the particles, e.g. to visualize/inspect their trajectories. - `use_new_weight=true`: Option to switch between the stop-gradient and our differentiable resampling step method. Defaults to using differentiable resampling. - `s`: controls the number of resampling steps according to `t > 1 && t < T && (t % s == 0)`. ```@example particle_filter function (F::ParticleFilter)(θ; store_path=false, use_new_weight=true, s=1) # s controls the number of resampling steps @unpack m, StochM, ys, sample_strategy = F @unpack T, start, dyn, obs = StochM X = [rand(start(θ)) for j in 1:m] # particles W = [1 / m for i in 1:m] # weights ω = 1 # total weight store_path && (Xs = [X]) for (t, y) in zip(1:T, ys) # update weights & likelihood using observations wi = map(x -> pdf(obs(x, θ), y), X) W = W .* wi ω_old = ω ω = sum(W) # resample particles if t > 1 && t < T && (t % s == 0) # && 1 / sum((W / ω) .^ 2) < length(W) ÷ 32 X, W = resample(m, X, W, ω, sample_strategy, use_new_weight) end # update particle states if t < T X = map(x -> rand(dyn(x, θ)), X) store_path && Zygote.ignore(() -> push!(Xs, X)) end end (store_path ? Xs : X), W end ``` Following [Kalman.jl](https://github.com/mschauer/Kalman.jl/blob/master/README.md), we implement a differentiable Kalman filter to check the ground-truth gradient. Our Kalman filter returns an updated posterior state estimate and the log-likelihood and takes the parameters of the stochastic program as an input. ```@example particle_filter function (F::KalmanFilter)(θ) @unpack d, StochM, H, R, Q = F @unpack start = StochM x = start(θ) Φ = reshape(θ, d, d) x, yres, S = GaussianDistributions.correct(x, ys[1] + R, H) ll = llikelihood(yres, S) xs = Any[x] for i in 2:length(ys) x = Φ * x ⊕ Q x, yres, S = GaussianDistributions.correct(x, ys[i] + R, H) ll += llikelihood(yres, S) push!(xs, x) end xs, ll end ``` For both filters, it is straightforward to obtain the log-likelihood via: ```@example particle_filter function log_likelihood(F::ParticleFilter, θ, use_new_weight=true, s=1) _, W = F(θ; store_path=false, use_new_weight=use_new_weight, s=s) log(sum(W)) end ``` and ```@example particle_filter function log_likelihood(F::KalmanFilter, θ) _, ll = F(θ) ll end ``` For convenience, we define functions for - forward-mode AD (and differentiable resampling step) to compute the gradient of the log-likelihood of the particle filter. - reverse-mode AD (and differentiable resampling step) to compute the gradient of the log-likelihood of the particle filter. - forward-mode AD (and stop-gradient method) to compute the gradient of the log-likelihood of the particle filter (without the `new_weight` function). - forward-mode AD to compute the gradient of the log-likelihood of the Kalman filter. ```@example particle_filter forw_grad(θ, F::ParticleFilter; s=1) = ForwardDiff.gradient(θ -> log_likelihood(F, θ, true, s), θ) back_grad(θ, F::ParticleFilter; s=1) = Zygote.gradient(θ -> log_likelihood(F, θ, true, s), θ)[1] forw_grad_biased(θ, F::ParticleFilter; s=1) = ForwardDiff.gradient(θ -> log_likelihood(F, θ, false, s), θ) forw_grad_Kalman(θ, F::KalmanFilter) = ForwardDiff.gradient(θ -> log_likelihood(F, θ), θ) ``` ## Model Having set up all core functionalities, we can now define the specific stochastic model. We consider the following system with a ``d``-dimensional latent process, ```math \begin{aligned} x_i &= \Phi x_{i-1} + w_i &\text{ with } w_i \sim \operatorname{Normal}(0,Q),\\ y_i &= x_i + \nu_i &\text{ with } \nu_i \sim \operatorname{Normal}(0,R), \end{aligned} ``` where ``\Phi`` is a ``d``-dimensional rotation matrix. ```@example particle_filter seed = 423897 ### Define model # here: n-dimensional rotation matrix Random.seed!(seed) T = 20 # time steps d = 2 # dimension # generate a rotation matrix M = randn(d, d) c = 0.3 # scaling O = exp(c * (M - transpose(M)) / 2) @assert det(O) ≈ 1 @assert transpose(O) * O ≈ I(d) θtrue = vec(O) # true parameter # observation model R = 0.01 * collect(I(d)) obs(x, θ) = MvNormal(x, R) # y = H x + ν with ν ~ Normal(0, R) # dynamical model Q = 0.02 * collect(I(d)) dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q) # x = Φ*x + w with w ~ Normal(0,Q) # starting position x0 = randn(d) # prior distribution start(θ) = Gaussian(x0, 0.001 * collect(I(d))) # put it all together stochastic_model = StochasticModel(T, start, dyn, obs) # relevant corresponding Kalman filterng defs H_Kalman = collect(I(d)) R_Kalman = Gaussian(zeros(Float64, d), R) # Φ_Kalman = O Q_Kalman = Gaussian(zeros(Float64, d), Q) ### ### simulate model Random.seed!(seed) xs, ys = simulate_single(stochastic_model, θtrue) ``` ## Visualization Using `particle_filter(θ; store_path=true)` and `kalman_filter(θ)`, it is straightforward to visualize both filters for our observed data. ```@example particle_filter m = 1000 kalman_filter = KalmanFilter(d, stochastic_model, H_Kalman, R_Kalman, Q_Kalman, ys) particle_filter = ParticleFilter(m, stochastic_model, ys, sample_stratified) ``` ```@example particle_filter ### run and visualize filters Xs, W = particle_filter(θtrue; store_path=true) fig = plot(getindex.(xs, 1), getindex.(xs, 2), legend=false, xlabel=L"x_1", ylabel=L"x_2") # x1 and x2 are bad names..conflicting notation scatter!(fig, getindex.(ys, 1), getindex.(ys, 2)) for i in 1:min(m, 100) # note that Xs has obs noise. local xs = [Xs[t][i] for t in 1:T] scatter!(fig, getindex.(xs, 1), getindex.(xs, 2), marker_z=1:T, color=:cool, alpha=0.1) # color to indicate time step end xs_Kalman, ll_Kalman = kalman_filter(θtrue) plot!(getindex.(mean.(xs_Kalman), 1), getindex.(mean.(xs_Kalman), 2), legend=false, color="red") png("pf_1") # hide ``` ![](pf_1.png) ## Bias We can also investigate the distribution of the gradients from the particle filter with and without differentiable resampling step, as compared to the gradient computed by differentiating the Kalman filter. ```@example particle_filter ### compute gradients Random.seed!(seed) X = [forw_grad(θtrue, particle_filter) for i in 1:200] # gradient of the particle filter *with* differentiation of the resampling step Random.seed!(seed) Xbiased = [forw_grad_biased(θtrue, particle_filter) for i in 1:200] # Gradient of the particle filter *without* differentiation of the resampling step # pick an arbitrary coordinate index = 1 # take derivative with respect to first parameter (2-dimensional example has a rotation matrix with four parameters in total) # plot histograms for the sampled derivative values fig = plot(normalize(fit(Histogram, getindex.(X, index), nbins=20), mode=:pdf), legend=false) # ours plot!(normalize(fit(Histogram, getindex.(Xbiased, index), nbins=20), mode=:pdf)) # biased vline!([mean(X)[index]], color=1) vline!([mean(Xbiased)[index]], color=2) # add derivative of differentiable Kalman filter as a comparison XK = forw_grad_Kalman(θtrue, kalman_filter) vline!([XK[index]], color="black") png("pf_2") # hide ``` ![](pf_2.png) The estimator using the `new_weight` function agrees with the gradient value from the Kalman filter and the [particle filter AD scheme developed by Ścibior and Wood](https://arxiv.org/abs/2106.10314), unlike biased estimators that neglect the contribution of the derivative from the resampling step. However, the biased estimator displays a smaller variance. ## Benchmark Finally, we can use `BenchmarkTools.jl` to benchmark the run times of the primal pass with respect to forward-mode and reverse-mode AD of the particle filter. As expected, forward-mode AD outperforms reverse-mode AD for the small number of parameters considered here. ```@example particle_filter # secs for how long the benchmark should run, see https://juliaci.github.io/BenchmarkTools.jl/stable/ secs = 1 suite = BenchmarkGroup() suite["scaling"] = BenchmarkGroup(["grads"]) suite["scaling"]["primal"] = @benchmarkable log_likelihood(particle_filter, θtrue) suite["scaling"]["forward"] = @benchmarkable forw_grad(θtrue, particle_filter) suite["scaling"]["backward"] = @benchmarkable back_grad(θtrue, particle_filter) tune!(suite) results = run(suite, verbose=true, seconds=secs) t1 = measurement(mean(results["scaling"]["primal"].times), std(results["scaling"]["primal"].times) / sqrt(length(results["scaling"]["primal"].times))) t2 = measurement(mean(results["scaling"]["forward"].times), std(results["scaling"]["forward"].times) / sqrt(length(results["scaling"]["forward"].times))) t3 = measurement(mean(results["scaling"]["backward"].times), std(results["scaling"]["backward"].times) / sqrt(length(results["scaling"]["backward"].times))) @show t1 t2 t3 ts = (t1, t2, t3) ./ 10^6 # ms @show ts ``` ================================================ FILE: docs/src/tutorials/random_walk.md ================================================ # Random walk ```@setup random_walk import Pkg Pkg.activate("../../../tutorials") Pkg.develop(path="../../..") Pkg.instantiate() ``` In this tutorial, we differentiate a random walk over the integers using `StochasticAD`. We will need the following packages, ```@example random_walk using Distributions # defines several supported discrete distributions using StochasticAD using StaticArrays # for more efficient small arrays ``` ## Setting up the random walk Let's define a function for simulating the walk. ```@example random_walk function simulate_walk(probs, steps, n) state = 0 for i in 1:n probs_here = probs(state) # transition probabilities for possible steps step_index = rand(Categorical(probs_here)) # which step do we take? step = steps[step_index] # get size of step state += step end return state end ``` Here, `steps` is a (1-indexed) array of the possible steps we can take. Each of these steps has a certain probability. To make things more interesting, we take in a *function* `probs` to produce these probabilities that can depend on the current state of the random walk. Let's zoom in on the two lines where discrete randomness is involved. ``` step_index = rand(Categorical(probs_here)) # which step do we take? step = steps[step_index] # get size of step ``` This is a cute pattern for making a discrete choice. First, we sample from a `Categorical` distribution from `Distributions.jl`, using the probabilities `probs_here` at our current position. This gives us an index between `1` and `length(steps)`, which we can use to pick the actual step to take. Stochastic triples propagate through both steps! ## Differentiating the random walk Let's define a toy problem. We consider a random walk with `-1` and `+1` steps, where the probability of `+1` starts off high but decays exponentially with a decay length of `p`. We take `n = 100` steps and set `p = 50`. ```@example random_walk using StochasticAD const steps = SA[-1, 1] # move left or move right make_probs(p) = X -> SA[1 - exp(-X / p), exp(-X / p)] f(p, n) = simulate_walk(make_probs(p), steps, n) @show f(50, 100) # let's run a single random walk with p = 50 @show stochastic_triple(p -> f(p, 100), 50) # let's see how a single stochastic triple looks like at p = 50 ``` Time to differentiate! For fun, let's differentiate the *square* of the output of the random walk. ```@example random_walk f_squared(p, n) = f(p, n)^2 samples = [derivative_estimate(p -> f_squared(p, 100), 50) for i in 1:1000] # many samples from derivative program at p = 50 derivative = mean(samples) uncertainty = std(samples) / sqrt(1000) println("derivative of 𝔼[f_squared] = $derivative ± $uncertainty") ``` ## Computing variance A crucial figure of merit for a derivative estimator is its variance. We compute the standard deviation (square root of the variance) of our estimator over a range of `n`. ```@example random_walk n_range = 10:10:100 # range for testing asymptotic variance behaviour p_range = 2 .* n_range nsamples = 10000 stds_triple = Float64[] for (n, p) in zip(n_range, p_range) std_triple = std(derivative_estimate(p -> f_squared(p, n), p) for i in 1:(nsamples)) push!(stds_triple, std_triple) end @show stds_triple ``` For comparison with other unbiased estimators, we also compute `stds_score` and `stds_score_baseline` for the [score function gradient estimator](https://arxiv.org/pdf/1906.10652.pdf), both without and with a variance-reducing batch-average control variate (CV). (For details, see [`core.jl`](https://github.com/gaurav-arya/StochasticAD.jl/blob/main/tutorials/random_walk/core.jl) and [`compare_score.jl`](https://github.com/gaurav-arya/StochasticAD.jl/blob/main/random_walk/compare_score.jl).) We can now graph the standard deviation of each estimator versus $n$, observing lower variance in the unbiased derivative estimate produced by stochastic triples: ```@raw html ``` ⠀ ================================================ FILE: docs/src/tutorials/reverse_demo.md ================================================ ```@meta EditURL = "../../../tutorials/reverse_example/reverse_demo.jl" ``` # Simple reverse mode example ```@setup random_walk import Pkg Pkg.activate("../../../tutorials") Pkg.develop(path="../../..") Pkg.instantiate() import Random Random.seed!(1234) ``` Load our packages ````@example reverse_demo using StochasticAD using Distributions using Enzyme using LinearAlgebra ```` Let us define our target function. ````@example reverse_demo # Define a toy `StochasticAD`-differentiable function for computing an integer value from a string. string_value(strings, index) = Int(sum(codepoint, strings[index])) string_value(strings, index::StochasticTriple) = StochasticAD.propagate(index -> string_value(strings, index), index) function f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling()) strings = ["cat", "dog", "meow", "woofs"] index = randst(Categorical(θ); derivative_coupling) return string_value(strings, index) end θ = [0.1, 0.5, 0.3, 0.1] @show f(θ) nothing ```` First, let's compute the sensitivity of `f` in a particular direction via forward-mode Stochastic AD. ````@example reverse_demo u = [1.0, 2.0, 4.0, -7.0] @show derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u) nothing ```` Now, let's do the same with reverse-mode. ````@example reverse_demo @show derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins)))) ```` Let's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative. ````@example reverse_demo forward() = derivative_estimate(f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u) reverse() = derivative_estimate(f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins)))) N = 40000 directional_derivs_fwd = [forward() for i in 1:N] derivs_bwd = [reverse() for i in 1:N] directional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd] println("Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))") println("Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))") @assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2) nothing ```` --- *This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* ================================================ FILE: ext/StochasticADEnzymeExt.jl ================================================ module StochasticADEnzymeExt using StochasticAD using Enzyme function enzyme_target(u, X, p, backend) # equivalent to derivative_estimate(X, p; backend, direction = u), but specialize to real output to make Enzyme happier st = StochasticAD.stochastic_triple_direction(X, p, u; backend) if !(StochasticAD.valtype(st) <: Real) error("EnzymeReverseAlgorithm only supports real-valued outputs.") end return derivative_contribution(st) end function StochasticAD.derivative_estimate(X, p, alg::StochasticAD.EnzymeReverseAlgorithm; direction = nothing, alg_data = (; forward_u = nothing)) if !isnothing(direction) error("EnzymeReverseAlgorithm does not support keyword argument `direction`") end if p isa AbstractVector Δu = zeros(float(eltype(p)), length(p)) u = isnothing(alg_data.forward_u) ? rand(StochasticAD.RNG, float(eltype(p)), length(p)) : alg_data.forward_u autodiff(Enzyme.Reverse, enzyme_target, Active, Duplicated(u, Δu), Const(X), Const(p), Const(alg.backend)) return Δu elseif p isa Real u = isnothing(alg_data.forward_u) ? rand(StochasticAD.RNG, float(typeof(p))) : forward_u ((du, _, _, _),) = autodiff(Enzyme.Reverse, enzyme_target, Active, Active(u), Const(X), Const(p), Const(alg.backend)) return du else error("EnzymeReverseAlgorithm only supports p::Real or p::AbstractVector") end end end ================================================ FILE: src/StochasticAD.jl ================================================ module StochasticAD ### Public API export stochastic_triple, derivative_contribution, perturbations, smooth_triple, dual_number, StochasticTriple # For working with stochastic triples export derivative_estimate, StochasticModel, stochastic_gradient # Higher level functionality export new_weight # Particle resampling export PrunedFIsBackend, PrunedFIsAggressiveBackend, DictFIsBackend, SmoothedFIsBackend, StrategyWrapperFIsBackend export PrunedFIs, PrunedFIsAggressive, DictFIs, SmoothedFIs, StrategyWrapperFIs export randst export InversionMethodDerivativeCoupling ### Imports using Random using Distributions using DistributionsAD using ChainRulesCore using ChainRulesOverloadGeneration using ExprTools using ForwardDiff using Functors import ChainRulesCore # resolve conflicts while this code exists in both. const on_new_rule = ChainRulesOverloadGeneration.on_new_rule const refresh_rules = ChainRulesOverloadGeneration.refresh_rules const RNG = copy(Random.default_rng()) ### Files responsible for backends include("finite_infinitesimals.jl") include("backends/pruned.jl") include("backends/pruned_aggressive.jl") include("backends/dict.jl") include("backends/smoothed.jl") include("backends/abstract_wrapper.jl") include("backends/strategy_wrapper.jl") using .PrunedFIsModule using .PrunedFIsAggressiveModule using .DictFIsModule using .SmoothedFIsModule using .AbstractWrapperFIsModule using .StrategyWrapperFIsModule include("prelude.jl") # Defines global constants include("smoothing.jl") # Smoothing rules. Placed before general rules so that new_weight frule is caught by overload generation. include("stochastic_triple.jl") # Defines stochastic triple object and higher level functions include("general_rules.jl") # Defines rules for propagation through deterministic functions include("discrete_randomness.jl") # Defines rules for propagation through discrete random functions include("propagate.jl") # Experimental generalized forward propagation functionality include("algorithms.jl") # Add algorithm-based higher-level interface include("misc.jl") # Miscellaneous functions that do not fit in the usual flow end ================================================ FILE: src/algorithms.jl ================================================ abstract type AbstractStochasticADAlgorithm end """ ForwardAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm A differentiation algorithm relying on forward propagation of stochastic triples. The `backend` argument controls the algorithm used by the third component of the stochastic triples. !!! note The required computation time for forward-mode AD scales linearly with the number of parameters in `p` (but is unaffected by the number of parameters in `X(p)`). """ struct ForwardAlgorithm{B <: StochasticAD.AbstractFIsBackend} <: AbstractStochasticADAlgorithm backend::B end """ EnzymeReverseAlgorithm(backend::StochasticAD.AbstractFIsBackend) <: AbstractStochasticADAlgorithm A differentiation algorithm relying on transposing the propagation of stochastic triples to produce a reverse-mode algorithm. The transposition is performed by [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl), which must be loaded for the algorithm to run. Currently, only real- and vector-valued inputs are supported, and only real-valued outputs are supported. The `backend` argument controls the algorithm used by the third component of the stochastic triples. In the call to `derivative_estimate`, this algorithm optionally accepts `alg_data` with the field `forward_u`, which specifies the directional derivative used in the forward pass that will be transposed. If `forward_u` is not provided, it is randomly generated. !!! warning For the reverse-mode algorithm to yield correct results, the employed `backend` cannot use input-dependent pruning strategies. A suggested reverse-mode compatible backend is `PrunedFIsBackend(Val(:wins))`. Additionally, this algorithm relies on the ability of `Enzyme.jl` to differentiate the forward stochastic triple run. It is recommended to check that the primal function `X` is type stable for its input `p` using a tool such as [JET.jl](https://github.com/aviatesk/JET.jl), with all code executed in a function with no global state. In addition, sometimes `X` may be type stable but stochastic triples introduce additional type instabilities. This can be debugged by checking type stability of Enzyme's target, which is `Base.get_extension(StochasticAD, :StochasticADEnzymeExt).enzyme_target(u, X, p, backend)`, where `u` is a test direction. !!! note For more details on the reverse-mode approach, see the following papers and talks: * ["You Only Linearize Once: Tangents Transpose to Gradients"](https://arxiv.org/abs/2204.10923), Radul et al. 2022. * ["Reverse mode ADEV via YOLO: tangent estimators transpose to gradient estimators"](https://www.youtube.com/watch?v=pnPmk-leSsE), Becker et al. 2024 * ["Probabilistic Programming with Programmable Variational Inference"](https://pldi24.sigplan.org/details/pldi-2024-papers/87/Probabilistic-Programming-with-Programmable-Variational-Inference), Becker et al. 2024 """ struct EnzymeReverseAlgorithm{B <: StochasticAD.AbstractFIsBackend} backend::B end function derivative_estimate( X, p, alg::ForwardAlgorithm; direction = nothing, alg_data::NamedTuple = (;)) return derivative_estimate(X, p; backend = alg.backend, direction) end @doc raw""" derivative_estimate(X, p, alg::AbstractStochasticADAlgorithm = ForwardAlgorithm(PrunedFIsBackend()); direction=nothing, alg_data::NamedTuple = (;)) Compute an unbiased estimate of ``\frac{\mathrm{d}\mathbb{E}[X(p)]}{\mathrm{d}p}``, the derivative of the expectation of the random function `X(p)` with respect to its input `p`. Both `p` and `X(p)` can be any object supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/), e.g. scalars or abstract arrays. The output of `derivative_estimate` has the same outer structure as `p`, but with each scalar in `p` replaced by a derivative estimate of `X(p)` with respect to that entry. For example, if `X(p) <: AbstractMatrix` and `p <: Real`, then the output would be a matrix. The `alg` keyword argument specifies the [algorithm](public_api.md#Algorithms) used to compute the derivative estimate. For backward compatibility, an additional signature `derivative_estimate(X, p; backend, direction=nothing)` is supported, which uses `ForwardAlgorithm` by default with the supplied `backend.` The `alg_data` keyword argument can specify any additional data that specific algorithms accept or require. When `direction` is provided, the output is only differentiated with respect to a perturbation of `p` in that direction. # Example ```jldoctest julia> using Distributions, Random, StochasticAD; Random.seed!(4321); julia> derivative_estimate(rand ∘ Bernoulli, 0.5) # A random quantity that averages to the true derivative. 2.0 julia> derivative_estimate(x -> [rand(Bernoulli(x * i/4)) for i in 1:3], 0.5) 3-element Vector{Float64}: 0.2857142857142857 0.6666666666666666 0.0 ``` """ derivative_estimate ================================================ FILE: src/backends/abstract_wrapper.jl ================================================ module AbstractWrapperFIsModule import ..StochasticAD export AbstractWrapperFIs """ AbstractWrapperFIs{V, FIs} <: StochasticAD.AbstractFIs{V} A convenience type for backend strategies that wrap another backend. A subtype `WrapperFIs <: AbstractWrapperFIs` should have a field called Δs containing the wrapped backend, and should also define the following methods: * `StochasticAD.similar_type(::Type{<:WrapperFIs}, V, FIs)`: return the type of a new `WrapperFIs` with value type `V` and wrapped backend type `FIs`, * `AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::WrapperFIs, Δs::AbstractFIs)`: construct a new `WrapperFIs` wrapping `Δs` given an existing wrapped instace `wrapper_Δs`. * `AbstractWrapperFIsModule.reconstruct_wrapper(::Type{<:WrapperFIs}, Δs::AbstractFIs)`: construct a new `WrapperFIs` wrapping `Δs` given the type of an existing `WrapperFIs`. Then, all other methods will generically be forwarded to the inner backend, except those overloaded by the specific wrapper type. """ abstract type AbstractWrapperFIs{V, FIs} <: StochasticAD.AbstractFIs{V} end function reconstruct_wrapper end function StochasticAD.similar_new(Δs::AbstractWrapperFIs, Δ, w) reconstruct_wrapper(Δs, StochasticAD.similar_new(Δs.Δs, Δ, w)) end function StochasticAD.similar_empty(Δs::AbstractWrapperFIs, V) reconstruct_wrapper(Δs, StochasticAD.similar_empty(Δs.Δs, V)) end function StochasticAD.similar_type(WrapperFIs::Type{<:AbstractWrapperFIs{V0, FIs}}, V) where {V0, FIs} return StochasticAD.similar_type(WrapperFIs, V, StochasticAD.similar_type(FIs, V)) end StochasticAD.valtype(Δs::AbstractWrapperFIs) = StochasticAD.valtype(Δs.Δs) function StochasticAD.couple(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}, Δs_all; rep = nothing, kwargs...) where {V, FIs} _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all) _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;) return reconstruct_wrapper(StochasticAD.get_any(Δs_all), StochasticAD.couple(FIs, _Δs_all; _rep_kwarg..., kwargs...)) end function StochasticAD.combine(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}, Δs_all; rep = nothing, kwargs...) where {V, FIs} _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all) _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;) return reconstruct_wrapper(StochasticAD.get_any(Δs_all), StochasticAD.combine(FIs, _Δs_all; _rep_kwarg..., kwargs...)) end function StochasticAD.get_rep(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}, Δs_all; kwargs...) where {V, FIs} _Δs_all = StochasticAD.structural_map(Δs -> Δs.Δs, Δs_all) return reconstruct_wrapper(StochasticAD.get_any(Δs_all), StochasticAD.get_rep(FIs, _Δs_all; kwargs...)) end function StochasticAD.scalarize(Δs::AbstractWrapperFIs; rep = nothing, kwargs...) _rep_kwarg = !isnothing(rep) ? (; rep = rep.Δs) : (;) return StochasticAD.structural_map(StochasticAD.scalarize( Δs.Δs; _rep_kwarg..., kwargs...)) do _Δs reconstruct_wrapper(Δs, _Δs) end end function StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs, Δs_all; kwargs...) StochasticAD.derivative_contribution(Δs.Δs, Δs_all; kwargs...) end StochasticAD.alltrue(f, Δs::AbstractWrapperFIs) = StochasticAD.alltrue(f, Δs.Δs) StochasticAD.perturbations(Δs::AbstractWrapperFIs) = StochasticAD.perturbations(Δs.Δs) function StochasticAD.filter_state(Δs::AbstractWrapperFIs, state) StochasticAD.filter_state(Δs.Δs, state) end function StochasticAD.weighted_map_Δs(f, Δs::AbstractWrapperFIs; kwargs...) reconstruct_wrapper(Δs, StochasticAD.weighted_map_Δs(f, Δs.Δs; kwargs...)) end StochasticAD.new_Δs_strategy(Δs::AbstractWrapperFIs) = StochasticAD.new_Δs_strategy(Δs.Δs) function Base.empty(WrapperFIs::Type{<:AbstractWrapperFIs{V, FIs}}) where {V, FIs} return reconstruct_wrapper(WrapperFIs, empty(FIs)) end Base.empty(Δs::AbstractWrapperFIs) = reconstruct_wrapper(Δs, empty(Δs.Δs)) Base.isempty(Δs::AbstractWrapperFIs) = isempty(Δs.Δs) Base.length(Δs::AbstractWrapperFIs) = length(Δs.Δs) Base.iszero(Δs::AbstractWrapperFIs) = iszero(Δs.Δs) function StochasticAD.derivative_contribution(Δs::AbstractWrapperFIs) StochasticAD.derivative_contribution(Δs.Δs) end function Base.convert(::Type{<:AbstractWrapperFIs{V}}, Δs::AbstractWrapperFIs) where {V} reconstruct_wrapper(Δs, convert(StochasticAD.similar_type(typeof(Δs.Δs), V), Δs.Δs)) end function StochasticAD.send_signal( Δs::AbstractWrapperFIs, signal::StochasticAD.AbstractPerturbationSignal) reconstruct_wrapper(Δs, StochasticAD.send_signal(Δs.Δs, signal)) end function Base.show(io::IO, Δs::AbstractWrapperFIs) return show(io, Δs.Δs) end end ================================================ FILE: src/backends/dict.jl ================================================ module DictFIsModule export DictFIsBackend, DictFIs import ..StochasticAD using Dictionaries """ DictFIsBackend <: StochasticAD.AbstractFIsBackend A dictionary backend algorithm which keeps entries for each perturbation that has occurred without pruning. Currently very unoptimized. """ struct DictFIsBackend <: StochasticAD.AbstractFIsBackend end """ DictFIsState State maintained by dictionary backend. """ mutable struct DictFIsState tag_count::Int64 valid::Bool DictFIsState(valid = true) = new(0, valid) end struct InfinitesimalEvent tag::Any # unique identifier w::Float64 # weight (infinitesimal probability wε) end Base.:<(event1::InfinitesimalEvent, event2::InfinitesimalEvent) = event1.tag < event2.tag function Base.:(==)(event1::InfinitesimalEvent, event2::InfinitesimalEvent) event1.tag == event2.tag end Base.:isless(event1::InfinitesimalEvent, event2::InfinitesimalEvent) = event1 < event2 """ DictFIs{V} <: StochasticAD.AbstractFIs{V} The implementing backend structure for DictFIsBackend. """ struct DictFIs{V} <: StochasticAD.AbstractFIs{V} dict::Dictionary{InfinitesimalEvent, V} state::DictFIsState end state(Δs::DictFIs) = Δs.state ### Empty / no perturbation function DictFIs{V}(state::DictFIsState) where {V} DictFIs{V}(Dictionary{InfinitesimalEvent, V}(), state) end StochasticAD.similar_empty(Δs::DictFIs, V::Type) = DictFIs{V}(Δs.state) Base.empty(Δs::DictFIs{V}) where {V} = StochasticAD.similar_empty(Δs::DictFIs, V::Type) function Base.empty(::Type{<:DictFIs{V}}) where {V} DictFIs{V}(DictFIsState(false)) end ### Create a new perturbation with infinitesimal probability function new_perturbation(Δ::V, w::Real, state::DictFIsState) where {V} state.tag_count += 1 event = InfinitesimalEvent(state.tag_count, w) DictFIs{V}(Dictionary([event], [Δ]), state) end function StochasticAD.similar_new(Δs::DictFIs, Δ::V, w::Real) where {V} new_perturbation(Δ, w, Δs.state) end ### Create Δs backend for the first stochastic triple of computation StochasticAD.create_Δs(::DictFIsBackend, V) = DictFIs{V}(DictFIsState()) ### Convert type of a backend function Base.convert(::Type{DictFIs{V}}, Δs::DictFIs) where {V} DictFIs{V}(convert(Dictionary{InfinitesimalEvent, V}, Δs.dict), Δs.state) end ### Getting information about Δs Base.isempty(Δs::DictFIs) = isempty(Δs.dict) Base.length(Δs::DictFIs) = length(Δs.dict) Base.iszero(Δs::DictFIs) = isempty(Δs) || all(iszero.(Δs.dict)) function StochasticAD.derivative_contribution(Δs::DictFIs{V}) where {V} sum((Δ * event.w for (event, Δ) in pairs(Δs.dict)), init = zero(V) * 0.0) end function StochasticAD.perturbations(Δs::DictFIs) [(; Δ, weight = event.w, state = event) for (event, Δ) in pairs(Δs.dict)] end ### Unary propagation function StochasticAD.weighted_map_Δs(f, Δs::DictFIs; kwargs...) # Pass key as state in map mapped_values_and_weights = map(f, collect(Δs.dict), keys(Δs.dict)) mapped_values = first.(mapped_values_and_weights) mapped_weights = last.(mapped_values_and_weights) scaled_events = map((event, a) -> InfinitesimalEvent(event.tag, event.w * a), keys(Δs.dict), mapped_weights) # TODO: should original events (with old tag) also be modified? dict = Dictionary(scaled_events, mapped_values) DictFIs(dict, Δs.state) end StochasticAD.alltrue(f, Δs::DictFIs) = all(map(f, collect(Δs.dict))) ### Coupling function StochasticAD.get_rep(::Type{<:DictFIs}, Δs_all) for Δs in StochasticAD.structural_iterate(Δs_all) if Δs.state.valid return Δs end end return first(Δs_all) end function StochasticAD.couple(FIs::Type{<:DictFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all), out_rep = nothing, kwargs...) all_keys = Iterators.map(StochasticAD.structural_iterate(Δs_all)) do Δs keys(Δs.dict) end distinct_keys = unique(all_keys |> Iterators.flatten) Δs_coupled_dict = [StochasticAD.structural_map( Δs -> isassigned(Δs.dict, key) ? Δs.dict[key] : zero(eltype(Δs.dict)), Δs_all) for key in distinct_keys] DictFIs(Dictionary(distinct_keys, Δs_coupled_dict), rep.state) end function StochasticAD.combine(FIs::Type{<:DictFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...) Δs_dicts = Iterators.map(Δs -> Δs.dict, StochasticAD.structural_iterate(Δs_all)) Δs_combined_dict = reduce(Δs_dicts) do Δs_dict1, Δs_dict2 mergewith((x, y) -> StochasticAD.structural_map(+, x, y), Δs_dict1, Δs_dict2) end DictFIs(Δs_combined_dict, rep.state) end function StochasticAD.scalarize(Δs::DictFIs; out_rep = nothing) # TODO: use vcat here? tupleify(Δ1, Δ2) = StochasticAD.structural_map(tuple, Δ1, Δ2) Δ_all_allkeys = foldl(tupleify, values(Δs.dict)) Δ_all_rep = first(values(Δs.dict)) _keys = keys(Δs.dict) return StochasticAD.structural_map(Δ_all_rep, Δ_all_allkeys) do _, Δ_allkeys return DictFIs(Dictionary(_keys, Δ_allkeys), Δs.state) end end function StochasticAD.filter_state(Δs::DictFIs{V}, key) where {V} haskey(Δs.dict, key) ? Δs.dict[key] : zero(V) end ### Miscellaneous StochasticAD.similar_type(::Type{<:DictFIs}, V::Type) = DictFIs{V} StochasticAD.valtype(::Type{<:DictFIs{V}}) where {V} = V end ================================================ FILE: src/backends/pruned.jl ================================================ module PrunedFIsModule import ..StochasticAD export PrunedFIsBackend, PrunedFIs """ PrunedFIsBackend <: StochasticAD.AbstractFIsBackend A backend algorithm that prunes between perturbations as soon as they clash (e.g. added together). Currently chooses uniformly between all perturbations. """ struct PrunedFIsBackend{M <: Val} <: StochasticAD.AbstractFIsBackend pruning_mode::M function PrunedFIsBackend(pruning_mode::M = Val(:weights)) where {M} if pruning_mode isa Val{:weights} || pruning_mode isa Val{:wins} return new{M}(pruning_mode) else error("Unsupported pruning_mode $pruning_mode for `PrunedFIsBackend.") end end end """ PrunedFIsState State maintained by pruning backend. """ mutable struct PrunedFIsState{M, W} tag::Int32 weight::Float64 valid::Bool # TODO: generalize (wins, pruning_mode) into a general interface for accumulating state # that informs future pruning decisions. wins::W pruning_mode::M function PrunedFIsState(pruning_mode::M, valid = true) where {M <: Val} wins = pruning_mode isa Val{:wins} ? (valid ? 1 : 0) : nothing state::PrunedFIsState = new{M, typeof(wins)}(0, 0.0, valid, wins) state.tag = objectid(state) % typemax(Int32) return state end end Base.:(==)(state1::PrunedFIsState, state2::PrunedFIsState) = state1.tag == state2.tag # c.f. https://github.com/JuliaLang/julia/blob/61c3521613767b2af21dfa5cc5a7b8195c5bdcaf/base/hashing.jl#L38C45-L38C51 Base.hash(state::PrunedFIsState) = state.tag """ PrunedFIs{V} <: StochasticAD.AbstractFIs{V} The implementing backend structure for PrunedFIsBackend. """ struct PrunedFIs{V, S <: PrunedFIsState} <: StochasticAD.AbstractFIs{V} Δ::V state::S end ### Empty / no perturbation PrunedFIs{V}(Δ::V, state::S) where {V, S <: PrunedFIsState} = PrunedFIs{V, S}(Δ, state) PrunedFIs{V}(state::PrunedFIsState) where {V} = PrunedFIs{V}(zero(V), state) # TODO: avoid allocations here function StochasticAD.similar_empty(Δs::PrunedFIs, V::Type) PrunedFIs{V}(PrunedFIsState(Δs.state.pruning_mode, false)) end Base.empty(Δs::PrunedFIs{V}) where {V} = StochasticAD.similar_empty(Δs::PrunedFIs, V::Type) # we truly have no clue what the state is here, so use an invalidated state function Base.empty(::Type{<:PrunedFIs{V, S}}) where {V, M, S <: PrunedFIsState{M}} PrunedFIs{V}(PrunedFIsState(M(), false)) end ### Create a new perturbation with infinitesimal probability function StochasticAD.similar_new(Δs::PrunedFIs, Δ::V, w::Real) where {V} if iszero(w) return StochasticAD.similar_empty(Δs, V) end state = PrunedFIsState(Δs.state.pruning_mode) state.weight += w Δs = PrunedFIs{V}(Δ, state) return Δs end ### Create Δs backend for the first stochastic triple of computation function StochasticAD.create_Δs(backend::PrunedFIsBackend, V) PrunedFIs{V}(PrunedFIsState(backend.pruning_mode, false)) end ### Convert type of a backend function Base.convert(::Type{<:PrunedFIs{V}}, Δs::PrunedFIs) where {V} PrunedFIs{V}(convert(V, Δs.Δ), Δs.state) end ### Getting information about perturbations # "empty" here means no perturbation or a perturbation that has been pruned away Base.isempty(Δs::PrunedFIs) = !Δs.state.valid Base.length(Δs::PrunedFIs) = isempty(Δs) ? 0 : 1 function Base.iszero(Δs::PrunedFIs) isempty(Δs) || all(iszero, StochasticAD.structural_iterate(Δs.Δ)) end Base.iszero(Δs::PrunedFIs{<:Real}) = isempty(Δs) || iszero(Δs.Δ) Base.iszero(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) || all(iszero.(Δs.Δ)) isapproxzero(Δs::PrunedFIs) = isempty(Δs) || isapprox(Δs.Δ, zero(Δs.Δ)) # we lazily prune, so check if empty first function pruned_value(Δs::PrunedFIs{V}) where {V} isempty(Δs) ? StochasticAD.structural_map(zero, Δs.Δ) : Δs.Δ end pruned_value(Δs::PrunedFIs{<:Real}) = isempty(Δs) ? zero(Δs.Δ) : Δs.Δ pruned_value(Δs::PrunedFIs{<:Tuple}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ pruned_value(Δs::PrunedFIs{<:AbstractArray}) = isempty(Δs) ? zero.(Δs.Δ) : Δs.Δ StochasticAD.derivative_contribution(Δs::PrunedFIs) = pruned_value(Δs) * Δs.state.weight function StochasticAD.perturbations(Δs::PrunedFIs) ((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),) end ### Unary propagation function StochasticAD.weighted_map_Δs(f, Δs::PrunedFIs; kwargs...) Δ_out, weight_out = f(pruned_value(Δs), Δs.state) # TODO: we could add a direct overload for map_Δs that elides the below line Δs.state.weight *= weight_out PrunedFIs(Δ_out, Δs.state) end StochasticAD.alltrue(f, Δs::PrunedFIs) = f(pruned_value(Δs)) ### Coupling function StochasticAD.get_rep(FIs::Type{<:PrunedFIs}, Δs_all) return empty(FIs) #StochasticAD.get_any(Δs_all) end function get_pruned_state(Δs_all; Δ_func = nothing, rep, out_rep = nothing) if !isnothing(Δ_func) && isnothing(out_rep) error("Specifying Δ_func requires out_rep to be specified.") end function op(cur_state, Δs) # lazy pruning optimization temporarily disabled with custom Δ_func # (because custom Δ_func's may prefer not to lazily prune) (isnothing(Δ_func) && isapproxzero(Δs)) && return cur_state candidate_state = Δs.state if !candidate_state.valid || (candidate_state == cur_state) return cur_state end if !cur_state.valid return candidate_state end # Compute "strength" of each perturbation for pruning proposal if !isnothing(Δ_func) # TODO: structural_map for each state can take asymptotically more time than necessary when combining many distinct states candidate_Δ = StochasticAD.structural_map( Base.Fix2(StochasticAD.filter_state, candidate_state), Δs_all) candidate_Δ_func::Float64 = Δ_func(candidate_Δ, candidate_state, out_rep) cur_Δ = StochasticAD.structural_map( Base.Fix2(StochasticAD.filter_state, cur_state), Δs_all) cur_Δ_func::Float64 = Δ_func(cur_Δ, cur_state, out_rep) else candidate_Δ_func = 1.0 cur_Δ_func = 1.0 end candidate_intrinsic_strength = Δs.state.pruning_mode isa Val{:wins} ? candidate_state.wins : abs(candidate_state.weight) cur_intrinsic_strength = Δs.state.pruning_mode isa Val{:wins} ? cur_state.wins : abs(cur_state.weight) candidate_strength = candidate_intrinsic_strength * candidate_Δ_func cur_strength = cur_intrinsic_strength * cur_Δ_func both_states_bad = iszero(candidate_strength) && iszero(cur_strength) if both_states_bad cur_state.valid = false candidate_state.valid = false return cur_state end # Prune between perturbations total_strength = cur_strength + candidate_strength p = candidate_strength / total_strength if isone(p) || (rand(StochasticAD.RNG) < p) cur_state.valid = false if Δs.state.pruning_mode isa Val{:wins} candidate_state.wins += 1 end candidate_state.weight *= 1 / p return candidate_state else candidate_state.valid = false if Δs.state.pruning_mode isa Val{:wins} cur_state.wins += 1 end cur_state.weight *= 1 / (1 - p) return cur_state end end dummy_state = PrunedFIsState(rep.state.pruning_mode, false) # For type stability, as well as retval if no better state found. TODO: can this be avoided? _new_state = foldl(op, StochasticAD.structural_iterate(Δs_all); init = dummy_state) return _new_state::PrunedFIsState end # for pruning, coupling amounts to getting rid of perturbed values that have been # lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid. function StochasticAD.couple( FIs::Type{<:PrunedFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all), out_rep = nothing, Δ_func = nothing, kwargs...) state = get_pruned_state(Δs_all; rep, Δ_func) Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here PrunedFIs(Δ_coupled, state) end # basically couple combined with a sum. function StochasticAD.combine( FIs::Type{<:PrunedFIs}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all), Δ_func = nothing, out_rep = nothing, kwargs...) state = get_pruned_state(Δs_all; rep, out_rep, Δ_func = !isnothing(Δ_func) ? (Δ, state, val) -> Δ_func(sum(Δ), state, val) : Δ_func) Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all)) PrunedFIs(Δ_combined, state) end function StochasticAD.scalarize(Δs::PrunedFIs; out_rep = nothing) return StochasticAD.structural_map(Δs.Δ) do Δ return PrunedFIs(Δ, Δs.state) end end function StochasticAD.filter_state(Δs::PrunedFIs{V}, state) where {V} Δs.state == state ? pruned_value(Δs) : zero(V) end ### Miscellaneous function StochasticAD.similar_type(::Type{<:PrunedFIs{V0, M}}, V::Type) where {V0, M} PrunedFIs{V, M} end StochasticAD.valtype(::Type{<:PrunedFIs{V}}) where {V} = V function Base.show(io::IO, Δs::PrunedFIs{V}) where {V} print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε") end end ================================================ FILE: src/backends/pruned_aggressive.jl ================================================ module PrunedFIsAggressiveModule import ..StochasticAD export PrunedFIsAggressiveBackend, PrunedFIsAggressive """ PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend A backend algorithm that aggressively prunes between perturbations as soon as they are created. """ struct PrunedFIsAggressiveBackend <: StochasticAD.AbstractFIsBackend end """ PrunedFIsAggressiveState State maintained by aggressive pruning backend. """ mutable struct PrunedFIsAggressiveState active_tag::Int64 # 0 is always a dummy tag weight::Float64 tag_count::Int64 valid::Bool PrunedFIsAggressiveState(valid = true) = new(0, 0.0, 0, valid) end """ PrunedFIsAggressive{V} <: StochasticAD.AbstractFIs{V} The implementing backend structure for PrunedFIsAggressiveBackend. """ struct PrunedFIsAggressive{V} <: StochasticAD.AbstractFIs{V} Δ::V tag::Int state::PrunedFIsAggressiveState # directly called when propagating an existing perturbation end ### Empty / no perturbation function PrunedFIsAggressive{V}(state::PrunedFIsAggressiveState) where {V} PrunedFIsAggressive{V}(zero(V), -1, state) end function StochasticAD.similar_empty(Δs::PrunedFIsAggressive, V::Type) PrunedFIsAggressive{V}(Δs.state) end function Base.empty(Δs::PrunedFIsAggressive{V}) where {V} StochasticAD.similar_empty(Δs, V) end # we truly have no clue what the state is here, so use an invalidated state function Base.empty(::Type{<:PrunedFIsAggressive{V}}) where {V} PrunedFIsAggressive{V}(PrunedFIsAggressiveState(false)) end ### Create a new perturbation with infinitesimal probability function new_perturbation(Δ::V, w::Real, state::PrunedFIsAggressiveState) where {V} total_weight = state.weight + w if rand(StochasticAD.RNG) * total_weight < state.weight state.weight += w return PrunedFIsAggressive{V}(state) else state.tag_count += 1 state.active_tag = state.tag_count state.weight += w return PrunedFIsAggressive{V}(Δ, state.active_tag, state) end end function StochasticAD.similar_new(Δs::PrunedFIsAggressive, Δ::V, w::Real) where {V} new_perturbation(Δ, w, Δs.state) end ### Create Δs backend for the first stochastic triple of computation function StochasticAD.create_Δs(::PrunedFIsAggressiveBackend, V) PrunedFIsAggressive{V}(PrunedFIsAggressiveState()) end ### Convert type of a backend function Base.convert(::Type{PrunedFIsAggressive{V}}, Δs::PrunedFIsAggressive) where {V} PrunedFIsAggressive{V}(convert(V, Δs.Δ), Δs.tag, Δs.state) end ### Getting information about perturbations # "empty" here means no perturbation or a perturbation that has been pruned away Base.isempty(Δs::PrunedFIsAggressive) = Δs.tag != Δs.state.active_tag Base.length(Δs::PrunedFIsAggressive) = isempty(Δs) ? 0 : 1 Base.iszero(Δs::PrunedFIsAggressive) = isempty(Δs) || iszero(Δs.Δ) # we lazily prune, so check if empty first pruned_value(Δs::PrunedFIsAggressive{V}) where {V} = isempty(Δs) ? zero(V) : Δs.Δ function StochasticAD.derivative_contribution(Δs::PrunedFIsAggressive) pruned_value(Δs) * Δs.state.weight end function StochasticAD.perturbations(Δs::PrunedFIsAggressive) ((; Δ = pruned_value(Δs), weight = Δs.state.weight, state = Δs.state),) end ### Unary propagation function StochasticAD.weighted_map_Δs(f, Δs::PrunedFIsAggressive; kwargs...) Δ_out, weight_out = f(Δs.Δ, nothing) Δs.state.weight *= weight_out PrunedFIsAggressive(Δ_out, Δs.tag, Δs.state) end StochasticAD.alltrue(f, Δs::PrunedFIsAggressive) = f(Δs.Δ) ### Coupling function StochasticAD.get_rep(::Type{<:PrunedFIsAggressive}, Δs_all) # Get some Δs with a valid state, or any if all are invalid. return reduce((Δs1, Δs2) -> Δs1.state.valid ? Δs1 : Δs2, StochasticAD.structural_iterate(Δs_all)) end # for pruning, coupling amounts to getting rid of perturbed values that have been # lazily kept around even after (aggressive or lazy) pruning made the perturbation invalid. function StochasticAD.couple(FIs::Type{<:PrunedFIsAggressive}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all), out_rep = nothing, kwargs...) state = rep.state Δ_coupled = StochasticAD.structural_map(pruned_value, Δs_all) # TODO: perhaps a performance optimization possible here PrunedFIsAggressive(Δ_coupled, state.active_tag, state) end # basically couple combined with a sum. function StochasticAD.combine(FIs::Type{<:PrunedFIsAggressive}, Δs_all; rep = StochasticAD.get_rep(FIs, Δs_all), kwargs...) state = rep.state Δ_combined = sum(pruned_value, StochasticAD.structural_iterate(Δs_all)) PrunedFIsAggressive(Δ_combined, state.active_tag, state) end function StochasticAD.scalarize(Δs::PrunedFIsAggressive; out_rep = nothing) return StochasticAD.structural_map(Δs.Δ) do Δ return PrunedFIsAggressive(Δ, Δs.tag, Δs.state) end end StochasticAD.filter_state(Δs::PrunedFIsAggressive, _) = pruned_value(Δs) ### Miscellaneous StochasticAD.similar_type(::Type{<:PrunedFIsAggressive}, V::Type) = PrunedFIsAggressive{V} StochasticAD.valtype(::Type{<:PrunedFIsAggressive{V}}) where {V} = V # should I have a mime input? function Base.show(io::IO, mime::MIME"text/plain", Δs::PrunedFIsAggressive{V}) where {V} print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε, tag $(Δs.tag)") end function Base.show(io::IO, Δs::PrunedFIsAggressive{V}) where {V} print(io, "$(pruned_value(Δs)) with probability $(Δs.state.weight)ε, tag $(Δs.tag)") end end ================================================ FILE: src/backends/smoothed.jl ================================================ module SmoothedFIsModule import ..StochasticAD export SmoothedFIsBackend, SmoothedFIs """ SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend A backend algorithm that smooths perturbations togethers. """ struct SmoothedFIsBackend <: StochasticAD.AbstractFIsBackend end """ SmoothedFIs{V} <: StochasticAD.AbstractFIs{V} The implementing backend structure for SmoothedFIsBackend. """ # TODO: make type of δ generic struct SmoothedFIs{V, V_float} <: StochasticAD.AbstractFIs{V} δ::V_float function SmoothedFIs{V}(δ) where {V} # hardcode Float64 representation for now, for simplicity. δ_f64 = StochasticAD.structural_map(Base.Fix1(convert, Float64), δ) return new{V, typeof(δ_f64)}(δ_f64) end end ### Empty / no perturbation StochasticAD.similar_empty(::SmoothedFIs, V::Type) = SmoothedFIs{V}(0.0) Base.empty(::Type{<:SmoothedFIs{V}}) where {V} = SmoothedFIs{V}(0.0) Base.empty(Δs::SmoothedFIs) = empty(typeof(Δs)) ### Create a new perturbation with infinitesimal probability function StochasticAD.similar_new(::SmoothedFIs, Δ::V, w::Real) where {V} SmoothedFIs{V}(Δ * w) end StochasticAD.new_Δs_strategy(::SmoothedFIs) = StochasticAD.TwoSidedStrategy() ### Create Δs backend for the first stochastic triple of computation StochasticAD.create_Δs(::SmoothedFIsBackend, V) = SmoothedFIs{V}(0.0) ### Convert type of a backend function Base.convert(FIs::Type{<:SmoothedFIs{V}}, Δs::SmoothedFIs) where {V} SmoothedFIs{V}(Δs.δ)::FIs end ### Getting information about perturbations Base.isempty(Δs::SmoothedFIs) = false Base.iszero(Δs::SmoothedFIs) = iszero(Δs.δ) Base.iszero(Δs::SmoothedFIs{<:Tuple}) = all(iszero.(Δs.δ)) StochasticAD.derivative_contribution(Δs::SmoothedFIs) = Δs.δ ### Unary propagation function StochasticAD.weighted_map_Δs(f, Δs::SmoothedFIs; deriv, out_rep, kwargs...) SmoothedFIs{typeof(out_rep)}(deriv(Δs.δ)) end StochasticAD.alltrue(f, Δs::SmoothedFIs) = true ### Coupling StochasticAD.get_rep(::Type{<:SmoothedFIs}, Δs_all) = StochasticAD.get_any(Δs_all) function StochasticAD.couple( ::Type{<:SmoothedFIs}, Δs_all; rep = nothing, out_rep, kwargs...) SmoothedFIs{typeof(out_rep)}(StochasticAD.structural_map(Δs -> Δs.δ, Δs_all)) end function StochasticAD.combine(::Type{<:SmoothedFIs}, Δs_all; rep = nothing, kwargs...) V_out = StochasticAD.valtype(first(StochasticAD.structural_iterate(Δs_all))) Δ_combined = sum(Δs -> Δs.δ, StochasticAD.structural_iterate(Δs_all)) SmoothedFIs{V_out}(Δ_combined) end function StochasticAD.scalarize(Δs::SmoothedFIs; out_rep) return StochasticAD.structural_map(out_rep, Δs.δ) do out, δ return SmoothedFIs{typeof(out)}(δ) end end ### Miscellaneous StochasticAD.similar_type(::Type{<:SmoothedFIs}, V::Type) = SmoothedFIs{V, Float64} StochasticAD.valtype(::Type{<:SmoothedFIs{V}}) where {V} = V function Base.show(io::IO, Δs::SmoothedFIs) print(io, "$(Δs.δ)ε") end end ================================================ FILE: src/backends/strategy_wrapper.jl ================================================ module StrategyWrapperFIsModule using ..StochasticAD using ..StochasticAD.AbstractWrapperFIsModule export StrategyWrapperFIsBackend, StrategyWrapperFIs struct StrategyWrapperFIsBackend{ B <: StochasticAD.AbstractFIsBackend, S <: StochasticAD.AbstractPerturbationStrategy } <: StochasticAD.AbstractFIsBackend backend::B strategy::S end struct StrategyWrapperFIs{ V, FIs <: StochasticAD.AbstractFIs{V}, S <: StochasticAD.AbstractPerturbationStrategy } <: AbstractWrapperFIs{V, FIs} Δs::FIs strategy::S end function StochasticAD.create_Δs(backend::StrategyWrapperFIsBackend, V) return StrategyWrapperFIs(StochasticAD.create_Δs(backend.backend, V), backend.strategy) end function StochasticAD.similar_type(::Type{<:StrategyWrapperFIs{V0, FIs0, S}}, V, FIs) where {V0, FIs0, S} return StrategyWrapperFIs{V, FIs, S} end function AbstractWrapperFIsModule.reconstruct_wrapper(wrapper_Δs::StrategyWrapperFIs, Δs) return StrategyWrapperFIs(Δs, wrapper_Δs.strategy) end function AbstractWrapperFIsModule.reconstruct_wrapper( ::Type{ <:StrategyWrapperFIs{V, FIs, S}, }, Δs) where {V, FIs, S} return StrategyWrapperFIs(Δs, S()) end StochasticAD.new_Δs_strategy(Δs::StrategyWrapperFIs) = Δs.strategy end ================================================ FILE: src/discrete_randomness.jl ================================================ ## Helper functions for discrete distributions # index of the parameter p _param_index(::Geometric) = 1 _param_index(::Bernoulli) = 1 _param_index(::Binomial) = 2 _param_index(::Poisson) = 1 _param_index(::Categorical) = 1 _get_parameter(d) = params(d)[_param_index(d)] # constructors for dist in [:Geometric, :Bernoulli, :Binomial, :Poisson, :Categorical] @eval _constructor(::$dist) = $dist end # reconstruct probability distribution with new paramter value function _reconstruct(d, p) i = _param_index(d) return _constructor(d)(params(d)[1:(i - 1)]..., p, params(d)[(i + 1):end]...) end # support of probability distribution _has_finite_support(d) = false _has_finite_support(d::Union{Bernoulli, Binomial, Categorical}) = true _get_support(d::Union{Bernoulli, Binomial, Categorical}) = minimum(d):maximum(d) # manual overloads to ensure that static-ness is preserved for Bernoulli's and Categoricals with static arrays. # since mapping over the range above could result in allocating vectors. _get_support(::Bernoulli) = (0, 1) # the map below looks a bit silly, but it gives us a collection of the categories with the same structure as probs(d). _get_support(d::Categorical) = map((val, prob) -> val, 1:ncategories(d), probs(d)) ## Derivative couplings # Derivative coupling approaches, determining which weighted perturbations to consider abstract type AbstractDerivativeCoupling end """ InversionMethodDerivativeCoupling(; mode::Val = Val(:positive_weight), handle_zeroprob::Val = Val(true)) Specifies an inversion method coupling for generating perturbations from a univariate distribution. Valid choices of `mode` are `Val(:positive_weight)`, `Val(:always_right)`, and `Val(:always_left)`. # Example ```jldoctest julia> using StochasticAD, Distributions, Random; Random.seed!(4321); julia> function X(p) return randst(Bernoulli(1 - p); derivative_coupling = InversionMethodDerivativeCoupling(; mode = Val(:always_right))) end X (generic function with 1 method) julia> stochastic_triple(X, 0.5) StochasticTriple of Int64: 0 + 0ε + (1 with probability -2.0ε) ``` """ Base.@kwdef struct InversionMethodDerivativeCoupling{M, HZP} mode::M = Val(:positive_weight) handle_zeroprob::HZP = Val(true) end # Strategies for precisely which perturbations to form given a derivative coupling struct SingleSidedStrategy <: AbstractPerturbationStrategy end struct TwoSidedStrategy <: AbstractPerturbationStrategy end struct SmoothedStraightThroughStrategy <: AbstractPerturbationStrategy end struct StraightThroughStrategy <: AbstractPerturbationStrategy end struct IgnoreDiscreteStrategy <: AbstractPerturbationStrategy end new_Δs_strategy(Δs) = SingleSidedStrategy() # Derivative coupling high-level interface """ δtoΔs(d, val, δ, Δs::AbstractFIs) Given the parameter `val` of a distribution `d` and an infinitesimal change `δ`, return the discrete change in the output, with a similar representation to `Δs`. """ δtoΔs(d, val, δ, Δs, derivative_coupling) = δtoΔs( d, val, δ, Δs, derivative_coupling, new_Δs_strategy(Δs)) function δtoΔs(d, val, δ, Δs, derivative_coupling, ::SingleSidedStrategy) _δtoΔs(d, val, δ, Δs, derivative_coupling) end function δtoΔs(d, val, δ, Δs, derivative_coupling, ::TwoSidedStrategy) Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling) Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling) return combine((scale(Δs1, 0.5), scale(Δs2, -0.5))) end # TODO: implement this ST for other distributions and couplings, if meaningful? function δtoΔs(d::Union{Bernoulli, Binomial}, val, δ, Δs, derivative_coupling::InversionMethodDerivativeCoupling, ::StraightThroughStrategy) p = succprob(d) Δs1 = _δtoΔs(d, val, δ, Δs, derivative_coupling) Δs2 = _δtoΔs(d, val, -δ, Δs, derivative_coupling) return combine((scale(Δs1, 1 - p), scale(Δs2, -p))) end function δtoΔs(d, val::V, δ, Δs, derivative_coupling, ::IgnoreDiscreteStrategy) where {V} similar_empty(Δs, V) end # Implement straight through strategy, works for all distrs, but does something that is only # meaningful for smoothed backends (using one(val)) function δtoΔs(d, val, δ, Δs, derivative_coupling, ::SmoothedStraightThroughStrategy) p = _get_parameter(d) δout = ForwardDiff.derivative(a -> mean(_reconstruct(d, p + a * δ)), 0.0) return similar_new(Δs, one(val), δout) end # Derivative coupling low-level implementations function _δtoΔs(d::Geometric, val::V, δ::Real, Δs::AbstractFIs, derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = succprob(d) if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) || (derivative_coupling.mode isa Val{:always_right}) return val > 0 ? similar_new(Δs, -one(V), δ * val / p / (1 - p)) : similar_empty(Δs, V) elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) || (derivative_coupling.mode isa Val{:always_left}) return similar_new(Δs, one(V), -δ * (val + 1) / p) else return similar_empty(Δs, V) end end function _δtoΔs(d::Bernoulli, val::V, δ::Real, Δs::AbstractFIs, derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = succprob(d) if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) || (derivative_coupling.mode isa Val{:always_right}) return isone(val) ? similar_empty(Δs, V) : similar_new(Δs, one(V), δ / (1 - p)) elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) || (derivative_coupling.mode isa Val{:always_left}) return isone(val) ? similar_new(Δs, -one(V), -δ / p) : similar_empty(Δs, V) else return similar_empty(Δs, V) end end function _δtoΔs(d::Binomial, val::V, δ::Real, Δs::AbstractFIs, derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = succprob(d) n = ntrials(d) if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) || (derivative_coupling.mode isa Val{:always_right}) return val == n ? similar_empty(Δs, V) : similar_new(Δs, one(V), δ * (n - val) / (1 - p)) elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) || (derivative_coupling.mode isa Val{:always_left}) return !iszero(val) ? similar_new(Δs, -one(V), -δ * val / p) : similar_empty(Δs, V) else return similar_empty(Δs, V) end end function _δtoΔs(d::Poisson, val::V, δ::Real, Δs::AbstractFIs, derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = mean(d) # rate if (derivative_coupling.mode isa Val{:positive_weight} && δ > 0) || (derivative_coupling.mode isa Val{:always_right}) return similar_new(Δs, 1, δ) elseif (derivative_coupling.mode isa Val{:positive_weight} && δ < 0) || (derivative_coupling.mode isa Val{:always_left}) return val > 0 ? similar_new(Δs, -1, -δ * val / p) : similar_empty(Δs, V) else return similar_empty(Δs, V) end end function _δtoΔs(d::Categorical, val::V, δs, Δs::AbstractFIs, derivative_coupling::InversionMethodDerivativeCoupling) where {V <: Signed} p = params(d)[1] # NB: Although we might expect sum(δs) = 0, it is useful to handle things more generally, viewing δs # as perturbing the Categorical distribution locally along some direction in the space of general measures. # The below formulation gets things right in this case too. left_sum = sum(δs[1:(val - 1)], init = zero(eltype(δs))) right_sum = sum(δs[1:val], init = zero(eltype(δs))) if (derivative_coupling.mode isa Val{:positive_weight} && left_sum > 0) || (derivative_coupling.mode isa Val{:always_left} && !iszero(left_sum)) # compute left_nonzero if derivative_coupling.handle_zeroprob isa Val{true} stop = rand() * left_sum upto = zero(eltype(δs)) # The "upto" logic handles an edge case of probability 0 events that have non-zero derivative. # It's a lot of logic to handle an edge case, but hopefully it's optimized away. left_nonzero = val for i in (val - 1):-1:1 if !iszero(p[i]) || ((upto += δs[i]) > stop) left_nonzero = i break end end else left_nonzero = val - 1 end Δs_left = similar_new(Δs, left_nonzero - val, left_sum / p[val]) else Δs_left = similar_empty(Δs, typeof(val)) end if (derivative_coupling.mode isa Val{:positive_weight} && right_sum < 0) || (derivative_coupling.mode isa Val{:always_right} && !iszero(right_sum)) # compute right_nonzero if derivative_coupling.handle_zeroprob isa Val{true} stop = -rand() * right_sum upto = zero(eltype(δs)) right_nonzero = val for i in (val + 1):length(p) if !iszero(p[i]) || ((upto += δs[i]) > stop) right_nonzero = i break end end else right_nonzero = val + 1 end Δs_right = similar_new(Δs, right_nonzero - val, -right_sum / p[val]) else Δs_right = similar_empty(Δs, typeof(val)) end return combine((Δs_left, Δs_right); rep = Δs) end ## Propagation couplings abstract type AbstractPropagationCoupling end """ InversionMethodPropagationCoupling Specifies an inversion method coupling for propagating perturbations. """ struct InversionMethodPropagationCoupling <: AbstractPropagationCoupling end function _map_func(d, val, Δ, ::InversionMethodPropagationCoupling) # construct alternative distribution p = _get_parameter(d) alt_d = _reconstruct(d, p + Δ) # compute bounds on original ω low = cdf(d, val - 1) high = cdf(d, val) # sample alternative value alt_val = quantile(alt_d, rand(RNG) * (high - low) + low) return convert(Signed, alt_val - val) end function _map_enumeration(d, val, Δ, ::InversionMethodPropagationCoupling) # construct alternative distribution p = _get_parameter(d) alt_d = _reconstruct(d, p + Δ) # compute bounds on original ω low = cdf(d, val - 1) high = cdf(d, val) if _has_finite_support(alt_d) map(_get_support(alt_d)) do alt_val # interval intersect of (cdf(alt_d, alt_val - 1), cdf(alt_d, alt_val)) and (low, high) alt_low = cdf(alt_d, alt_val - 1) alt_high = cdf(alt_d, alt_val) prob_alt = max(0.0, min(alt_high, high) - max(alt_low, low)) / (high - low) return (alt_val - val, prob_alt) end else error("enumeration not supported for distribution $d. Does $d have finite support?") end end ## Overloading of random sampling # Define randst interface """ randst(rng, d::Distributions.Sampleable; kwargs...) When no keyword arguments are provided, `randst` behaves identically to `rand(rng, d)` in both ordinary computation and for stochastic triple dispatches. However, `randst` also allows the user to provide various keyword arguments for customizing the differentiation logic. The set of allowed keyword arguments depends on the type of `d`: a couple common ones are `derivative_coupling` and `propagation_coupling`. For developers: if you wish to accept custom keyword arguments in a stochastic triple dispatch, you should overload `randst`, and redirect `rand` to your `randst` method. If you do not, it suffices to just overload `rand`. """ randst(rng, d::Distributions.Sampleable; kwargs...) = rand(rng, d) randst(d::Distributions.Sampleable; kwargs...) = randst(Random.default_rng(), d; kwargs...) # Define stochastic triple rules for dist in [:Geometric, :Bernoulli, :Binomial, :Poisson] @eval function Base.rand(rng::AbstractRNG, d_st::$dist{StochasticTriple{T, V, FIs}}) where {T, V, FIs} return randst(rng, d_st) end @eval function randst(rng::AbstractRNG, d_st::$dist{StochasticTriple{T, V, FIs}}; Δ_kwargs = (;), derivative_coupling = InversionMethodDerivativeCoupling(), propagation_coupling = InversionMethodPropagationCoupling()) where {T, V, FIs} st = _get_parameter(d_st) d = _reconstruct(d_st, st.value) val = convert(Signed, rand(rng, d)) Δs1 = δtoΔs(d, val, st.δ, st.Δs, derivative_coupling) Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling), st.Δs; enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling), deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling), out_rep = val, Δ_kwargs...) StochasticTriple{T}(val, zero(val), combine((Δs2, Δs1); rep = Δs1)) # ensure that tags are in order in combine, in case backend wishes to exploit this end end # currently handle Categorical separately since parameter is a vector # what if some elements in vector are not stochastic triples... promotion should take care of that? function Base.rand(rng::AbstractRNG, d_st::Categorical{StochasticTriple{T, V, FIs}}) where {T, V, FIs} return randst(rng, d_st) end function randst(rng::AbstractRNG, d_st::Categorical{<:StochasticTriple{T}, <:AbstractVector{<:StochasticTriple{T, V}}}; Δ_kwargs = (;), derivative_coupling = InversionMethodDerivativeCoupling(), propagation_coupling = InversionMethodPropagationCoupling()) where {T, V} sts = _get_parameter(d_st) # stochastic triple for each probability p = map(st -> st.value, sts) # try to keep the same type. e.g. static array -> static array. TODO: avoid allocations d = _reconstruct(d_st, p) val = convert(Signed, rand(rng, d)) Δs_all = map(st -> st.Δs, sts) Δs_rep = get_rep(Δs_all) Δs1 = δtoΔs(d, val, map(st -> st.δ, sts), Δs_rep, derivative_coupling) Δs_coupled = couple(Δs_all; rep = Δs_rep, out_rep = p) # TODO: again, there are possible allocations here Δs2 = map(Δ -> _map_func(d, val, Δ, propagation_coupling), Δs_coupled; enumeration = (Δ, _) -> _map_enumeration(d, val, Δ, propagation_coupling), deriv = δ -> smoothed_delta(d, val, δ, derivative_coupling), out_rep = val, Δ_kwargs...) Δs = combine((Δs2, Δs1); rep = Δs1, out_rep = val, Δ_kwargs...) StochasticTriple{T}(val, zero(val), Δs) end ## Handling finite perturbation to Binomial number of trials """ DiscreteDeltaStochasticTriple{T, V, FIs <: AbstractFIs} An experimental discrete stochastic triple type used internally for representing perturbations to non-real quantities. Currently only used to represent a finite perturbation to the Binomial parameter n. ## Constructor - `value`: the primal value. - `Δs``: some representation of the perturbation to the primal, which can have an unconventional interpretation depending on `T`. """ struct DiscreteDeltaStochasticTriple{T, V, FIs <: AbstractFIs} value::V Δs::FIs function DiscreteDeltaStochasticTriple{T, V, FIs}(value::V, Δs::FIs) where {T, V, FIs <: AbstractFIs} new{T, V, FIs}(value, Δs) end end function DiscreteDeltaStochasticTriple{T}(val::V, Δs::FIs) where {T, V, FIs <: AbstractFIs} DiscreteDeltaStochasticTriple{T, V, FIs}(val, Δs) end function Distributions.Binomial(n::StochasticTriple{T}, p::Real) where {T} return DiscreteDeltaStochasticTriple{T}(Binomial(n.value, p), n.Δs) end # TODO: Support functions other than `rand` called on a perturbed Binomial. function Base.rand(rng::AbstractRNG, d_st::DiscreteDeltaStochasticTriple{T, <:Binomial}) where {T} return randst(rng, d_st) end function randst(rng::AbstractRNG, d_st::DiscreteDeltaStochasticTriple{T, <:Binomial}) where {T} d = d_st.value val = rand(rng, d) function map_func(Δ) if Δ >= 0 return rand(StochasticAD.RNG, Binomial(Δ, value(succprob(d)))) else return -rand(StochasticAD.RNG, Hypergeometric(value(val), ntrials(d) - value(val), -Δ)) end end Δs = map(map_func, d_st.Δs) if val isa StochasticTriple return StochasticTriple{T}(val.value, val.δ, combine((Δs, val.Δs); rep = Δs)) else return StochasticTriple{T}(val, zero(val), Δs) end end ================================================ FILE: src/finite_infinitesimals.jl ================================================ # TODO: make this a module, with the interface exported? ## """ AbstractFIsBackend An abstract type for backend strategies of Finite perturbations that occur with Infinitesimal probability (FIs). """ abstract type AbstractFIsBackend end """ AbstractFIs{V} An abstract type for concrete backend representations of Finite Infinitesimals. """ abstract type AbstractFIs{V} end ### Some of the necessary interface notes below. # TODO: document function create_Δs end function similar_new end function similar_empty end function similar_type end valtype(Δs::AbstractFIs) = valtype(typeof(Δs)) # TODO: typeof ∘ first is a loose check, should make more robust. # TODO: perhaps deprecate these methods in favor of an explicit first argument? couple(Δs_all; kwargs...) = couple(typeof(first(Δs_all)), Δs_all; kwargs...) combine(Δs_all; kwargs...) = combine(typeof(first(Δs_all)), Δs_all; kwargs...) get_rep(Δs_all; kwargs...) = get_rep(typeof(first(Δs_all)), Δs_all; kwargs...) function scalarize end function derivative_contribution end function alltrue end function perturbations end function filter_state end function weighted_map_Δs end function map_Δs(f, Δs::AbstractFIs; kwargs...) StochasticAD.weighted_map_Δs((Δs, state) -> (f(Δs, state), 1.0), Δs; kwargs...) end function Base.map(f, Δs::AbstractFIs; kwargs...) StochasticAD.map_Δs((Δs, _) -> f(Δs), Δs; kwargs...) end # We also add a scale to deriv for scaling smoothed perturbations function scale(Δs::AbstractFIs, a::Real) StochasticAD.weighted_map_Δs((Δ, state) -> (Δ, a), Δs; deriv = Base.Fix1(*, a), out_rep = Δs) end function new_Δs_strategy end # utility function useful e.g. for get_rep in some backends function get_any(Δs_all) # The code below is a bit ridiculous, but it's faster than `first` for small structures:) foldl((Δs1, Δs2) -> Δs1, StochasticAD.structural_iterate(Δs_all)) end abstract type AbstractPerturbationStrategy end abstract type AbstractPerturbationSignal end function send_signal end # Ignore signals by default since they do not change semantics. function StochasticAD.send_signal( Δs::StochasticAD.AbstractFIs, ::StochasticAD.AbstractPerturbationSignal) return Δs end ================================================ FILE: src/general_rules.jl ================================================ """ Operators which have already been overloaded by StochasticAD. """ const handled_ops = Tuple{DataType, Int}[] """ define_triple_overload(sig) Given the signature type-type of the primal function, define operator overloading rules for stochastic triples. Currently supports functions with all-real inputs and one real output. """ # TODO: special case optimizations # TODO: generalizations to not-all-real inputs and/or not-one-real output function define_triple_overload(sig) opT, argTs = Iterators.peel(ExprTools.parameters(sig)) opT <: Type{<:Type} && return # not handling constructors sig <: Tuple{Type, Vararg{Any}} && return opT <: Core.Builtin && return false # can't do operator overloading for builtins isabstracttype(opT) || fieldcount(opT) == 0 || return false # not handling functors isempty(argTs) && return false # we are an operator overloading AD, need operands all(argT isa Type && Real <: argT for argT in argTs) || return N = length(ExprTools.parameters(sig)) - 1 # skip the op # Skip already-handled ops, as well as ops that will be handled manually later (and more correctly, see #79). if (opT, N) in handled_ops || (opT.instance in UNARY_TYPEFUNCS_WRAP) return end push!(handled_ops, (opT, N)) if opT.instance in UNARY_PREDICATES && (N == 1) @eval function (f::$opT)(st::StochasticTriple) val = value(st) out = f(val) if !alltrue(Δ -> (f(val + Δ) == out), st.Δs) error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)") end return out end elseif opT.instance in BINARY_PREDICATES && (N == 2) # Special case equality comparisons as in https://github.com/JuliaDiff/ForwardDiff.jl/pull/481 if opT.instance == Base.:(==) return_value_real = quote out && iszero(delta(st)) end return_value_st = quote out2 = out && (delta(st1) == delta(st2)) end else return_value_real = quote out end return_value_st = quote out end end @eval function (f::$opT)(st::StochasticTriple, x::Real) val = value(st) out = f(val, x) if !alltrue(Δ -> (f(val + Δ, x) == out), st.Δs) error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)") end return $return_value_real end @eval function (f::$opT)(x::Real, st::StochasticTriple) val = value(st) out = f(x, val) if !alltrue(Δ -> (f(x, val + Δ) == out), st.Δs) error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)") end return $return_value_real end @eval function (f::$opT)(st1::StochasticTriple, st2::StochasticTriple) val1 = value(st1) val2 = value(st2) out = f(val1, val2) Δs_coupled = couple((st1.Δs, st2.Δs); out_rep = (val1, val2)) safe_perturb = alltrue(Δs -> f(val1 + Δs[1], val2 + Δs[2]) == out, Δs_coupled) if !safe_perturb error("Output of boolean predicate cannot depend on input (unsupported by StochasticAD)") end return $return_value_st end elseif N == 1 if Base.return_types(frule, (Tuple{NoTangent, Real}, opT, Real))[1] <: Tuple{Any, NoTangent} return end @eval function (f::$opT)(st::StochasticTriple{T}; kwargs...) where {T} run_frule = δ -> begin args_tangent = (NoTangent(), δ) return frule(args_tangent, f, value(st); kwargs...) end val, δ0 = run_frule(delta(st)) δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0 if !iszero(st.Δs) Δs = map(Δ -> f(st.value + Δ; kwargs...) - val, st.Δs; deriv = last ∘ run_frule, out_rep = val) else Δs = similar_empty(st.Δs, typeof(val)) end return StochasticTriple{T}(val, δ, Δs) end elseif N == 2 if Base.return_types(frule, (Tuple{NoTangent, Real, Real}, opT, Real, Real))[1] <: Tuple{Any, NoTangent} return end for R in AMBIGUOUS_TYPES @eval function (f::$opT)(st::StochasticTriple{T}, x::$R; kwargs...) where {T} run_frule = δ -> begin args_tangent = (NoTangent(), δ, zero(x)) return frule(args_tangent, f, value(st), x; kwargs...) end val, δ0 = run_frule(delta(st)) δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0 if !iszero(st.Δs) Δs = map(Δ -> f(st.value + Δ, x; kwargs...) - val, st.Δs; deriv = last ∘ run_frule, out_rep = val) else Δs = similar_empty(st.Δs, typeof(val)) end return StochasticTriple{T}(val, δ, Δs) end @eval function (f::$opT)(x::$R, st::StochasticTriple{T}; kwargs...) where {T} run_frule = δ -> begin args_tangent = (NoTangent(), zero(x), δ) return frule(args_tangent, f, x, value(st); kwargs...) end val, δ0 = run_frule(delta(st)) δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0 if !iszero(st.Δs) Δs = map(Δ -> f(x, st.value + Δ; kwargs...) - val, st.Δs; deriv = last ∘ run_frule, out_rep = val) else Δs = similar_empty(st.Δs, typeof(val)) end return StochasticTriple{T}(val, δ, Δs) end end @eval function (f::$opT)(sts::Vararg{StochasticTriple{T}, 2}; kwargs...) where {T} run_frule = δs -> begin args_tangent = (NoTangent(), δs...) args = (f, value.(sts)...) return frule(args_tangent, args...; kwargs...) end val, δ0 = run_frule(delta.(sts)) δ::typeof(val) = (δ0 isa ZeroTangent || δ0 isa NoTangent) ? zero(value(st)) : δ0 Δs_all = map(st -> getfield(st, :Δs), sts) if all(iszero.(Δs_all)) Δs = similar_empty(first(sts).Δs, typeof(val)) else vals_in = value.(sts) Δs_coupled = couple(Tuple(Δs_all); out_rep = vals_in) mapfunc = let vals_in = vals_in Δ -> (f((vals_in .+ Δ)...; kwargs...) - val) end Δs = map(mapfunc, Δs_coupled; deriv = last ∘ run_frule, out_rep = val) end return StochasticTriple{T}(val, δ, Δs) end end end on_new_rule(define_triple_overload, frule) ### Extra overloads # TODO: generalize the below logic to compactly handle a wider range of functions. # See also https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/dual.jl. function Base.hash(st::StochasticTriple, hsh::UInt) if !isempty(st.Δs) error("Hashing a stochastic triple with perturbations not yet supported.") end hash(StochasticAD.value(st), hsh) end #= This is a hacky experimental way to convert a float-like stochastic triple into an integer-like one, to facilitate generic coding. =# function Base.round(I::Type{<:Integer}, st::StochasticTriple{T, V}) where {T, V} return StochasticTriple{T}(round(I, st.value), map(Δ -> round(I, st.value + Δ), st.Δs)) end for op in UNARY_TYPEFUNCS_NOWRAP function (::typeof(op))(::Type{<:StochasticTriple{T, V, FIs}}) where {T, V, FIs} return op(V) end end for op in UNARY_TYPEFUNCS_WRAP function (::typeof(op))(::Type{StochasticTriple{T, V, FIs}}) where {T, V, FIs} return StochasticTriple{T, V, FIs}(op(V), zero(V), empty(FIs)) end function (::typeof(op))(st::StochasticTriple) return op(typeof(st)) end end for op in RNG_TYPEFUNCS_WRAP function (::typeof(op))(rng::AbstractRNG, ::Type{StochasticTriple{T, V, FIs}}) where {T, V, FIs} return StochasticTriple{T, V, FIs}(op(rng, V), zero(V), empty(FIs)) end end #= The short-circuit "x == y" case in Base.isapprox is bad for us because it could unnecessarily lead to a boolean-predicate depends on output error where StochasticAD cannot prove correctness. We patch up the rule by removing the short-circuit, allowing some common cases to work. In the future, we will ideally handle the overloading rule in a more general way. (E.g. by catching the chain rule for isapprox and recursively calling isapprox on the values.) =# function Base.isapprox(st1::StochasticTriple, st2::StochasticTriple; atol::Real = 0, rtol::Real = Base.rtoldefault(st1, st2, atol), nans::Bool = false, norm::Function = abs) (isfinite(st1) && isfinite(st2) && norm(st1 - st2) <= max(atol, rtol * max(norm(st1), norm(st2)))) || (nans && isnan(st1) && isnan(st2)) end function Base.isapprox(st1::StochasticTriple, x::Real; atol::Real = 0, rtol::Real = Base.rtoldefault(st1, x, atol), nans::Bool = false, norm::Function = abs) (isfinite(st1) && isfinite(x) && norm(st1 - x) <= max(atol, rtol * max(norm(st1), norm(x)))) || (nans && isnan(st1) && isnan(x)) end function Base.isapprox(x::Real, st::StochasticTriple; kwargs...) return Base.isapprox(st, x; kwargs...) end # Alternate version of _isassigned that does not fall back on try/catch. _isassigned(C, i) = (i in eachindex(C)) """ Base.getindex(C::AbstractArray, st::StochasticTriple{T}) A simple prototype rule for array indexing. Assumes that underlying type of `st` can index into collection C. """ # TODO: support multiple indices, cartesian indices, non abstract array indexables, other use cases... # Example to fix: A[:, :, st] function Base.getindex(C::AbstractArray, st::StochasticTriple{T, V, FIs}) where {T, V, FIs} val = C[st.value] do_map = (Δ, state) -> begin return value(C[st.value + Δ], state) - value(val, state) end # TODO: below doesn't support sparse arrays, use something like nextind deriv = δ -> begin scale = if _isassigned(C, st.value + 1) && _isassigned(C, st.value - 1) 1 / 2 * (value(C[st.value + 1]) - value(C[st.value - 1])) elseif _isassigned(C, st.value + 1) value(C[st.value + 1]) - value(C[st.value]) elseif _isassigned(C, st.value - 1) value(C[st.value]) - value(C[st.value - 1]) else zero(eltype(C)) end return scale * δ end Δs = StochasticAD.map_Δs(do_map, st.Δs; deriv, out_rep = value(val)) if val isa StochasticTriple Δs = combine((Δs, val.Δs)) end return StochasticTriple{T}(value(val), delta(val), Δs) end ================================================ FILE: src/misc.jl ================================================ @doc raw""" StochasticModel(X, p) Combine stochastic program `X` with parameter `p` into a trainable model using [Functors](https://fluxml.ai/Functors.jl/stable/), where `p <: AbstractArray`. Formulate as a minimization problem, i.e. find ``p`` that minimizes ``\mathbb{E}[X(p)]``. """ struct StochasticModel{S <: AbstractArray, T} X::T p::S end @functor StochasticModel (p,) """ stochastic_gradient(m::StochasticModel) Compute gradient with respect to the trainable parameter `p` of `StochasticModel(X, p)`. """ function stochastic_gradient(m::StochasticModel) fmap(p -> derivative_estimate(m.X, p), m) end ================================================ FILE: src/prelude.jl ================================================ const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode) const UNARY_PREDICATES = [isinf, isnan, isfinite, iseven, isodd, isreal, isinteger] const BINARY_PREDICATES = [ isequal, isless, <, >, ==, !=, <=, >= ] const UNARY_TYPEFUNCS_NOWRAP = [Base.rtoldefault] const UNARY_TYPEFUNCS_WRAP = [ typemin, typemax, floatmin, floatmax, zero, one ] const RNG_TYPEFUNCS_WRAP = [rand, randn, randexp] """ structural_iterate(args) Internal helper function for iterating through the scalar values of a functor, where AbstractFIs are also counted as scalars. """ function structural_iterate(args) make_iterator(x) = x isa AbstractArray ? x : (x,) exclude(x) = Functors.isleaf(x) || (x isa AbstractFIs) iter = fmap(make_iterator, args; walk = Functors.IterateWalk(), cache = nothing, exclude) return iter end structural_iterate(args::NTuple{N, Union{Real, AbstractFIs}}) where {N} = args structural_iterate(args::AbstractArray{T}) where {T <: Union{Real, AbstractFIs}} = args structural_iterate(args::T) where {T <: Real} = (args,) """ structural_map(f, args) Internal helper function for a structure-preserving map, often to be used on a function's input/output arguments. Currently uses [fmap](https://fluxml.ai/Functors.jl/stable/api/#Functors.fmap) from Functors.jl as a backend. """ function structural_map(f, args...; only_vals = nothing) walk = if only_vals isa Val{true} Functors.StructuralWalk() elseif (only_vals isa Val{false}) || isnothing(only_vals) Functors.DefaultWalk() else error("Unsupported argument only_vals = $only_vals") end fmap((args...) -> args[1] isa AbstractArray ? f.(args...) : f(args...), args...; cache = nothing, walk) end ================================================ FILE: src/propagate.jl ================================================ """ A version of `value`` that allows unrecognized args to pass through. """ function get_value(arg) if arg isa StochasticTriple return value(arg) else # potentially dangerous, see also note in get_Δs return arg end end function get_Δs(arg, FIs) if arg isa StochasticTriple return arg.Δs else #= this case is a bit dangerous: perturbations could be dropped here if a leaf of a functor somehow contains a type that is not one of the two above. =# return empty(similar_type(FIs, typeof(arg))) end end function strip_Δs(arg; use_dual = Val(true)) if arg isa StochasticTriple # TODO: replace check below with a more robust notion of discreteness. if valtype(arg) <: Integer return value(arg) else if use_dual isa Val{true} return ForwardDiff.Dual{tag(arg)}(value(arg), delta(arg)) else return StochasticAD.StochasticTriple{tag(arg)}( value(arg), delta(arg), empty(arg.Δs)) end end else return arg end end """ propagate(f, args...; keep_deltas = Val(false)) Propagates `args` through a function `f`, handling stochastic triples by independently running `f` on the primal and the alternatives, rather than by inspecting the internals of `f` (which may possibly be unsupported by `StochasticAD`). Currently handles deterministic functions `f` with any input and output that is `fmap`-able by `Functors.jl`. If `f` has a continuously differentiable component, provide `keep_deltas = Val(true)`. This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator overloading rules based on dispatch. For example: ```jldoctest using StochasticAD, Distributions import Random # hide Random.seed!(4321) # hide function mybranch(x) str = repr(x) # string-valued intermediate! if length(str) < 2 return 3 else return 7 end end function f(x) return mybranch(9 + rand(Bernoulli(x))) end # stochastic_triple(f, 0.5) # this would fail # Add a dispatch rule for mybranch using StochasticAD.propagate mybranch(x::StochasticAD.StochasticTriple) = StochasticAD.propagate(mybranch, x) stochastic_triple(f, 0.5) # now works # output StochasticTriple of Int64: 3 + 0ε + (4 with probability 2.0ε) ``` !!! warning This function is experimental and subject to change. """ function propagate(f, args...; keep_deltas = Val(false), provided_st_rep = nothing, deriv = nothing) # TODO: support kwargs to f (or just use kwfunc in macro) #= TODO: maybe don't iterate through every scalar of array below, but rather have special array dispatch =# st_rep = if provided_st_rep === nothing args_iter = structural_iterate(args) function args_fold(arg1, arg2) if arg1 isa StochasticTriple if (arg2 isa StochasticTriple) && (tag(arg1) !== tag(arg2)) throw(ArgumentError("Tags of combined stochastic triples do not match!")) end return arg1 else return arg2 end end foldl(args_fold, args_iter) else provided_st_rep end if !(st_rep isa StochasticTriple) return f(args...) end primal_args = structural_map(get_value, args) input_args = keep_deltas isa Val{false} ? primal_args : structural_map(strip_Δs, args) #= TODO: the below is dangerous is general. It should be safe so long as f does not close over stochastic triples. (If f is a closure, the parameters of f should be treated like any other parameters; if they are stochastic triples and we are ignoring them, dangerous in general.) =# out = f(input_args...) val = structural_map(value, out) # TODO: what does the only_vals do in the below and why? Δs_all = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), args; only_vals = Val{true}()) # TODO: Coupling approach below needs to handle non-perturbable objects. Δs_coupled = couple(backendtype(st_rep), Δs_all; rep = st_rep.Δs, out_rep = val) function map_func(Δ_coupled) perturbed_args = structural_map(+, primal_args, Δ_coupled) #= TODO: for f discrete random with randomness independent of params, could couple here. But difficult without a splittable RNG. =# alt = f(perturbed_args...) return structural_map((x, y) -> value(x) - y, alt, val) end Δs = map(map_func, Δs_coupled; out_rep = val, deriv) # TODO: make sure all FI backends support interface needed below new_out = structural_map(out, scalarize(Δs; out_rep = val)) do leaf_out, leaf_Δs StochasticAD.StochasticTriple{tag(st_rep)}(value(leaf_out), delta(leaf_out), leaf_Δs) end return new_out end ================================================ FILE: src/smoothing.jl ================================================ ### Particle resampling @doc raw""" new_weight(p::Real) Simulate a Bernoulli variable whose primal output is always 1. Uses a smoothing rule for use in forward and reverse-mode AD, which is exactly unbiased when the quantity is only used in linear functions (e.g. used as an [importance weight](https://en.wikipedia.org/wiki/Importance_sampling)). """ new_weight(p::Real) = 1 function new_weight(p::ForwardDiff.Dual{T}) where {T} Δp = ForwardDiff.partials(p) val_p = ForwardDiff.value(p) val_p = max(1e-5, val_p) # TODO: is this necessary? ForwardDiff.Dual{T}(one(val_p), Δp / val_p) end function ChainRulesCore.frule((_, Δp), ::typeof(new_weight), p::Real) val_p = max(1e-5, p) # TODO: is this necessary? return one(p), Δp / val_p end function ChainRulesCore.rrule(::typeof(new_weight), p) function new_weight_pullback(∇Ω) return (ChainRulesCore.NoTangent(), ∇Ω / p) end return (one(p), new_weight_pullback) end # Smoothed rules for univariate single-parameter distributions. function smoothed_delta(d, val, δ, derivative_coupling) Δs_empty = SmoothedFIs{typeof(val)}(0.0) return derivative_contribution(δtoΔs(d, val, δ, Δs_empty, derivative_coupling)) end for (dist, i, field) in [ (:Geometric, :1, :p), (:Bernoulli, :1, :p), (:Binomial, :2, :p), (:Poisson, :1, :λ), (:Categorical, :1, :p) ] # i = index of parameter p # dual overloading @eval function Base.rand(rng::AbstractRNG, d_dual::$dist{<:ForwardDiff.Dual{T}}) where {T} return randst(rng, d_dual) end @eval function randst(rng::AbstractRNG, d_dual::$dist{<:ForwardDiff.Dual{T}}; derivative_coupling = InversionMethodDerivativeCoupling()) where {T} dual = params(d_dual)[$i] # dual could represent an array of duals or a single one; map handles both cases. p = map(value, dual) # Generate a δ for each partial component. partials_indices = ntuple(identity, length(first(dual).partials)) δs = map(i -> map(d -> ForwardDiff.partials(d)[i], dual), partials_indices) d = $dist(params(d_dual)[1:($i - 1)]..., p, params(d_dual)[($i + 1):end]...) val = convert(Signed, rand(rng, d)) partials = ForwardDiff.Partials(map( δ -> smoothed_delta(d, val, δ, derivative_coupling), δs)) ForwardDiff.Dual{T}(val, partials) end # frule @eval function ChainRulesCore.frule(Δargs, ::typeof(rand), rng::AbstractRNG, d::$dist) return frule(Δargs, randst, rng, d) end @eval function ChainRulesCore.frule((_, _, Δd), ::typeof(randst), rng::AbstractRNG, d::$dist; derivative_coupling = InversionMethodDerivativeCoupling()) val = convert(Signed, rand(rng, d)) Δval = smoothed_delta(d, val, Δd, derivative_coupling) return (val, Δval) end # rrule @eval function ChainRulesCore.rrule(::typeof(rand), rng::AbstractRNG, d::$dist) return rrule(randst, rng, d) end @eval function ChainRulesCore.rrule(::typeof(randst), rng::AbstractRNG, d::$dist; derivative_coupling = InversionMethodDerivativeCoupling()) val = convert(Signed, rand(rng, d)) function rand_pullback(∇out) p = params(d)[$i] if p isa Real Δp = smoothed_delta(d, val, one(val), derivative_coupling) else # TODO: this rule is O(length(p)^2), whereas we should be able to do O(length(p)) by reversing through δtoΔs. I = eachindex(p) V = eltype(p) onehot(i) = map(j -> j == i ? one(V) : zero(V), I) Δp = map(i -> smoothed_delta(d, val, onehot(i), derivative_coupling), I) end # rrule_via_ad approach below not used because slow. # Δp = rrule_via_ad(config, smoothed_delta, d, val, map(one, p))[2](∇out)[4] Δd = ChainRulesCore.Tangent{typeof(d)}(; $field = ∇out * Δp) return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), Δd) end return (val, rand_pullback) end end ================================================ FILE: src/stochastic_triple.jl ================================================ """ StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}} Stores the primal value of the computation, alongside a "dual" component representing an infinitesimal change, and a "triple" component that tracks discrete change(s) with infinitesimal probability. Pretty printed as "value + δε + ({pretty print of Δs})". ## Constructor - `value`: the primal value. - `δ`: the value of the almost-sure derivative, i.e. the rate of "infinitesimal" change. - `Δs`: alternate values with associated weights, i.e. Finite perturbations with Infinitesimal probability, represented by a backend `FIs <: AbstractFIs`. """ struct StochasticTriple{T, V <: Real, FIs <: AbstractFIs{V}} <: Real value::V δ::V # infinitesimal change Δs::FIs # finite changes with infinitesimal probabilities # (Δ = 3, p = 1*h) function StochasticTriple{T, V, FIs}(value::V, δ::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}} new{T, V, FIs}(value, δ, Δs) end end """ value(st::StochasticTriple) Return the primal value of `st`. """ value(x::Real, state = nothing) = x # Experimental method for obtaining the alternate value of a stochastic triple associated with a certain backend state. value(st::StochasticTriple) = st.value function value(st::StochasticTriple, state) st.value + filter_state(st.Δs, state) end #= Support ForwardDiff.Dual for internal usage. Assumes batch size is 1. =# value(d::ForwardDiff.Dual, state = nothing) = ForwardDiff.value(d) """ delta(st::StochasticTriple) Return the almost-sure derivative of `st`, i.e. the rate of infinitesimal change. """ delta(x::Real) = zero(x) delta(st::StochasticTriple) = st.δ # Support ForwardDiff.Dual for internal usage. delta(d::ForwardDiff.Dual) = ForwardDiff.partials(d)[1] """ perturbations(st::StochasticTriple) Return the finite perturbation(s) of `st`, in a format dependent on the [backend](devdocs.md) used for storing perturbations. """ perturbations(x::Real) = () perturbations(st::StochasticTriple) = perturbations(st.Δs) """ send_signal(st::StochasticTriple, signal::AbstractPerturbationSignal) send_signal(Δs::StochasticAD.AbstractFIs, signal::AbstractPerturbationSignal) Send a certain signal to a stochastic triple's perturbation collection `st.Δs` (or to a `Δs` directly), which the backend may process as it wishes. Semantically, unbiasedness should not be affected by the sending of the signal. The new version of the first argument (`st` or `Δs`) after signal processing is returned. """ send_signal(st::Real, ::AbstractPerturbationSignal) = st function send_signal(st::StochasticTriple{T}, signal::AbstractPerturbationSignal) where {T} new_Δs = send_signal(st.Δs, signal) return StochasticTriple{T}(st.value, st.δ, new_Δs) end """ derivative_contribution(st::StochasticTriple) Return the derivative estimate given by combining the dual and triple components of `st`. """ derivative_contribution(x::Real) = zero(x) derivative_contribution(d::ForwardDiff.Dual) = delta(d) derivative_contribution(st::StochasticTriple) = st.δ + derivative_contribution(st.Δs) """ tag(st::StochasticTriple) tag(::Type{<:StochasticTriple})) Get the tag of a stochastic triple. """ tag(::Type{<:StochasticTriple{T}}) where {T} = T tag(::Type{<:ForwardDiff.Dual{T}}) where {T} = T tag(::StochasticTriple{T}) where {T} = T tag(::ForwardDiff.Dual{T}) where {T} = T """ valtype(st::StochasticTriple) valtype(st::Type{<:StochasticTriple}) Get the underlying type of the value tracked by a stochastic triple. """ valtype(st::StochasticTriple) = valtype(typeof(st)) valtype(::Type{<:StochasticTriple{T, V}}) where {T, V} = V """ backendtype(st::StochasticTriple) backendtype(st::Type{<:StochasticTriple}) Get the backend type of a stochastic triple. """ backendtype(st::StochasticTriple) = backendtype(typeof(st)) backendtype(::Type{<:StochasticTriple{T, V, FIs}}) where {T, V, FIs} = FIs """ smooth_triple(st::StochasticTriple) Smooth the dual and triple components of a stochastic triple into a single dual component. Useful for avoiding unnecessary pruning when running multilinear functions on triples. """ smooth_triple(x::Real) = x function smooth_triple(st::StochasticTriple{T, V, FIs}) where {T, V, FIs} return StochasticTriple{T}(value(st), derivative_contribution(st), empty(FIs)) end ### Extra constructors of stochastic triples function StochasticTriple{T}(value::V, δ::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}} StochasticTriple{T, V, FIs}(value, δ, Δs) end function StochasticTriple{T}(value::V, Δs::FIs) where {T, V, FIs <: AbstractFIs{V}} StochasticTriple{T}(value, zero(value), Δs) end function StochasticTriple{T}(value::A, δ::B, Δs::FIs) where {T, A, B, C, FIs <: AbstractFIs{C}} V = promote_type(A, B, C) StochasticTriple{T}(convert(V, value), convert(V, δ), convert(similar_type(FIs, V), Δs)) end ### Conversion rules # TODO: is this the right thing to do? Maybe, different from the promote case because there V was guaranteed to be an ancestor. # Also, bad to do when already same type? function Base.convert(::Type{StochasticTriple{T1, V, FIs}}, x::StochasticTriple{T2}) where {T1, T2, V, FIs} (T1 !== T2) && throw(ArgumentError("Tags of combined stochastic triples do not match.")) StochasticTriple{T1, V, FIs}(convert(V, x.value), convert(V, x.δ), convert(FIs, x.Δs)) end # TODO: ForwardDiff's promotion rules are a little more complicated, see https://github.com/JuliaDiff/ForwardDiff.jl/issues/322 # May need to look into why and possibly use them here too. function Base.promote_rule(::Type{StochasticTriple{T, V1, FIs}}, ::Type{StochasticTriple{T, V2, FIs2}}) where {T, V1, FIs, V2, FIs2} V = promote_type(V1, V2) StochasticTriple{T, V, similar_type(FIs, V)} end function Base.promote_rule(::Type{StochasticTriple{T, V1, FIs}}, ::Type{V2}) where {T, V1, FIs, V2 <: Real} V = promote_type(V1, V2) StochasticTriple{T, V, similar_type(FIs, V)} end function Base.convert(::Type{StochasticTriple{T, V, FIs}}, x::Real) where {T, V, FIs} StochasticTriple{T, V, FIs}(convert(V, x), zero(V), empty(FIs)) end ### Creating the first stochastic triple in a computation function StochasticTriple{T}(value::V, δ::V, backend::AbstractFIsBackend) where {T, V} StochasticTriple{T}(value, δ, create_Δs(backend, V)) end function StochasticTriple{T}(value::V, backend::AbstractFIsBackend) where {T, V} StochasticTriple{T}(value, zero(V), backend) end function StochasticTriple{T}(value::A, δ::B, backend::AbstractFIsBackend) where {T, A, B} V = promote_type(A, B) StochasticTriple{T}(convert(V, value), convert(V, δ), backend) end ### Showing a stochastic triple function Base.summary(::StochasticTriple{T, V}) where {T, V} return "StochasticTriple of $V" end function Base.show(io::IO, ::MIME"text/plain", st::StochasticTriple) println(io, "$(summary(st)):") show(io, st) end function Base.show(io::IO, st::StochasticTriple) print(io, "$(st.value) + $(st.δ)ε") if (!isempty(st.Δs)) print(io, " + ($(repr(st.Δs)))") end end ### Higher level functions struct Tag{F, V} end function stochastic_triple_direction(f, p::V, direction; backend) where {V} Δs = create_Δs(backend, Int) # TODO: necessity of hardcoding some type here suggests interface improvements sts = structural_map(p, direction) do p_i, direction_i StochasticTriple{Tag{typeof(f), V}}(p_i, direction_i, similar_empty(Δs, typeof(p_i))) end return f(sts) end """ stochastic_triple(X, p; backend=PrunedFIsBackend(), direction=nothing) stochastic_triple(p; backend=PrunedFIsBackend(), direction=nothing) For any `p` that is supported by [`Functors.jl`](https://fluxml.ai/Functors.jl/stable/), e.g. scalars or abstract arrays, differentiate the output with respect to each value of `p`, returning an output of similar structure to `p`, where a particular value contains the stochastic-triple output of `X` when perturbing the corresponding value in `p` (i.e. replacing the original value `x` with `x + ε`). When `direction` is provided, return only the stochastic-triple output of `X` with respect to a perturbation of `p` in that particular direction. When `X` is not provided, the identity function is used. The `backend` keyword argument describes the algorithm used by the third component of the stochastic triple, see [technical details](devdocs.md) for more details. # Example ```jldoctest julia> using Distributions, Random, StochasticAD; Random.seed!(4321); julia> stochastic_triple(rand ∘ Bernoulli, 0.5) StochasticTriple of Int64: 0 + 0ε + (1 with probability 2.0ε) ``` """ function stochastic_triple( f, p; direction = nothing, backend::AbstractFIsBackend = PrunedFIsBackend()) if direction !== nothing return stochastic_triple_direction(f, p, direction; backend) end counter = begin c = 0 (_) -> begin c += 1 return c end end indices = structural_map(counter, p) map_func = perturbed_index -> begin direction = structural_map(indices, p) do i, p_i i == perturbed_index ? one(p_i) : zero(p_i) end stochastic_triple_direction(f, p, direction; backend) end return structural_map(map_func, indices) end stochastic_triple(p; kwargs...) = stochastic_triple(identity, p; kwargs...) """ dual_number(X, p; backend=PrunedFIsBackend(), direction=nothing) dual_number(p; backend=PrunedFIsBackend(), direction=nothing) A lightweight wrapper around [`stochastic_triple`](#StochasticAD.stochastic_triple) that entirely ignores the derivative contribution of all discrete random components, so that it behaves like a regular dual number. Mostly for fun -- this, of course, leads to a useless derivative estimate for discrete random functions! """ function dual_number(f, p; backend = PrunedFIsBackend(), kwargs...) backend = StrategyWrapperFIsBackend(backend, IgnoreDiscreteStrategy()) stochastic_triple(f, p; backend, kwargs...) end dual_number(p; kwargs...) = dual_number(identity, p; kwargs...) function derivative_estimate(f, p; kwargs...) StochasticAD.structural_map(derivative_contribution, stochastic_triple(f, p; kwargs...)) end ================================================ FILE: test/game_of_life.jl ================================================ using StochasticAD using Test using Statistics include("../tutorials/game_of_life/core.jl") using .GoLCore: fd_clever, play, p, nsamples @testset "AD and Finite Differences" begin samples_fd_clever = [fd_clever(p) for i in 1:nsamples] samples_st = [derivative_estimate(play, p) for i in 1:nsamples] @test mean(samples_st)≈mean(samples_fd_clever) rtol=5e-2 end ================================================ FILE: test/random_walk.jl ================================================ using StochasticAD using Test using Statistics using ForwardDiff: derivative include("../tutorials/random_walk/core.jl") using .RandomWalkCore: n, p, nsamples using .RandomWalkCore: fX, get_dfX @testset "Check unbiasedness" begin fX_deriv = derivative(p -> get_dfX(p, n), p) fX_deriv_estimate = mean(derivative_estimate(fX, p) for i in 1:nsamples) @test isapprox(fX_deriv, fX_deriv_estimate; rtol = 1e-2) end ================================================ FILE: test/resampling.jl ================================================ using StochasticAD using Random, Test using Distributions using LinearAlgebra using ForwardDiff # test forward-mode AD and reverse-mode AD on the particle filter ### Particle Filter Functions include("../tutorials/particle_filter/core.jl") seed = 237347578 ### Define model Random.seed!(seed) T = 3 d = 2 A(θ, a = 0.01) = [exp(-a)*cos(θ[]) exp(-a)*sin(θ[]) -exp(-a)*sin(θ[]) exp(-a)*cos(θ[])] obs(x, θ) = MvNormal(x, 0.01 * collect(I(d))) dyn(x, θ) = MvNormal(A(θ) * x, 0.02 * collect(I(d))) x0 = [2.0, 0.0] # start value of the simulation start(θ) = Dirac(x0) θtrue = [0.20] # put it all together stochastic_model = ParticleFilterCore.StochasticModel(T, start, dyn, obs) ### simulate model Random.seed!(seed) xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue) ### ### initialize sampler m = 1000 particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys, ParticleFilterCore.sample_stratified) ### @testset "new weight" begin p = 0.5 st = stochastic_triple(p) d = ForwardDiff.Dual(p, (1.0, 2.0)) @test new_weight(p) == one(p) @test StochasticAD.value(new_weight(st)) == one(p) @test StochasticAD.delta(new_weight(st)) == 1.0 / p @test ForwardDiff.value(new_weight(d)) == one(p) @test collect(ForwardDiff.partials(new_weight(d))) == [1.0 / p, 2.0 / p] end @testset "forward-mode and reverse-mode AD: single run" begin Random.seed!(seed) grad_forw = ParticleFilterCore.forw_grad(θtrue, particle_filter) Random.seed!(seed) grad_back = ParticleFilterCore.back_grad(θtrue, particle_filter) @test grad_forw ≈ grad_back end @testset "AD and Finite Differences" begin h = 0.02 # finite diff N = 500 # number of samples grad_fw = [ParticleFilterCore.forw_grad(θtrue, particle_filter)[1] for i in 1:N] # grad_bw = @time [back_grad(θtrue, particle_filter) for i in 1:N] grad_fd = [(ParticleFilterCore.log_likelihood(particle_filter, θtrue .+ h) - ParticleFilterCore.log_likelihood(particle_filter, θtrue .- h)) / (2h) for i in 1:N] @test mean(grad_fd)≈mean(grad_fw) rtol=5e-2 end ================================================ FILE: test/runtests.jl ================================================ using SafeTestsets using Test, Pkg import Random Random.seed!(1234) const GROUP = get(ENV, "GROUP", "All") const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR") @time begin if GROUP == "All" @time @safetestset "Triples" begin include("triples.jl") end @time @safetestset "Game of life" begin include("game_of_life.jl") end @time @safetestset "Random walk" begin include("random_walk.jl") end @time @safetestset "Resampling" begin include("resampling.jl") end end end ================================================ FILE: test/triples.jl ================================================ using StochasticAD using Test using Distributions using ForwardDiff using OffsetArrays using ChainRulesCore using Random using Zygote const backends = [ PrunedFIsBackend(), PrunedFIsAggressiveBackend(), DictFIsBackend() ] const backends_smoothed = [ SmoothedFIsBackend(), StrategyWrapperFIsBackend(SmoothedFIsBackend(), StochasticAD.TwoSidedStrategy()) ] @testset "Distributions w.r.t. continuous parameter" begin for backend in vcat(backends, backends_smoothed, :smoothing_autodiff) MAX = 10000 nsamples = 100000 rtol = 5e-2 # friendly tolerance for stochastic comparisons. TODO: more motivated choice of tolerance. ### Make test cases distributions = [ Bernoulli, Geometric, Poisson, (p -> Categorical([p^2, 1 - p^2])), (p -> Categorical([0, p^2, 0, 0, 1 - p^2])), # check that 0's are skipped over (p -> Categorical([1.0, exp(p)] ./ (1.0 + exp(p)))), # test fix for #38 (floating point comparisons in Categorical logic) (p -> Binomial(3, p)), (p -> Binomial(20, p)) ] p_ranges = [(0.2, 0.8) for _ in 1:8] out_ranges = [0:1, 0:MAX, 0:MAX, 1:2, 1:5, 1:2, 0:3, 0:20] test_cases = collect(zip(distributions, p_ranges, out_ranges)) test_funcs = [x -> 7 * x - 3, x -> (x + 1)^2, x -> sqrt(x + 1)] if backend isa DictFIsBackend # Only test dictionary backend on Bernoulli to speed things up. Should still cover interface. test_cases = test_cases[1:1] elseif backend == :smoothing_autodiff || backend in backends_smoothed # Only test smoothing backend on each unique distribution once to seed tests up. test_cases = vcat(test_cases[1:4], test_cases[7]) # Only test unbiasedness of smoothing for linear function test_funcs = test_funcs[1:1] end for (distr, p_range, out_range) in test_cases for f in test_funcs function get_mean(p) dp = distr(p) sum(pdf(dp, i) * f(i) for i in out_range) end low_p, high_p = p_range for g in [p -> p, p -> high_p + low_p - p] # test both sides of derivative full_func = f ∘ rand ∘ distr ∘ g p = low_p + (high_p - low_p) * rand() exact_deriv = ForwardDiff.derivative(p -> get_mean(g(p)), p) if backend == :smoothing_autodiff batched_full_func(p) = mean([full_func(p) for i in 1:nsamples]) # The array input used for ForwardDiff below is a trick to test multiple partials triple_deriv_forward = mean(ForwardDiff.gradient( arr -> batched_full_func(sum(arr)), [2 * p, -p])) triple_deriv_backward = Zygote.gradient(batched_full_func, p)[1] @test isapprox(triple_deriv_forward, exact_deriv, rtol = rtol) @test isapprox(triple_deriv_backward, exact_deriv, rtol = rtol) else get_deriv = () -> derivative_estimate(full_func, p; backend) triple_deriv = mean(get_deriv() for i in 1:nsamples) @test isapprox(triple_deriv, exact_deriv, rtol = rtol) end end end end end end @testset "Perturbing n of binomial" begin function get_triple_deriv(Δ) # Manually create a finite perturbation to avoid any randomness in its creation Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), Δ, 3.5) st = StochasticAD.StochasticTriple{0}(5, 0, Δs) st_continuous = stochastic_triple(0.5) return derivative_contribution(rand(Binomial(st, st_continuous))) end for Δ in -2:2 triple_deriv = mean(get_triple_deriv(Δ) for i in 1:100000) exact_deriv = 3.5 * 0.5 * Δ + 5 @test isapprox(triple_deriv, exact_deriv, rtol = 5e-2) end end @testset "Nested binomials" begin binbin = p -> rand(Binomial(rand(Binomial(10, p)), p)) # ∼ Binomial(10, p^2) for p in [0.3, 0.7] triple_deriv = mean(derivative_estimate(binbin, p) for i in 1:100000) exact_deriv = 10 * 2 * p @test isapprox(triple_deriv, exact_deriv, rtol = 5e-2) end end @testset "Boolean comparisons" begin for backend in backends tested = falses(2) while !(all(tested)) st = stochastic_triple(rand ∘ Bernoulli, 0.5; backend) x = StochasticAD.value(st) if x == 0 # Ensure errors on unsafe/unsupported boolean comparisons @test_throws Exception st>0.5 @test_throws Exception 0.5 0.5 @test 0.5 < st @test st == 1 end tested[x + 1] = true end @test stochastic_triple(1.0; backend) != 1 end end @testset "Array indexing" begin for backend in vcat(backends, backends_smoothed) p = 0.3 # Test indexing into array of floats with stochastic triple index arr = [3.5, 5.2, 8.4] (backend in backends_smoothed) && (arr[3] = 6.9) # make linear for smoothing test function array_index(p) index = rand(Categorical([p / 2, p / 2, 1 - p])) return arr[index] end array_index_mean(p) = sum([p / 2, p / 2, (1 - p)] .* arr) triple_array_index_deriv = mean(derivative_estimate(array_index, p; backend) for i in 1:50000) exact_array_index_deriv = ForwardDiff.derivative(array_index_mean, p) @test isapprox(triple_array_index_deriv, exact_array_index_deriv, rtol = 5e-2) # Don't run subsequent tests with smoothing backend (backend in backends_smoothed) && continue # Test indexing into array of stochastic triples with stochastic triple index function array_index2(p) arr2 = [rand(Bernoulli(p)), rand(Bernoulli(p)), rand(Bernoulli(p))] .* arr index = rand(Categorical([p / 2, p / 2, 1 - p])) return arr2[index] end array_index2_mean(p) = sum([p / 2 * p, p / 2 * p, (1 - p) * p] .* arr) triple_array_index2_deriv = mean(derivative_estimate(array_index2, p; backend) for i in 1:50000) exact_array_index2_deriv = ForwardDiff.derivative(array_index2_mean, p) @test isapprox(triple_array_index2_deriv, exact_array_index2_deriv, rtol = 5e-2) # Test case where triple and alternate array value are coupled function array_index3(p) st = rand(Bernoulli(p)) arr2 = [-5, st] return arr2[st + 1] end array_index3_mean(p) = -5 * (1 - p) + 1 * p triple_array_index3_deriv = mean(derivative_estimate(array_index3, p; backend) for i in 1:50000) exact_array_index3_deriv = ForwardDiff.derivative(array_index3_mean, p) @test isapprox(triple_array_index3_deriv, exact_array_index3_deriv, rtol = 5e-2) end end @testset "Array/functor inputs to higher level functions" begin for backend in backends # Try a deterministic test function to compare to ForwardDiff f(x) = (x[1] * x[2] * sin(x[3]) + exp(x[1] * x[2])) / x[3] x = [1, 2, π / 2] stochastic_ad_grad = derivative_estimate(f, x; backend) stochastic_ad_grad2 = derivative_contribution.(stochastic_triple(f, x; backend)) stochastic_ad_grad_firsttwo = derivative_estimate( f, x; direction = [1.0, 1.0, 0.0], backend) fd_grad = ForwardDiff.gradient(f, x) @test stochastic_ad_grad ≈ fd_grad @test stochastic_ad_grad ≈ stochastic_ad_grad2 @test stochastic_ad_grad[1] + stochastic_ad_grad[2] ≈ stochastic_ad_grad_firsttwo # Try an OffsetArray too f_off(x) = (x[0] * x[1] * sin(x[2]) + exp(x[0] * x[1])) / x[2] x_off = OffsetArray([1, 2, π / 2], 0:2) stochastic_ad_grad_off = derivative_estimate(f_off, x_off) @test stochastic_ad_grad_off ≈ OffsetArray(stochastic_ad_grad, 0:2) # Try a Functor f_func(x) = (x[1] * x[2][1] * sin(x[2][2]) + exp(x[1] * x[2][1])) / x[2][2] x_func = (1, [2, π / 2]) stochastic_ad_grad_func = derivative_estimate(f_func, x_func) stochastic_ad_grad_func_expected = (stochastic_ad_grad[1], stochastic_ad_grad[2:3]) compare_grad_funcs = StochasticAD.structural_map(≈, stochastic_ad_grad_func, stochastic_ad_grad_func_expected) @test all(compare_grad_funcs |> StochasticAD.structural_iterate) # Test StochasticModel + stochastic_gradient combination m = StochasticModel(f, x) @test stochastic_gradient(m).p ≈ stochastic_ad_grad end end @testset "Propagation using frule with ZeroTangent" begin st = stochastic_triple(0.5) # Verify that the rule for imag indeed gives a ZeroTangent value = StochasticAD.value(st) δ = StochasticAD.delta(st) @test frule((NoTangent(), δ), imag, value)[2] isa ZeroTangent # Test that stochastic triples flow through this rule out_st = imag(st) @test StochasticAD.value(out_st) ≈ 0 @test StochasticAD.delta(out_st) ≈ 0 @test isempty(out_st.Δs) end @testset "Unary functions converting type to fixed instance" begin for val in [0.5, 1] st = stochastic_triple(val) for op in StochasticAD.UNARY_TYPEFUNCS_WRAP f = getfield(Base, Symbol(op)) out_st = f(st) @test out_st isa StochasticAD.StochasticTriple @test StochasticAD.value(out_st) ≈ f(val) ≈ f(typeof(val)) @test StochasticAD.delta(out_st) ≈ 0 @test isempty(out_st.Δs) @test f(typeof(st)) == out_st end #= It so happens that the UNARY_TYPEFUNCS_WRAP funcs all support both instances and types whereas UNARY_TYPEFUNCS_NOWRAP only supports types, so we only test types in the below, but this is a coincidence that may not hold in the future. =# for op in StochasticAD.UNARY_TYPEFUNCS_NOWRAP f = getfield(Base, Symbol(op)) out = f(typeof(st)) @test out isa typeof(val) @test out ≈ f(typeof(val)) end RNG = copy(Random.GLOBAL_RNG) for op in StochasticAD.RNG_TYPEFUNCS_WRAP f = getfield(Random, Symbol(op)) out_st = f(copy(RNG), typeof(st)) @test out_st isa StochasticAD.StochasticTriple @test StochasticAD.value(out_st) ≈ f(copy(RNG), typeof(val)) @test StochasticAD.delta(out_st) ≈ 0 @test isempty(out_st.Δs) end end end @testset "Hashing" begin st = stochastic_triple(3.0) @test_nowarn hash(st) @test_nowarn hash(st, UInt(5)) d = Dict() @test_nowarn d[st] = 5 @test d[st] == 5 @test d[3] == 5 # Test that we get an error with discrete random dictionary indices, # since this isn't supported and we want to avoid silent failures. Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0, 1.0) st = StochasticAD.StochasticTriple{0}(1.0, 0, Δs) @test_throws ErrorException d[rand(Bernoulli(st))] end @testset "Coupled comparison" begin Δs_1 = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0, 1.0) Δs_2 = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), 1.0, 1.0) st_1 = StochasticAD.StochasticTriple{0}(1.0, 0, Δs_1) st_2 = StochasticAD.StochasticTriple{0}(1.0, 0, Δs_2) @test st_1 == st_1 @test_throws ErrorException st_1==st_2 end @testset "Converting float stochastic triples to integer triples" begin st = stochastic_triple(0.6) @test round(Int, st) isa StochasticAD.StochasticTriple @test StochasticAD.delta(round(Int, st)) ≈ 0 @test round(Int, st) ≈ 1 end @testset "Approximate comparisons" begin st = stochastic_triple(0.5) @test st ≈ st # Check that the rtol is indeed reasonable @test st ≈ st + 1e-14 @test !(st ≈ st + 1) @test_broken stochastic_triple(Inf) ≈ stochastic_triple(Inf) end @testset "Error on unmatched tags" begin st1 = stochastic_triple(0.5) st2 = stochastic_triple(x -> x^2, 0.5) @test_throws ArgumentError convert(typeof(st1), st2) end @testset "Finite perturbation backend interface" begin for backend in vcat(backends, backends_smoothed) # this boolean may need to become more fine-grained in the future is_smoothed_backend = backend in backends_smoothed #= Test the backend interface across the finite perturbation backends, which is currently a bit implicitly defined. =# V0 = Int V1 = Float64 #= All four of the below approaches should create an empty backend, although the backend's internal state management may differ. =# Δs0 = StochasticAD.create_Δs(backend, V0) # used to create first triple in computation FIs = typeof(Δs0) Δs1 = empty(Δs0) Δs2 = empty(typeof(Δs0)) Δs3 = StochasticAD.similar_empty(Δs0, V1) for (Δs, V) in ((Δs0, V0), (Δs1, V0), (Δs2, V0), (Δs3, V1)) @test StochasticAD.valtype(Δs) === V @test Δs isa StochasticAD.similar_type(FIs, V) !is_smoothed_backend && @test isempty(Δs) @test iszero(derivative_contribution(Δs)) end # Test creation of a single perturbation for Δ in (1, 3.0) Δs0 = StochasticAD.create_Δs(backend, V0) Δs1 = StochasticAD.similar_new(Δs0, Δ, 3.0) @test StochasticAD.valtype(Δs1) === typeof(Δ) @test Δs1 isa StochasticAD.similar_type(FIs, typeof(Δ)) !is_smoothed_backend && @test !isempty(Δs1) @test derivative_contribution(Δs1) == 3Δ # Test StochasticAD.alltrue @test StochasticAD.alltrue(_Δ -> true, Δs1) @test !StochasticAD.alltrue(_Δ -> false, Δs1) || is_smoothed_backend # Test map # We use a dummy deriv here and below. TODO: use a more interesting dummy for better testing. Δs1_map = Base.map(Δ -> Δ^2, Δs1; deriv = identity, out_rep = Δ) !is_smoothed_backend && @test derivative_contribution(Δs1_map) ≈ Δ^2 * 3.0 # Test map with weight (make a new copy so that original does not get reweighted) Δs2 = StochasticAD.similar_new(StochasticAD.create_Δs(backend, V0), Δ, 3.0) Δs2_weight_map = StochasticAD.weighted_map_Δs((Δ, _) -> (Δ^2, 2.0), Δs2; deriv = identity, out_rep = Δ) !is_smoothed_backend && @test derivative_contribution(Δs2_weight_map) ≈ Δ^2 * 3.0 * 2.0 # Also test scale w2 = derivative_contribution(Δs2) Δs2_scaled = StochasticAD.scale(Δs2, 2.0) w2_scaled = derivative_contribution(Δs2_scaled) @test w2_scaled ≈ 2.0 * w2 # Test map_Δs with filter state if !is_smoothed_backend Δs1_plus_Δs0 = StochasticAD.map_Δs( (Δ, state) -> Δ + StochasticAD.filter_state(Δs0, state), Δs1) @test derivative_contribution(Δs1_plus_Δs0) ≈ Δ * 3.0 Δs1_plus_mapped = StochasticAD.map_Δs( (Δ, state) -> Δ + StochasticAD.filter_state(Δs1, state), Δs1_map) @test derivative_contribution(Δs1_plus_mapped) ≈ Δ * 3.0 + Δ^2 * 3.0 end end # Test coupling Δ_coupleds = (3, [4.0, 5.0], (2, [3.0, 4.0])) for Δ_coupled in Δ_coupleds function get_Δs_coupled(; do_combine = false, use_get_rep = false) Δs0 = StochasticAD.create_Δs(backend, Int) Δs1 = StochasticAD.similar_new(Δs0, 1, 3.0) # perturbation 1 Δs2 = StochasticAD.similar_new(Δs0, 1, 2.0) # perturbation 2 # A group of perturbations that all stem from perturbation 1. Δs_all1 = StochasticAD.structural_map(Δ_coupled) do Δ Base.map(_Δ -> Δ, Δs1; deriv = identity, out_rep = Δ) end # A group of perturbations that all stem from perturbation 2. Δs_all2 = StochasticAD.structural_map(Δ_coupled) do Δ Base.map(_Δ -> 2 * Δ, Δs2; deriv = (δ -> 2δ), out_rep = Δ) end # Join them into a single structure that should be coupled Δs_all = (Δs_all1, Δs_all2) kwargs = use_get_rep ? (; rep = StochasticAD.get_rep(FIs, Δs_all)) : (;) if do_combine return StochasticAD.combine(FIs, Δs_all; kwargs...) else return StochasticAD.couple(FIs, Δs_all; out_rep = (Δ_coupled, Δ_coupled), kwargs...) end end #= As a test function to apply to the coupled perturbation, we apply a matmul followed by a sigmoid activation function and a sum. =# l = 2 * length(collect(StochasticAD.structural_iterate(Δ_coupled))) A = rand(l, l) function mapfunc(Δ_coupled) arr = collect(StochasticAD.structural_iterate(Δ_coupled)) sum(x -> 1 / (1 + exp(-x)), A * arr) end # Test the above function, and also a simple sum. for use_get_rep in (false, true) Δs_coupled = get_Δs_coupled(; use_get_rep) @test StochasticAD.valtype(Δs_coupled) == typeof((Δ_coupled, Δ_coupled)) for (mapfunc, check_combine) in ((mapfunc, false), (Δ_coupled -> sum(StochasticAD.structural_iterate(Δ_coupled)), true)) function get_contribution() Δs_coupled = get_Δs_coupled(; use_get_rep) Δs_coupled_mapped = map(mapfunc, Δs_coupled; deriv = (δ -> 1.0), out_rep = 0.0) return derivative_contribution(Δs_coupled_mapped) end zero_Δ_coupled = StochasticAD.structural_map(zero, Δ_coupled) expected_contribution1 = 3.0 * mapfunc((Δ_coupled, zero_Δ_coupled)) expected_contribution2 = 2.0 * mapfunc((zero_Δ_coupled, StochasticAD.structural_map(x -> 2x, Δ_coupled))) expected_contribution = expected_contribution1 + expected_contribution2 if !is_smoothed_backend @test isapprox(mean(get_contribution() for i in 1:1000), expected_contribution; rtol = 5e-2) end # For a simple sum, this should be equivalent to the combine behaviour. if check_combine && !is_smoothed_backend @test isapprox( mean(derivative_contribution(get_Δs_coupled(; do_combine = true)) for i in 1:1000), expected_contribution; rtol = 5e-2) end # Check scalarize Δs_coupled2 = StochasticAD.couple(FIs, StochasticAD.scalarize(Δs_coupled; out_rep = (Δ_coupled, Δ_coupled)), out_rep = (Δ_coupled, Δ_coupled)) @test derivative_contribution(map(mapfunc, Δs_coupled; deriv = (δ -> 1.0), out_rep = 0.0)) ≈ derivative_contribution(map(mapfunc, Δs_coupled2; deriv = (δ -> 1.0), out_rep = 0.0)) end end end end end @testset "Getting information about stochastic triples" begin for backend in vcat(backends, backends_smoothed) Random.seed!(4321) f(x) = rand(Bernoulli(x)) + x st = stochastic_triple(f, 0.5; backend) # Expected: 0.5 + 1.0ε + (1.0 with probability 2.0ε) dual = ForwardDiff.Dual(0.5, 1.0) @test StochasticAD.value(0.5) == 0.5 @test StochasticAD.value(st) == 0.5 @test StochasticAD.value(dual) == 0.5 @test iszero(StochasticAD.delta(0.5)) @test StochasticAD.delta(st) == 1.0 @test StochasticAD.delta(dual) == 1.0 if !(backend in backends_smoothed) #= NB: since the implementation of perturbations can be backend-specific, the below property need not hold in general, but does for the current non-smoothed backends. =# p = only(perturbations(st)) @test p.Δ == 1 && p.weight == 2.0 @test derivative_contribution(st) == 3.0 else # Since smoothed algorithm uses the two-sided strategy, we get a different derivative contribution. @test derivative_contribution(st) == 2.0 end @test StochasticAD.tag(st) === StochasticAD.Tag{typeof(f), Float64} @test StochasticAD.valtype(st) === Float64 @test StochasticAD.valtype(st.Δs) === Float64 end end @testset "Propagation via StochasticAD.propagate" begin for backend in backends function form_triple(primal, δ, Δ, Δs_base) Δs = map(_Δ -> Δ, Δs_base) return StochasticAD.StochasticTriple{0}(primal, δ, Δs) end function test_propagate(f, primals, Δs; test_deltas = false) Δs_base = StochasticAD.similar_new(StochasticAD.create_Δs(backend, Int), 0, 1.0) _form_triple(x, δ, Δ) = form_triple(x, δ, Δ, Δs_base) out = f(primals...) out_Δ_expected = StochasticAD.structural_map(-, f(StochasticAD.structural_map(+, primals, Δs)...), f(primals...)) if test_deltas duals = StochasticAD.structural_map(primals) do x x isa AbstractFloat ? ForwardDiff.Dual{0}(x, rand(typeof(x))) : x end δs = StochasticAD.structural_map(StochasticAD.delta, duals) out_δ_expected = StochasticAD.structural_map(StochasticAD.delta, f(duals...)) else δs = StochasticAD.structural_map(zero, primals) out_δ_expected = StochasticAD.structural_map(zero, out) end input_sts = StochasticAD.structural_map(_form_triple, primals, δs, Δs) out_st = StochasticAD.propagate(f, input_sts...; keep_deltas = Val{test_deltas}) # Test type StochasticAD.structural_map(out_st, out, out_δ_expected, out_Δ_expected) do x_st, x, δ, Δ @test x_st isa StochasticAD.StochasticTriple{0, typeof(x)} @test StochasticAD.value(x_st) == x @test StochasticAD.delta(x_st) ≈ δ p = only(perturbations(x_st)) @test p.Δ == Δ && p.weight == 1.0 end end #= Test propagation through some simple functions. f1: a simple if statement. f2: involves array-containing-fucntor input and output. f3: involves array-containing-functor input, but real output. f4: length ∘ repr (real or array input, real output). f5: mutates input array! Broken since unsupported. f6: the first-arg (blob) should just be passed through without attempting to perturb. Broken since unsupported. f7: involves matrix-containing-functor input and output. =# function f1(x) if x == 0 return 1 elseif x == 3 return 2 else return 5 end end @test StochasticAD.propagate(f1, 0) === f1(0) for (primal, Δ) in [(0, 3), (0, 4), (3, -1)] test_propagate(f1, (primal,), (Δ,)) end function f2(arr, scalar) if sum(arr) + scalar <= 5 return arr .* scalar, sum(arr) * scalar else return arr .- scalar, sum(arr) - scalar end end f3(arr, scalar) = f2(arr, scalar)[2] primals1 = ([1, 1], 2) Δs1 = ([2, 3], 5) primals2 = ([1, 2], 1) Δs2 = ([1, -2], 1) primals3 = ([5, 2], -1) Δs3 = ([-3, 1], 0) for (primals, Δs) in [(primals1, Δs1), (primals2, Δs2), (primals3, Δs3)] for test_deltas in (false, true) if test_deltas primals = StochasticAD.structural_map(float, primals) Δs = StochasticAD.structural_map(float, Δs) end test_propagate(f2, primals, Δs; test_deltas) test_propagate(f3, primals, Δs; test_deltas) end end f4(x) = Base.length(repr(x)) for (primals, Δs) in [(2, 11), (([3, 14],), ([14, -152],))] test_propagate(f4, primals, Δs) end function f5(arr) if arr == [1, 2] arr .+= 1 else arr .-= 1 end end # Tests for f6 skipped (would break) for (primals, Δs) in [([1, 2], [1, -1]), ([2, 4], [-1, -2]), ([2, 4], [-1, -1])] @test_skip "propagate f5" # test_propagate(f5, primals, Δs) end f6(blob, arr) = blob, f5(arr) # Tests for f6 missing (would break) @test_skip "propagate f6" function f7(mat, scalar) return mat * scalar, scalar + sum(mat) end test_propagate(f7, (rand(2, 2), 4.0), (rand(2, 2), 1.0); test_deltas = true) end end @testset "zero'ing of Inf/NaN (#79)" begin st = stochastic_triple(0.5) st_zero = zero(1 / zero(st)) @test iszero(StochasticAD.value(st_zero)) @test iszero(StochasticAD.delta(st_zero)) end @testset "smooth_triple" begin f(p) = sum(rand(Bernoulli(p)) * i for i in 1:100) f2(p) = sum(smooth_triple(rand(Bernoulli(p))) * i for i in 1:100) p = 0.6 f_est = mean(derivative_estimate(f, p) for i in 1:10000) f2_est = mean(derivative_estimate(f2, p) for i in 1:10000) @test f_est≈f2_est rtol=5e-2 end @testset "No unnecessary float promotion" begin f(p) = rand(Bernoulli(p))^2 st = stochastic_triple(f, 0.5) @test StochasticAD.valtype(st) == typeof(convert(Signed, f(0.5))) end ================================================ FILE: tutorials/Project.toml ================================================ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d" GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ================================================ FILE: tutorials/README.md ================================================ The raw source code for our tutorials. You likely want to look at the documentation instead, where they are presented more clearly! ================================================ FILE: tutorials/game_of_life/core.jl ================================================ module GoLCore using Random using Distributions using LinearAlgebra using StochasticAD using StaticArrays using OffsetArrays function update_state!(all_probs, N, board, board_old) for i in (-N):N for j in (-N):N neighbours = board_old[i + 1, j] + board_old[i - 1, j] + board_old[i, j - 1] + board_old[i, j + 1] index = board[i, j] * 5 + neighbours + 1 # trick necessary because we do not have implementation support for stochastic triple not <: Real b = rand(Bernoulli(all_probs[index])) board[i, j] += (1 - 2 * board[i, j]) * b end end end function play_game_of_life(p, all_probs, N, T, log = false) dual_type = promote_type(typeof(rand(Bernoulli(p))), typeof.(rand.(Bernoulli.(all_probs)))...) # TODO: better way of getting the concrete type board = OffsetArray(zeros(dual_type, 2 * N + 3, 2 * N + 3), (-(N + 1)):(N + 1), (-(N + 1)):(N + 1)) # pad by 1 for i in (-N):N for j in (-N):N board[i, j] = rand(Bernoulli(p)) end end board_old = similar(board) log && (history = []) for time_step in 1:T copy!(board_old, board) update_state!(all_probs, N, board, board_old) log && push!(history, copy(board)) end if !log return sum(board) else return sum(board), board, history end end function play(p, θ = 0.1, N = 3, T = 3; log = false) # N is the board half-length, T are game time steps low = θ high = 1 - θ birth_probs = SA[low, low, low, high, low] # 0, 1, 2, 3, 4 neighbours death_probs = SA[high, high, low, low, high] # 0, 1, 2, 3, 4 neighbours return play_game_of_life(p, vcat(birth_probs, death_probs), N, T, log) end # An implementation of finite differences that uses "common random numbers" # (the same seed), for more accurate checking, albeit with a finite step size h # such that there is weight degeneracy as h → 0. function fd_clever(p, h = 0.01) state = copy(Random.default_rng()) run1 = play(p + h) copy!(Random.default_rng(), state) run2 = play(p - h) (run1 - run2) / (2h) end # Provide some default parameters p = 0.5 nsamples = 200_000 end ================================================ FILE: tutorials/game_of_life/plot_board.jl ================================================ include("core.jl") using Plots using Statistics using BenchmarkTools p = 0.5 _, board, history = stochastic_triple(p -> GoLCore.play(p; log = true), p) anim1 = @animate for (i, board) in enumerate(history) heatmap(collect(StochasticAD.value.(board)), title = "time $i", clim = (-1, 1), c = :grays) end anim2 = @animate for (i, board) in enumerate(history) heatmap(collect(StochasticAD.derivative_contribution.(board)), title = "time $i", clim = (-1, 1), c = :grays) end gif(anim1, "game.gif", fps = 15) gif(anim2, "perturbation.gif", fps = 15) fig1 = heatmap(collect(StochasticAD.value.(board)), clim = (-1, 1), c = :grays) fig2 = heatmap(collect(derivative_contribution.(board)), clim = (-1, 1), c = :grays) # TODO: graph perturbed values instead of derivative contribution savefig(fig1, "board.png") savefig(fig2, "perturbation.png") ================================================ FILE: tutorials/particle_filter/benchmark.jl ================================================ include("core.jl") include("model.jl") using Plots, LaTeXStrings using BenchmarkTools using Measurements # Benchmark for primal, forward- and reverse-mode AD of particle sampler ### compute gradients # secs for how long the benchmark should run, see https://juliaci.github.io/BenchmarkTools.jl/stable/ secs = 10 suite = BenchmarkGroup() suite["scaling"] = BenchmarkGroup(["grads"]) suite["scaling"]["primal"] = @benchmarkable ParticleFilterCore.log_likelihood( particle_filter, θtrue) suite["scaling"]["forward"] = @benchmarkable ParticleFilterCore.forw_grad(θtrue, particle_filter) suite["scaling"]["backward"] = @benchmarkable ParticleFilterCore.back_grad(θtrue, particle_filter) tune!(suite) results = run(suite, verbose = true, seconds = secs) t1 = measurement(mean(results["scaling"]["primal"].times), std(results["scaling"]["primal"].times) / sqrt(length(results["scaling"]["primal"].times))) t2 = measurement(mean(results["scaling"]["forward"].times), std(results["scaling"]["forward"].times) / sqrt(length(results["scaling"]["forward"].times))) t3 = measurement(mean(results["scaling"]["backward"].times), std(results["scaling"]["backward"].times) / sqrt(length(results["scaling"]["backward"].times))) @show t1 t2 t3 ts = (t1, t2, t3) ./ 10^6 # ms @show ts BenchmarkTools.save("benchmark_data_" * string(d) * ".json", results) ================================================ FILE: tutorials/particle_filter/bias.jl ================================================ include("core.jl") include("model.jl") using Plots, LaTeXStrings using Random # Comparison of the derivative of the particle filter with and without differentiating the resampling step. ### compute gradients Random.seed!(seed) X = [ParticleFilterCore.forw_grad(θtrue, particle_filter) for i in 1:1000] # gradient of the particle filter *with* differentiation of the resampling step Random.seed!(seed) Xbiased = [ParticleFilterCore.forw_grad_biased(θtrue, particle_filter) for i in 1:1000] # Gradient of the particle filter *without* differentiation of the resampling step # pick an arbitrary coordinate index = 1 # take derivative with respect to first parameter (2-dimensional example has a rotation matrix with four parameters in total) # plot histograms for the sampled derivative values fig = plot(normalize(fit(Histogram, getindex.(X, index), nbins = 50), mode = :pdf), legend = false) # ours plot!(normalize(fit(Histogram, getindex.(Xbiased, index), nbins = 50), mode = :pdf)) # biased vline!([mean(X)[index]], color = 1) vline!([mean(Xbiased)[index]], color = 2) # add derivative of differentiable Kalman filter as a comparison XK = ParticleFilterCore.forw_grad_Kalman(θtrue, kalman_filter) vline!([XK[index]], color = "black") display(fig) savefig(fig, "tails.pdf") ================================================ FILE: tutorials/particle_filter/core.jl ================================================ module ParticleFilterCore # load dependencies using Distributions using DistributionsAD using Random using Statistics using StatsBase using LinearAlgebra using Zygote using StochasticAD using ForwardDiff using GaussianDistributions using GaussianDistributions: correct, ⊕ using UnPack ### Particle Filter Functions # Model defs """ StochasticModel{dType<:Integer,TType<:Integer,T1,T2,T3} For parameters `θ`, `rand(start(θ))` gives a sample from the prior distribution of the starting distribution. For current state `x` and parameters `θ`, `xnew = rand(dyn(x, θ))` samples the new state (i.e. `dyn` gives for each `x, θ` a distribution-like object). Finally, `y = rand(obs(x, θ))` samples an observation. ## Constructor - `T`: total number of time steps. - `start`: starting distribution for the initial state. For example, in the form of a narrow Gaussian `start(θ) = Gaussian(x0, 0.001 * I(d))`. - `dyn`: pointwise differentiable stochastic program in the form of Markov transition densities. For example, `dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q(θ))`, where `Q(θ)` denotes the covariance matrix. - `obs`: observation model having a smooth conditional probability density depending on current state `x` and parameters `θ`. For example, `obs(x, θ) = MvNormal(x, R(θ))`, where `R(θ)` denotes the covariance matrix. """ struct StochasticModel{TType <: Integer, T1, T2, T3} T::TType # time steps start::T1 # prior dyn::T2 # dynamical model obs::T3 # observation model end # Particle filter """ ParticleFilter{mType<:Integer,MType<:StochasticModel,yType,sType} Wraps a stochastic model `StochM::StochasticModel` and observational data `ys`. Assumes a observation-likelihood is available via `pdf(obs(x, θ), y)`. ## Constructor - `m`: number of particles. - `StochM`: stochastic model of type `StochasticModel`. - `ys`: observations. - `sample_strategy`: strategy for the resampling step of the particle filter. For example, stratified sampling as implemented in `sample_stratified`. """ struct ParticleFilter{mType <: Integer, MType <: StochasticModel, yType, sType} m::mType # number of particles StochM::MType # stochastic model ys::yType # observations sample_strategy::sType # sampling function end # Kalman filter """ KalmanFilter{dType<:Integer,MType<:StochasticModel,HType,RType,QType,yType} Differentiable Kalman filter following https://github.com/mschauer/Kalman.jl/blob/master/README.md. Wraps a stochastic model `StochM::StochasticModel` and observational data `ys`. Assumes a observation-likelihood is implemented via `llikelihood(yres, S)`. For example: ``` llikelihood(yres, S) = GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres) ``` ## Constructor - `d`: dimension of the state-transition matrix Φ according to x = Φ*x + w with w ~ Normal(0,Q). - `StochM`: Stochastic model of type `StochasticModel`. - `H`: linear map from the state space into the observed space according to y = H x + ν with ν ~ Normal(0, R). - `R`: covariance matrix entering the observation model according to y = H x + ν with ν ~ Normal(0, R). - `Q`: covariance matrix entering the state-transition model according to x = Φ*x + w with w ~ Normal(0,Q). - `ys`: observations. """ struct KalmanFilter{dType <: Integer, MType <: StochasticModel, HType, RType, QType, yType} # H, R = obs # θ, Q = dyn d::dType StochM::MType # stochastic model H::HType # observation model, maps the true state space into the observed space R::RType # observation model, covariance matrix Q::QType # dynamical model, covariance matrix ys::yType # observations end """ simulate_single(StochM::StochasticModel, θ) Simulate a single particle from the forward model returning a vector of observations (no resampling steps), e.g. ``` Random.seed!(seed) xs, ys = simulate_single(StochM, θtrue) ``` to get observations ys from the latent states xs based on the (true, potentially unknown) parameters θ. """ function simulate_single(StochM::StochasticModel, θ) @unpack T, start, dyn, obs = StochM x = rand(start(θ)) y = rand(obs(x, θ)) xs = [x] ys = [y] for t in 2:T x = rand(dyn(x, θ)) y = rand(obs(x, θ)) push!(xs, x) push!(ys, y) end xs, ys end """ sample_stratified(p, K, sump=1) Stratified resampling strategy, see for example https://arxiv.org/abs/1202.6163. Here, `p` denotes the probabilities of `K` particles with `sump = sum(p)`. """ function sample_stratified(p, K, sump = 1) n = length(p) U = rand() is = zeros(Int, K) i = 1 cw = p[1] for k in 1:K t = sump * (k - 1 + U) / K while cw < t && i < n i += 1 @inbounds cw += p[i] end is[k] = i end return is end """ resample(m, X, W, ω, sample_strategy, use_new_weight=true) Resampling step wrapped for use in particle filter using differentiable resampling from the article (`use_new_weight`). Returns states `X_new` and weights `W_new` of resampled particles. ## args - `m`: number of particles. - `X`: current particle states. - `W`: current weight vector of the particles. - `ω == sum(W)` is an invariant. - `sample_strategy`: specific resampling strategy to be used. Currently, only `sample_stratified` is available. - `use_new_weight=true`: Allows one to switch between biased, stop-gradient method and differentiable resampling step. """ function resample(m, X, W, ω, sample_strategy, use_new_weight = true) js = Zygote.ignore(() -> sample_strategy(W, m, ω)) X_new = X[js] if use_new_weight # differentiable resampling W_chosen = W[js] W_new = map(w -> ω * new_weight(w / ω) / m, W_chosen) else # stop gradient, biased approach W_new = fill(ω / m, m) end X_new, W_new end """ (F::ParticleFilter)(θ; store_path=false, use_new_weight=true, s=1) Run particle filter. The particle filter propagates particles with weights `W` preserving the invariant `ω == sum(W)`. `W` is never normalized and `ω` contains therefore likelihood information. Defaults to return particle positions and weights at `T` if `store_path=false`. ## args - `θ`: parameters for the stochastic program (state-transition and observation model). - `store_path=false`: Option to store the path of the particles, e.g. to visualize/inspect their trajectories. - `use_new_weight=true`: Option to switch between the stop-gradient and our differentiable resampling step method. Defaults to using differentiable resampling. - `s`: controls the number of resampling steps according to `t > 1 && t < T && (t % s == 0)`. """ function (F::ParticleFilter)(θ; store_path = false, use_new_weight = true, s = 1) # s controls the number of resampling steps @unpack m, StochM, ys, sample_strategy = F @unpack T, start, dyn, obs = StochM X = [rand(start(θ)) for j in 1:m] # particles W = [1 / m for i in 1:m] # weights ω = 1 # total weight store_path && (Xs = [X]) for (t, y) in zip(1:T, ys) # update weights & likelihood using observations wi = map(x -> pdf(obs(x, θ), y), X) W = W .* wi ω_old = ω ω = sum(W) # resample particles if t > 1 && t < T && (t % s == 0) # && 1 / sum((W / ω) .^ 2) < length(W) ÷ 32 X, W = resample(m, X, W, ω, sample_strategy, use_new_weight) end # update particle states if t < T X = map(x -> rand(dyn(x, θ)), X) store_path && Zygote.ignore(() -> push!(Xs, X)) end end (store_path ? Xs : X), W end # differentiable Kalman filter, following https://github.com/mschauer/Kalman.jl/blob/master/README.md function llikelihood(yres, S) GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres) end """ (F::KalmanFilter)(θ) Run differentiable Kalman filter. Returns updated posterior state estimate and log likelihood. ## args - `θ`: parameters for the stochastic program (state-transition and observation model). """ function (F::KalmanFilter)(θ) @unpack d, StochM, H, R, Q, ys = F @unpack start = StochM x = start(θ) Φ = reshape(θ, d, d) x, yres, S = GaussianDistributions.correct(x, ys[1] + R, H) ll = llikelihood(yres, S) xs = Any[x] for i in 2:length(ys) x = Φ * x ⊕ Q x, yres, S = GaussianDistributions.correct(x, ys[i] + R, H) ll += llikelihood(yres, S) push!(xs, x) end xs, ll end # compute log-likelihood of Particle Sampler """ log_likelihood(F::ParticleFilter, θ, use_new_weight=true, s=1) Compute log-likelihood of particle sampler. See `ParticleFilter` for `use_new_weight` and `s`. ## args - `θ`: parameters for the stochastic program (state-transition and observation model). """ function log_likelihood(F::ParticleFilter, θ, use_new_weight = true, s = 1) _, W = F(θ; store_path = false, use_new_weight = use_new_weight, s = s) log(sum(W)) end # compute log-likelihood of Kalman Filter """ log_likelihood(F::KalmanFilter, θ) Compute log-likelihood of Kalman filter. ## args - `θ`: parameters for the stochastic program (state-transition and observation model). """ function log_likelihood(F::KalmanFilter, θ) _, ll = F(θ) ll end # forward differentiation of particle sampler function forw_grad(θ, F::ParticleFilter; s = 1) ForwardDiff.gradient(θ -> log_likelihood(F, θ, true, s), θ) end # backward differentiation of particle sampler function back_grad(θ, F::ParticleFilter; s = 1) Zygote.gradient(θ -> log_likelihood(F, θ, true, s), θ)[1] end # biased forward differentiation of particle sampler, avoiding differentiation of the resampling step function forw_grad_biased(θ, F::ParticleFilter; s = 1) ForwardDiff.gradient(θ -> log_likelihood(F, θ, false, s), θ) end # forward-mode AD of Kalman filter forw_grad_Kalman(θ, F::KalmanFilter) = ForwardDiff.gradient(θ -> log_likelihood(F, θ), θ) end ================================================ FILE: tutorials/particle_filter/model.jl ================================================ # ParticleFilter Model using Random, LinearAlgebra, GaussianDistributions, Distributions # particle filter core function definitions include("core.jl") ### Define model d = 2 # dimension T = 20 # time steps # generate a rotation matrix, dynamical model, observation model, prior distribution as a function of d function generate_system(d, T) # here: n-dimensional rotation matrix seed = 423897 Random.seed!(seed) M = randn(d, d) c = 0.3 # scaling O = exp(c * (M - transpose(M)) / 2) @assert det(O) ≈ 1 @assert transpose(O) * O ≈ I(d) # true parameter θtrue = vec(O) # observation model R = 0.01 * collect(I(d)) obs(x, θ) = MvNormal(x, R) # y = H x + ν with ν ~ Normal(0, R) # dynamical model Q = 0.02 * collect(I(d)) dyn(x, θ) = MvNormal(reshape(θ, d, d) * x, Q) # x = Φ*x + w with w ~ Normal(0,Q) # starting position x0 = randn(d) # prior distribution start(θ) = Gaussian(x0, 0.001 * collect(I(d))) # put it all together stochastic_model = ParticleFilterCore.StochasticModel(T, start, dyn, obs) # relevant corresponding Kalman filterng defs H_Kalman = collect(I(d)) R_Kalman = Gaussian(zeros(Float64, d), R) # Φ_Kalman = O Q_Kalman = Gaussian(zeros(Float64, d), Q) ### simulate model Random.seed!(seed) xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue) ### ### initialize filters m = 1000 # number of particles kalman_filter = ParticleFilterCore.KalmanFilter( d, stochastic_model, H_Kalman, R_Kalman, Q_Kalman, ys) particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys, ParticleFilterCore.sample_stratified) return θtrue, xs, ys, stochastic_model, kalman_filter, particle_filter end θtrue, xs, ys, stochastic_model, kalman_filter, particle_filter = generate_system(d, T) ================================================ FILE: tutorials/particle_filter/variance.jl ================================================ include("core.jl") include("model.jl") using Plots, LaTeXStrings using Random Random.seed!(seed) # Comparison of the variance of the particle filter with and without differentiating the resampling step *as a function of the time steps*. vars_pf = [] vars_pf_biased = [] Ts = 5:5:30 for T in Ts # Random.seed!(seed) is fixed in model! θtrue, xs, ys, stochastic_model, kalman_filter, particle_filter = generate_system(d, T) xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue) particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys, ParticleFilterCore.sample_stratified) ### compute var of gradients # Gradient of the particle filter *with* differentiation of the resampling step var_pf = @time var([ParticleFilterCore.forw_grad(θtrue, particle_filter) for i in 1:100]) # Gradient of the particle filter *without* differentiation of the resampling step var_pf_biased = @time var([ParticleFilterCore.forw_grad_biased(θtrue, particle_filter) for i in 1:100]) push!(vars_pf, var_pf) push!(vars_pf_biased, var_pf_biased) end @show vars_pf @show vars_pf_biased # pick an arbitrary coordinate index = 1 # take derivative with respect to first parameter fig = plot(Ts, getindex.(vars_pf, index), color = 1, label = "unbiased", size = (300, 250), xlabel = L"n", ylabel = "variance", legend = :topleft, y_scale = :log) scatter!(Ts, getindex.(vars_pf, index), color = 1, label = false) plot!(Ts, getindex.(vars_pf_biased, index), color = 2, label = "biased") scatter!(Ts, getindex.(vars_pf_biased, index), color = 2, label = false) display(fig) savefig(fig, "particle_filter_variance_steps.pdf") # Comparison of the variance of the particle filter with and without differentiating the resampling step *as a function of the system size*. vars_pf = [] vars_pf_biased = [] ds = 2:1:6 for d in ds # Random.seed!(seed) is fixed in model! θtrue, xs, ys, stochastic_model, kalman_filter, particle_filter = generate_system(d, 10) xs, ys = ParticleFilterCore.simulate_single(stochastic_model, θtrue) particle_filter = ParticleFilterCore.ParticleFilter(m, stochastic_model, ys, ParticleFilterCore.sample_stratified) ### compute var of gradients # Gradient of the particle filter *with* differentiation of the resampling step var_pf = @time var([ParticleFilterCore.forw_grad(θtrue, particle_filter) for i in 1:50]) # Gradient of the particle filter *without* differentiation of the resampling step var_pf_biased = @time var([ParticleFilterCore.forw_grad_biased(θtrue, particle_filter) for i in 1:50]) push!(vars_pf, var_pf) push!(vars_pf_biased, var_pf_biased) end fig = plot(ds, getindex.(vars_pf, index), color = 1, label = "unbiased", size = (300, 250), xlabel = L"d", ylabel = "variance", legend = :topleft, y_scale = :log) scatter!(ds, getindex.(vars_pf, index), color = 1, label = false) plot!(ds, getindex.(vars_pf_biased, index), color = 2, label = "biased") scatter!(ds, getindex.(vars_pf_biased, index), color = 2, label = false) display(fig) savefig(fig, "particle_filter_variance_size.pdf") ================================================ FILE: tutorials/particle_filter/visualize.jl ================================================ include("core.jl") include("model.jl") using Plots, LaTeXStrings # visualization of stochastic process (observations and latent states), particle filter, and Kalman filter ### run and visualize filters Xs, W = particle_filter(θtrue; store_path = true) fig = plot(getindex.(xs, 1), getindex.(xs, 2), legend = false, xlabel = L"x_1", ylabel = L"x_2") # x1 and x2 are bad names..conflicting notation scatter!(fig, getindex.(ys, 1), getindex.(ys, 2)) for i in 1:min(m, 100) # note that Xs has obs noise. local xs = [Xs[t][i] for t in 1:T] scatter!(fig, getindex.(xs, 1), getindex.(xs, 2), marker_z = 1:T, color = :cool, alpha = 0.1) # color to indicate time step end xs_Kalman, ll_Kalman = kalman_filter(θtrue) plot!(getindex.(mean.(xs_Kalman), 1), getindex.(mean.(xs_Kalman), 2), legend = false, color = "red") display(fig) savefig(fig, "filter.pdf") ================================================ FILE: tutorials/random_walk/compare_score.jl ================================================ include("core.jl") using Plots, LaTeXStrings using Statistics using StochasticAD using ForwardDiff: derivative using ProgressMeter begin stds_triple = Float64[] stds_smoothed = Float64[] stds_score = Float64[] stds_score_baseline = Float64[] @showprogress for (n, p) in zip(RandomWalkCore.n_range, RandomWalkCore.p_range) std_triple = std(derivative_estimate(p -> RandomWalkCore.fX(p, n), p) for i in 1:(RandomWalkCore.nsamples)) std_smoothed = std(derivative( p -> RandomWalkCore.fX(p, n; hardcode_leftright_step = true), p) for i in 1:(RandomWalkCore.nsamples)) std_score = std(RandomWalkCore.score_fX_deriv(p, n, 0.0) for i in 1:(RandomWalkCore.nsamples)) avg = mean(RandomWalkCore.fX(p, n) for i in 1:10000) std_score_baseline = std(RandomWalkCore.score_fX_deriv(p, n, avg) for i in 1:(RandomWalkCore.nsamples)) push!(stds_triple, std_triple) push!(stds_score, std_score) push!(stds_score_baseline, std_score_baseline) push!(stds_smoothed, std_smoothed) end end @show stds_triple @show stds_score @show stds_score_baseline @show stds_smoothed begin show_smoothed = false fig = plot(RandomWalkCore.n_range, stds_score, color = 2, label = "score-function", size = (300, 250), xlabel = L"n", ylabel = "standard deviation", legend = :topleft) scatter!(RandomWalkCore.n_range, stds_score, color = 2, label = false) plot!(RandomWalkCore.n_range, stds_score_baseline, color = 3, label = "score-function w/ CV") scatter!(RandomWalkCore.n_range, stds_score_baseline, color = 3, label = false) plot!(RandomWalkCore.n_range, stds_triple, color = 1, label = "stochastic triples") scatter!(RandomWalkCore.n_range, stds_triple, color = 1, label = false) if show_smoothed plot!(RandomWalkCore.n_range, stds_smoothed, color = 4, label = "smoothed stochastic triples") scatter!(RandomWalkCore.n_range, stds_smoothed, color = 4, label = false) end display(fig) plot!(fig, dpi = 500) savefig(fig, "random_walk.png") end ================================================ FILE: tutorials/random_walk/core.jl ================================================ module RandomWalkCore using Random using Statistics using Distributions using LinearAlgebra using StochasticAD using StaticArrays using OffsetArrays: Origin import ForwardDiff using ForwardDiff: Dual, derivative, value, partials ## Parameters steps = SA[-1, 1] make_probs(p) = X -> SA[1 - exp(-X / p), exp(-X / p)] f = x -> x^2 # function to apply to X n = 50# number of steps p = 100 # default parameter value n_range = 10:10:100 # range for testing asymptotics p_range = 2 .* n_range nsamples = 10000 # number of times to run gradient estimators ## Simulate function simulate_walk(probs, steps, n; debug = false, hardcode_leftright_step = false) X = 0 for i in 1:n probs_X = probs(X) # transition probabilities debug && @show probs_X step_index = rand(Categorical(probs_X)) # produces an integer-valued StochasticTriple debug && @show step_index if hardcode_leftright_step step = 2 * (step_index - 1) - 1 else step = steps[step_index] # differentiate through array indexing end X += step debug && @show X end return X end X(p, n; kwargs...) = simulate_walk(make_probs(p), steps, n; kwargs...) fX(p, n; kwargs...) = f(X(p, n; kwargs...)) X(p; kwargs...) = X(p, n; kwargs...) fX(p; kwargs...) = fX(p, n; kwargs...) ## Simulate with score method manually added on function simulate_walk_score(probs, steps, n; debug = false) X = 0.0 dlogP = 0.0 for i in 1:n probs_X = probs(X) # transition probabilities step_index = convert(Int, ForwardDiff.value(rand(Categorical(probs_X)))) # just a number step = steps[step_index] # differentiate through array indexing dlogP += partials(log(probs_X[step_index]))[1] X += step # take step end return (X, dlogP) end score_X(p, n) = simulate_walk_score(make_probs(Dual(p, 1.0)), steps, n) function score_X_deriv(p, n, avg) X, dlogP = score_X(p, n) (X - avg) * dlogP end function score_fX_deriv(p, n, avg) X, dlogP = score_X(p, n) return (f(X) - avg) * dlogP end score_X_deriv(p; avg = 0.0) = score_X_deriv(p, n, avg) score_fX_deriv(p; avg = 0.0) = score_fX_deriv(p, n, avg) ## Exactly compute transition matrix M range = 0:n range_start = 1 # range[range_start] = 0 function get_M(p) probs = make_probs(p) M = zeros(eltype(first(probs(range[range_start]))), length(range), length(range)) low = minimum(range) for x in range for (step, prob) in zip(steps, probs(x)) if (x + step) in range M[x + step - low + 1, x - low + 1] = prob end end end M end function probdensity(p, n) M = get_M(p) vec = zeros(length(range)) vec[range_start] = 1 M^n * vec end get_dX(p, n) = sum(probdensity(p, n) .* range) get_dfX(p, n) = sum(probdensity(p, n) .* (f.(range))) end ================================================ FILE: tutorials/random_walk/show_unbiased.jl ================================================ include("core.jl") println("## Exact computation\n") using ForwardDiff: derivative using BenchmarkTools using .RandomWalkCore: n, p, nsamples using .RandomWalkCore: X, f, fX, get_dX, get_dfX using .RandomWalkCore: score_X_deriv, score_fX_deriv using StochasticAD using Statistics import Random X_deriv = derivative(p -> get_dX(p, n), p) fX_deriv = derivative(p -> get_dfX(p, n), p) println("X derivative: $X_deriv") println("f(X) derivative: $fX_deriv") println() println("## Stochastic triple computation\n") @btime fX(p) @btime derivative_estimate(fX, p; backend = PrunedFIsAggressiveBackend()) @btime derivative_estimate(fX, p; backend = PrunedFIsBackend()) triple_X_derivs = [derivative_estimate(X, p) for i in 1:nsamples] triple_fX_derivs = [derivative_estimate(fX, p) for i in 1:nsamples] println("Stochastic triple X derivative mean: $(mean(triple_X_derivs))") println("Stochastic triple X derivative std : $(std(triple_X_derivs))") println("Stochastic triple f(X) derivative mean: $(mean(triple_fX_derivs))") println("Stochastic triple f(X) derivative std: $(std(triple_fX_derivs))") println() smoothed_X_derivs = [derivative(p -> X(p; hardcode_leftright_step = true), p) for i in 1:nsamples] smoothed_fX_derivs = [derivative(p -> fX(p; hardcode_leftright_step = true), p) for i in 1:nsamples] println("Smoothed X derivative mean: $(mean(smoothed_X_derivs))") println("Smoothed X derivative std : $(std(smoothed_X_derivs))") println("Smoothed f(X) derivative mean: $(mean(smoothed_fX_derivs))") println("Smoothed f(X) derivative std: $(std(smoothed_fX_derivs))") println() println("## Score function computation\n") # baseline avg_X = mean(X(p) for i in 1:10000) avg_fX = mean(fX(p) for i in 1:10000) score_X_derivs = [score_X_deriv(p; avg = avg_X) for i in 1:nsamples] score_fX_derivs = [score_fX_deriv(p; avg = avg_fX) for i in 1:nsamples] println("Score X derivative mean: $(mean(score_X_derivs))") println("Score X derivative std: $(std(score_X_derivs))") println("Score f(X) derivative mean: $(mean(score_fX_derivs))") println("Score f(X) derivative std: $(std(score_fX_derivs))") println() println("## Finite differences\n") function fd(X, p, h = 10) state = copy(Random.default_rng()) run1 = X(p - h / 2) copy!(Random.default_rng(), state) run2 = X(p + h / 2) (run2 - run1) / h end fd_X_derivs = [fd(X, p) for i in 1:nsamples] fd_fX_derivs = [fd(f ∘ X, p) for i in 1:nsamples] println("FD X derivative mean: $(mean(fd_X_derivs))") println("FD X derivative std: $(std(fd_X_derivs))") println("FD f(X) derivative mean: $(mean(fd_fX_derivs))") println("FD f(X) derivative std: $(std(fd_fX_derivs))") println() ================================================ FILE: tutorials/reverse_example/reverse_demo.jl ================================================ #text # Simple reverse mode example #text ```@setup random_walk #text import Pkg #text Pkg.activate("../../../tutorials") #text Pkg.develop(path="../../..") #text Pkg.instantiate() #text #text import Random #text Random.seed!(1234) #text ``` import Random #src Random.seed!(1234) #src ##cell #text Load our packages using StochasticAD using Distributions using Enzyme using LinearAlgebra ##cell #text Let us define our target function. # Define a toy `StochasticAD`-differentiable function for computing an integer value from a string. string_value(strings, index) = Int(sum(codepoint, strings[index])) function string_value(strings, index::StochasticTriple) StochasticAD.propagate(index -> string_value(strings, index), index) end function f(θ; derivative_coupling = StochasticAD.InversionMethodDerivativeCoupling()) strings = ["cat", "dog", "meow", "woofs"] index = randst(Categorical(θ); derivative_coupling) return string_value(strings, index) end θ = [0.1, 0.5, 0.3, 0.1] @show f(θ) nothing ##cell #text First, let's compute the sensitivity of `f` in a particular direction via forward-mode Stochastic AD. u = [1.0, 2.0, 4.0, -7.0] @show derivative_estimate( f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u) nothing ##cell #text Now, let's do the same with reverse-mode, via [`EnzymeReverseAlgorithm`](@ref). @show derivative_estimate( f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins)))) ##cell #text Let's verify that our reverse-mode gradient is consistent with our forward-mode directional derivative. function forward() derivative_estimate( f, θ, StochasticAD.ForwardAlgorithm(PrunedFIsBackend()); direction = u) end function reverse() derivative_estimate( f, θ, StochasticAD.EnzymeReverseAlgorithm(PrunedFIsBackend(Val(:wins)))) end N = 40000 directional_derivs_fwd = [forward() for i in 1:N] derivs_bwd = [reverse() for i in 1:N] directional_derivs_bwd = [dot(u, δ) for δ in derivs_bwd] println("Forward mode: $(mean(directional_derivs_fwd)) ± $(std(directional_derivs_fwd) / sqrt(N))") println("Reverse mode: $(mean(directional_derivs_bwd)) ± $(std(directional_derivs_bwd) / sqrt(N))") @assert isapprox(mean(directional_derivs_fwd), mean(directional_derivs_bwd), rtol = 3e-2) nothing ##cell #! format: off #src using Literate #src do_documenter = true #src function preprocess(content) #src new_lines = map(split(content, "\n")) do line #src if endswith(line, "#src") #src line #src elseif startswith(line, "##cell") #src "#src" #src elseif startswith(line, "#text") #src replace(line, "#text" => "#") #src # try and save comments; strip necessasry since Literate.jl also treats indented comments on their own line as markdown. #src elseif startswith(strip(line), "#") && !startswith(strip(line), "#=") && !startswith(strip(line), "#-") #src # TODO: should be replace first occurence only? #src replace(line, "#" => "##") #src else #src line #src end #src end #src return join(new_lines, "\n") #src end #src withenv("JULIA_DEBUG" => "Literate") do #src dir = joinpath(dirname(dirname(pathof(StochasticAD))), "docs", "src", "tutorials") #src if do_documenter #src @time Literate.markdown( @__FILE__, dir; execute = false, flavor = Literate.DocumenterFlavor(), preprocess = preprocess, documenter = true) #src else #src @time Literate.markdown(@__FILE__, dir; execute = true, flavor = Literate.CommonMark(), preprocess = preprocess) #src end #src end #src ================================================ FILE: tutorials/toy_optimizations/Project.toml ================================================ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac" Tilde = "73a6ac3c-4b34-4cca-a813-308f7589d80d" ================================================ FILE: tutorials/toy_optimizations/igarch.jl ================================================ # Poisson autoregression cd(@__DIR__) using StochasticAD, Distributions using Optimisers import Random Random.seed!(1234) Random.seed!(StochasticAD.RNG, 1234) PLOT = true if PLOT using CairoMakie end # Poisson autoregression model, returning end value after `n` iterations function igarch(a, b, c, n, λ) z = rand(Poisson(λ)) λ = a + b * z + c * λ for i in 2:n z = rand(Poisson(λ)) λ = a + b * z + c * λ end return λ, z end λ0 = 5.42 # true starting value ## Generate observations n = 10 a, b, c = [0.25, 0.9, 0.51] _, z_obs = igarch(a, b, c, n, λ0) # 140 in first run # Posterior density estimate of parameter p=λ0 given z_obs=140 (assume we don't know) function X(p, z_obs = 140, n = 10) a, b, c = [0.25, 0.9, 0.51] λ, _ = igarch(a, b, c, n - 1, p) pdf(Exponential(100.0), λ) * pdf(Poisson(λ), z_obs) end # Maximize posterior with Adam and Optimize p0 = [20.5] iterations = 5000 m = StochasticAD.StochasticModel(p0, x -> -X(x)) # Formulate as minimization problem trace = Float64[] o = Adam(0.08) s = Optimisers.setup(o, m) for i in 1:iterations Optimisers.update!(s, m, StochasticAD.stochastic_gradient(m)) push!(trace, m.p[]) end p_opt = m.p[] if PLOT ps = range(0, 10, length = 50) N = 1000 expected = [mean(X(p) for _ in 1:N) for p in ps] slope = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps] f = Figure() ax = f[1, 1] = Axis(f, title = "Estimates") lines!(ax, ps, expected, label = "≈ E X(p)") lines!(ax, ps, slope, label = "≈ (E X(p))'") vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0) vlines!(ax, [λ0], linestyle = :dot, linewidth = 2.0) hlines!(ax, [0.0], color = :black, linewidth = 1.0) f[1, 2] = Legend(f, ax, framevisible = false) ylims!(ax, (-1e-5, 2e-5)) ax = f[2, 1:2] = Axis(f, title = "Optimizer trace") lines!(ax, trace, color = :green, linewidth = 2.0) hlines!(ax, [λ0], linestyle = :dot, linewidth = 2.0) ylims!(ax, (0, 20)) save("igarch.png", f) display(f) end ================================================ FILE: tutorials/toy_optimizations/intro.jl ================================================ # Toy expectation optimization problem cd(@__DIR__) using StochasticAD, Distributions, Optimisers import Random # hide Random.seed!(1234) # hide PLOT = true if PLOT using CairoMakie end # The "crazy" stochastic program from the introduction function X(p) a = p * (1 - p) b = rand(Binomial(10, p)) c = 2 * b + 3 * rand(Bernoulli(p)) return a * c * rand(Normal(b, a)) end # Maximize E[X(p)] using Adam and Optimize p0 = [0.5] iterations = 5000 m = StochasticAD.StochasticModel(p0, x -> -X(x)) # Formulate as minimization problem trace = Float64[] o = Adam() s = Optimisers.setup(o, m) for i in 1:iterations Optimisers.update!(s, m, StochasticAD.stochastic_gradient(m)) push!(trace, m.p[]) end p_opt = m.p[] if PLOT dp = 1 / 50 N = 1000 ps = dp:dp:(1 - dp) avg = [mean(X(p) for _ in 1:N) for p in ps] derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps] f = Figure() ax = f[1, 1] = Axis(f, title = "Estimates") lines!(ax, ps, avg, label = "≈ E[X(p)]") lines!(ax, ps, derivative, label = "≈ d/dp E[X(p)]") vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0) hlines!(ax, [0.0], color = :black, linewidth = 1.0) f[1, 2] = Legend(f, ax, framevisible = false) ylims!(ax, (-50, 80)) ax = f[2, 1:2] = Axis(f, title = "Optimizer trace") lines!(ax, trace, color = :green, linewidth = 2.0) save("intro.png", f) display(f) end ================================================ FILE: tutorials/toy_optimizations/variational.jl ================================================ # Toy variational problem: Find Poisson(p) close to NegativeBinomial(10, 1-30/(10+30)) # by minimization of the Kullback Leibler distance cd(@__DIR__) using StochasticAD, Distributions, Optimisers import Random # hide Random.seed!(1234) # hide PLOT = true if PLOT using CairoMakie end # Sample the likelihood ratio. E[X(p)] is the Kullback-Leibler distance between the models function X(p) i = rand(Poisson(p)) return logpdf(Poisson(p), i) - logpdf(NegativeBinomial(10, 1 - 30 / (10 + 30)), i) end # Minimize E[X] = KL(Poisson(p)| NegativeBinomial(10, 1-30/(10+30))) using Adam and Optimize.jl iterations = 5000 p0 = [10.0] m = StochasticAD.StochasticModel(p0, X) # Formulate as minimization problem trace = Float64[] o = Adam(0.1) s = Optimisers.setup(o, m) for i in 1:iterations Optimisers.update!(s, m, StochasticAD.stochastic_gradient(m)) push!(trace, m.p[]) end p_opt = m.p[] if PLOT dp = 1 / 2 N = 1000 ps = 10:dp:50 avg = [mean(X(p) for _ in 1:N) for p in ps] derivative = [mean(derivative_estimate(X, p) for _ in 1:N) for p in ps] f = Figure() ax = f[1, 1] = Axis(f, title = "Estimates") lines!(ax, ps, avg, label = "≈ E[X(p)]") lines!(ax, ps, derivative, label = "≈ d/dp E[X(p)]") vlines!(ax, [p_opt], label = "p_opt", color = :green, linewidth = 2.0) hlines!(ax, [0.0], color = :black, linewidth = 1.0) f[1, 2] = Legend(f, ax, framevisible = false) ylims!(ax, (-10, 10)) ax = f[2, 1:2] = Axis(f, title = "Optimizer trace") lines!(ax, trace, color = :green, linewidth = 2.0) save("variational.png", f) display(f) end