Repository: FluxML/NNlib.jl Branch: master Commit: d3c38a49d9cf Files: 141 Total size: 634.0 KB Directory structure: gitextract_0xi0432k/ ├── .buildkite/ │ └── pipeline.yml ├── .codecov.yml ├── .github/ │ ├── copilot-instructions.md │ ├── dependabot.yml │ └── workflows/ │ ├── BenchmarkTrigger.yml │ ├── CompatHelper.yml │ ├── Downstream.yml │ ├── TagBot.yml │ ├── ci.yml │ ├── clean_preview.yml │ └── pr_comment.yml ├── .gitignore ├── LICENSE.md ├── Project.toml ├── README.md ├── benchmark/ │ ├── Project.toml │ ├── benchmarks.jl │ ├── perf_report.jl │ └── runbenchmarks.jl ├── docs/ │ ├── .gitignore │ ├── Project.toml │ ├── make.jl │ └── src/ │ ├── assets/ │ │ ├── flux.css │ │ └── jfk.flac │ ├── audio.md │ ├── index.md │ └── reference.md ├── ext/ │ ├── NNlibAMDGPUExt/ │ │ ├── NNlibAMDGPUExt.jl │ │ ├── activations.jl │ │ ├── batched_mul.jl │ │ ├── conv.jl │ │ └── pool.jl │ ├── NNlibCUDACUDNNExt/ │ │ ├── NNlibCUDACUDNNExt.jl │ │ ├── activations.jl │ │ ├── batchnorm.jl │ │ ├── conv.jl │ │ ├── pooling.jl │ │ └── softmax.jl │ ├── NNlibCUDAExt/ │ │ ├── NNlibCUDAExt.jl │ │ ├── activations.jl │ │ ├── batchedadjtrans.jl │ │ ├── batchedmul.jl │ │ ├── ctc.jl │ │ ├── sampling.jl │ │ ├── scatter.jl │ │ └── utils.jl │ ├── NNlibEnzymeCoreExt/ │ │ └── NNlibEnzymeCoreExt.jl │ ├── NNlibFFTWExt/ │ │ ├── NNlibFFTWExt.jl │ │ └── stft.jl │ ├── NNlibForwardDiffExt.jl │ ├── NNlibMetalExt.jl │ └── NNlibSpecialFunctionsExt.jl ├── src/ │ ├── NNlib.jl │ ├── activations.jl │ ├── attention.jl │ ├── audio/ │ │ ├── mel.jl │ │ ├── spectrogram.jl │ │ └── stft.jl │ ├── batched/ │ │ ├── batchedadjtrans.jl │ │ └── batchedmul.jl │ ├── bias_act.jl │ ├── conv.jl │ ├── conv_bias_act.jl │ ├── ctc.jl │ ├── deprecations.jl │ ├── dim_helpers/ │ │ ├── ConvDims.jl │ │ ├── DenseConvDims.jl │ │ ├── DepthwiseConvDims.jl │ │ └── PoolDims.jl │ ├── dim_helpers.jl │ ├── dropout.jl │ ├── fold.jl │ ├── functions.jl │ ├── gather.jl │ ├── gemm.jl │ ├── impl/ │ │ ├── conv_direct.jl │ │ ├── conv_im2col.jl │ │ ├── depthwiseconv_direct.jl │ │ ├── depthwiseconv_im2col.jl │ │ ├── padding_edges.jl │ │ └── pooling_direct.jl │ ├── normalization.jl │ ├── padding.jl │ ├── pooling.jl │ ├── rotation.jl │ ├── sampling.jl │ ├── scatter.jl │ ├── softmax.jl │ ├── upsample.jl │ └── utils.jl └── test/ ├── Project.toml ├── activations.jl ├── attention.jl ├── batchedmul.jl ├── bias_act.jl ├── conv.jl ├── conv_bias_act.jl ├── ctc.jl ├── dropout.jl ├── ext_amdgpu/ │ ├── activations.jl │ ├── attention.jl │ ├── batched_mul.jl │ ├── batched_repr.jl │ ├── conv.jl │ ├── dropout.jl │ ├── pool.jl │ ├── runtests.jl │ ├── softmax.jl │ └── storage_type.jl ├── ext_cuda/ │ ├── activations.jl │ ├── batchedadjtrans.jl │ ├── batchedmul.jl │ ├── batchnorm.jl │ ├── conv.jl │ ├── ctc.jl │ ├── dropout.jl │ ├── fold.jl │ ├── gather.jl │ ├── pooling.jl │ ├── runtests.jl │ ├── sampling.jl │ ├── scatter.jl │ ├── softmax.jl │ └── test_utils.jl ├── ext_metal/ │ ├── activations.jl │ └── runtests.jl ├── functions.jl ├── inference.jl ├── padding.jl ├── pooling.jl ├── runtests.jl ├── sampling.jl ├── softmax.jl ├── test_utils.jl ├── testsuite/ │ ├── fold.jl │ ├── gather.jl │ ├── rotation.jl │ ├── scatter.jl │ ├── spectral.jl │ └── upsample.jl └── utils.jl ================================================ FILE CONTENTS ================================================ ================================================ FILE: .buildkite/pipeline.yml ================================================ steps: - label: ":julia: Julia {{matrix.julia}} - CUDA GPU" command: - echo 'CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"' >> test/Project.toml - echo 'cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"' >> test/Project.toml plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext agents: queue: "juliagpu" cuda: "*" env: JULIA_NUM_THREADS: 4 NNLIB_TEST_CUDA: "true" NNLIB_TEST_CPU: "false" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 180 matrix: setup: julia: - "1.10" - "1" - "nightly" adjustments: - with: julia: "nightly" soft_fail: true - label: ":julia: Julia {{matrix.julia}} - AMD GPU" command: - echo 'AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"' >> test/Project.toml plugins: - JuliaCI/julia#v1: version: "1" - JuliaCI/julia-test#v1: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext agents: queue: "juliagpu" rocm: "*" rocmgpu: "*" timeout_in_minutes: 180 env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" NNLIB_TEST_AMDGPU: "true" NNLIB_TEST_CPU: "false" JULIA_NUM_THREADS: 4 matrix: setup: julia: # - "1.10" - "1" # - "nightly" # adjustments: # - with: # julia: "nightly" # soft_fail: true - label: ":julia: Julia {{matrix.julia}} - Metal GPU" command: - echo 'Metal = "dde4c033-4e86-420c-a63e-0dd931031962"' >> test/Project.toml plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext agents: queue: "juliaecosystem" os: "macos" arch: "aarch64" timeout_in_minutes: 180 env: NNLIB_TEST_METAL: "true" NNLIB_TEST_CPU: "false" JULIA_NUM_THREADS: 4 matrix: setup: julia: # - "1.10" - "1" # - "nightly " # adjustments: # - with: # julia: "nightly" # soft_fail: true - label: "Benchmarks" plugins: - JuliaCI/julia#v1: version: 1 env: JULIA_NUM_THREADS: 4 command: - julia --project=benchmark -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - julia --project=benchmark benchmark/runbenchmarks.jl - printf '%b\n' "$(cat benchmark/report.md)" | buildkite-agent annotate --style 'info' agents: queue: "juliagpu" if: build.pull_request.labels includes "benchmark" timeout_in_minutes: 30 env: SECRET_CODECOV_TOKEN: "IlEMvDI6RciJQr5eX7qBBpHYFAe8+Svf3lNJh9gZi0MeJZQvMZWzHfW/lVncA9d9K+gDBBTv/zwqF86xOaIFLuACNdcGZiGgHS+NGeXN5CEppjqLnqKuaeHmLgJ43jygxRwgF88LhwTGcHG7pmESIp1Bn3Jd23UUv4t8hJLBDF+KJLZMefzCXnEVzfwJYxhJktnKJPA4dOv59w33Vj1x5uCYZbQlLP54IJPBm8UGdXS+JrUX8Z7lhxbkJUi6c+R6cvVBw27uRjF0pUJY26mt1frx8MzTGTOweXTpi+Kc5JhzlokMlan17j6T/b7qMC13IuKopfqu1GhkSBQD3ZhQqA==;U2FsdGVkX19l7JMB48k4oJHLoaqC7/MmvQWmaiBxRN472ZC6AcQ0uCBRy6Fw8tI0YcjIxKDScaBnJ2v/deOfhg==" ================================================ FILE: .codecov.yml ================================================ comment: false ================================================ FILE: .github/copilot-instructions.md ================================================ # NNlib.jl Copilot Instructions ## Repository Overview NNlib.jl is a library providing fundamental neural network operations and primitives for Julia. It is primarily used by Flux.jl but can be used independently. The library provides: - Activation functions (sigmoid, relu, gelu, etc.) - Convolution and pooling operations - Attention mechanisms - Batched matrix operations - Neural network utilities (dropout, normalization, etc.) - GPU acceleration support (CUDA, AMDGPU) ## Project Structure ``` NNlib.jl/ ├── src/ # Core library implementation │ ├── NNlib.jl # Main module file │ ├── activations.jl # Activation functions │ ├── attention.jl # Attention mechanisms │ ├── conv.jl # Convolution operations │ ├── pooling.jl # Pooling operations │ ├── batched/ # Batched operations │ └── impl/ # Implementation details ├── ext/ # Package extensions for GPU backends │ ├── NNlibCUDAExt/ # CUDA-specific implementations │ ├── NNlibAMDGPUExt/ # AMDGPU-specific implementations │ └── NNlibCUDACUDNNExt/ # cuDNN-specific implementations ├── test/ # Test suite └── docs/ # Documentation ``` ## Julia Version - Minimum Julia version: 1.10 - CI tests on: minimum julia version, latest stable (1.x), and pre-release versions ## Coding Standards ### Julia Conventions 1. **Naming**: - Functions: lowercase with underscores (e.g., `dot_product_attention`) - Types: PascalCase (e.g., `ConvDims`, `PoolDims`) - Constants: UPPERCASE with underscores (e.g., `ACTIVATIONS`) 2. **Documentation**: - Use Julia docstrings (""" ... """) for all exported functions - Include examples in docstrings where appropriate - Keep documentation up-to-date with implementation changes 3. **Type Annotations**: - Use type parameters and abstract types for generic implementations - Leverage Julia's multiple dispatch for specialized implementations - Define clear type hierarchies (e.g., `DenseConvDims`, `DepthwiseConvDims`) 4. **Performance**: - Prefer in-place operations where appropriate (functions ending with `!`) - Use `@inbounds` judiciously when bounds checking is verified - Consider thread safety for multi-threaded operations - Use `NNlib.@disallow_spawns` to control threading behavior ### Code Organization 1. **Core Implementations**: CPU implementations go in `src/` 2. **GPU Extensions**: GPU-specific code belongs in `ext/` as package extensions 3. **Tests**: Mirror the structure of `src/` in `test/` 4. **Gradients**: Define gradients using ChainRules.jl (`rrule` functions) ## Testing ### Test Infrastructure - Uses the standard Julia `Test` framework - Tests are organized to mirror the source structure - GPU tests are conditional (controlled by environment variables) ### Running Tests ```julia # Run all CPU tests julia --project -e 'using Pkg; Pkg.test()' # Run tests with threading JULIA_NUM_THREADS=4 julia --project -e 'using Pkg; Pkg.test()' ``` ### Test Patterns 1. **Activation Functions**: Test at specific values (0.0, 1.0, -1.0) and verify expected outputs 2. **Gradient Tests**: Use `ChainRulesTestUtils` for gradient correctness 3. **Type Stability**: Use `@inferred` where appropriate 4. **GPU Tests**: Conditional testing based on environment variables: - `ENV["NNLIB_TEST_CUDA"]` for CUDA tests - `ENV["NNLIB_TEST_AMDGPU"]` for AMDGPU tests ### Writing New Tests - Include tests for edge cases (zero inputs, negative values, boundary conditions) - Test both forward pass and gradients (using ChainRulesTestUtils) - For array operations, test multiple dimensions and batch sizes - Include tests for type stability when performance-critical ## Dependencies ### Core Dependencies - **ChainRulesCore**: For automatic differentiation support - **KernelAbstractions**: For GPU kernel abstractions - **Adapt**: For moving data between CPU/GPU - **GPUArraysCore**: GPU array interface ### Weak Dependencies (Extensions) - **CUDA.jl/cuDNN**: NVIDIA GPU support - **AMDGPU.jl**: AMD GPU support - **FFTW**: Fast Fourier transforms - **ForwardDiff**: Forward-mode AD support - **EnzymeCore**: Enzyme AD support ### Adding New Dependencies - Consider whether the dependency should be a weak dependency (extension) - Update `Project.toml` with version constraints - Ensure compatibility with supported Julia versions - Run full test suite after adding dependencies ## GPU Support NNlib uses Julia's package extension system for GPU backends: 1. **CUDA**: Load with `using NNlib, CUDA, cuDNN` 2. **AMDGPU**: Load with `using NNlib, AMDGPU` ### GPU Implementation Guidelines - Keep GPU-specific code in appropriate extensions (`ext/` directory) - Provide CPU fallback implementations in `src/` - Test GPU implementations separately (conditional on hardware availability) - Use KernelAbstractions for portable GPU kernels when possible ## Build and CI/CD ### Continuous Integration - **CI Workflow**: `.github/workflows/ci.yml` - Tests on Linux (always), Windows, and macOS - Tests with different Julia versions (LTS, stable, pre-release) - Tests with different thread counts ### Additional Workflows - **TagBot**: Automatic release tagging - **CompatHelper**: Dependency compatibility updates - **Downstream**: Tests dependent packages - **BenchmarkTrigger**: Performance regression testing ## Common Tasks ### Adding a New Activation Function 1. Add function to `src/activations.jl` 2. Add to `ACTIVATIONS` tuple for automatic export 3. Define gradient with `@scalar_rule` or `rrule` 4. Add tests in `test/activations.jl` at key values 5. Document with docstring and example ### Adding a New Operation 1. Implement in appropriate file in `src/` 2. Export from `src/NNlib.jl` 3. Define gradients using ChainRules 4. Add comprehensive tests 5. Add GPU implementations in extensions if applicable 6. Document in appropriate file in `docs/src/` ### Modifying Existing Functions 1. Check for dependent code in Flux.jl and other downstream packages 2. Maintain backward compatibility or document breaking changes 3. Update tests to cover new behavior 4. Update gradients if needed 5. Consider performance implications ## Performance Considerations 1. **Memory Allocation**: Minimize allocations in hot paths 2. **Threading**: NNlib uses Julia threads for parallel operations - Control with `NNlib.@disallow_spawns` if needed - Thread count controlled by `JULIA_NUM_THREADS` 3. **GPU Kernels**: Optimize kernel launch parameters and memory access patterns 4. **Type Stability**: Ensure type-stable code for performance-critical paths ## Documentation - Documentation source: `docs/src/` - Built with Documenter.jl - Includes API reference, examples, and guides - Documentation tests run via `DocTestSetup` ## Related Projects - **Flux.jl**: Primary consumer of NNlib - **Zygote.jl**: Automatic differentiation (uses ChainRules) - **ChainRules.jl**: Gradient definitions - **KernelAbstractions.jl**: GPU kernel abstraction ## Getting Help - Documentation: https://fluxml.ai/NNlib.jl/dev/ - Issues: https://github.com/FluxML/NNlib.jl/issues - FluxML community: https://github.com/FluxML ================================================ FILE: .github/dependabot.yml ================================================ # To get started with Dependabot version updates, you'll need to specify which # package ecosystems to update and where the package manifests are located. # Please see the documentation for all configuration options: # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file version: 2 updates: - package-ecosystem: "github-actions" directory: "/" # Location of package manifests schedule: interval: "weekly" ================================================ FILE: .github/workflows/BenchmarkTrigger.yml ================================================ name: Benchmark Trigger on: pull_request_target: types: [ labeled ] workflow_dispatch: inputs: pr_id: type: string description: id of the pull request that triggers this workflow target_url: type: string description: url of target baseline_url: type: string description: url of baseline jobs: benchmark_trigger: if: ${{ github.event.label.name == 'benchmark' }} runs-on: ubuntu-latest env: REPOSITORY: ${{ github.event.repository.full_name }} PR_ID: ${{ github.event.inputs.pr_id || github.event.pull_request.number }} TARGET_URL: ${{ github.event.inputs.target_url || format('{0}#{1}', github.event.pull_request.head.repo.html_url, github.event.pull_request.head.sha) }} BASELINE_URL: ${{ github.event.inputs.baseline_url || format('{0}#{1}', github.event.pull_request.base.repo.html_url, github.event.pull_request.base.sha) }} steps: - name: Get app installation token (ghs) id: get-app-token uses: tibdex/github-app-token@v2 with: app_id: ${{ secrets.BENCH_APP_ID }} installation_id: ${{ secrets.BENCH_INSTALL_ID }} private_key: ${{ secrets.BENCH_PRIVATE_KEY }} - uses: benc-uk/workflow-dispatch@v1 with: repo: FluxML/FluxMLBenchmarks.jl ref: refs/heads/main workflow: Benchmark.yml token: ${{ steps.get-app-token.outputs.token }} inputs: '{ "repository": "${{ env.REPOSITORY }}", "pr_id": "${{ env.PR_ID }}", "target_url": "${{ env.TARGET_URL }}", "baseline_url": "${{ env.BASELINE_URL }}" }' ================================================ FILE: .github/workflows/CompatHelper.yml ================================================ name: CompatHelper on: schedule: - cron: 0 0 * * * workflow_dispatch: permissions: contents: write pull-requests: write jobs: CompatHelper: runs-on: ubuntu-latest steps: - name: Check if Julia is already available in the PATH id: julia_in_path run: which julia continue-on-error: true - name: Install Julia, but only if it is not already available in the PATH uses: julia-actions/setup-julia@v2 with: version: '1' arch: ${{ runner.arch }} if: steps.julia_in_path.outcome != 'success' - name: "Add the General registry via Git" run: | import Pkg ENV["JULIA_PKG_SERVER"] = "" Pkg.Registry.add("General") shell: julia --color=yes {0} - name: "Install CompatHelper" run: | import Pkg name = "CompatHelper" uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" version = "3" Pkg.add(; name, uuid, version) shell: julia --color=yes {0} - name: "Run CompatHelper" run: | import CompatHelper CompatHelper.main() shell: julia --color=yes {0} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} # COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }} ================================================ FILE: .github/workflows/Downstream.yml ================================================ name: IntegrationTest on: push: branches: [master] tags: [v*] pull_request: # needed to allow julia-actions/cache to delete old caches that it has created permissions: actions: write contents: read jobs: test: name: ${{ matrix.package.repo }}/${{ matrix.package.group }} runs-on: ${{ matrix.os }} env: GROUP: ${{ matrix.package.group }} strategy: fail-fast: false matrix: julia-version: [1] os: [ubuntu-latest] package: - {user: FluxML, repo: Flux.jl, group: All} - {user: FluxML, repo: Tracker.jl, group: All} - {user: LuxDL, repo: Lux.jl, group: All} steps: - uses: actions/checkout@v6 - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@latest - name: Clone Downstream uses: actions/checkout@v6 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream - name: Load this and run the downstream tests shell: julia --color=yes --project=downstream {0} run: | using Pkg try # force it to use this PR's version of the package Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps Pkg.update() Pkg.test() # resolver may fail with test time deps catch err err isa Pkg.Resolve.ResolverError || rethrow() # If we can't resolve that means this is incompatible by SemVer and this is fine # It means we marked this as a breaking change, so we don't need to worry about # Mistakenly introducing a breaking change, as we have intentionally made one @info "Not compatible with this release. No problem." exception=err exit(0) # Exit immediately, as a success end env: RETESTITEMS_NWORKERS: 4 BACKEND_GROUP: CPU # for Lux.jl ================================================ FILE: .github/workflows/TagBot.yml ================================================ name: TagBot on: issue_comment: types: - created workflow_dispatch: inputs: lookback: default: 3 permissions: actions: read checks: read contents: write deployments: read issues: read discussions: read packages: read pages: read pull-requests: read repository-projects: read security-events: read statuses: read jobs: TagBot: if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' runs-on: ubuntu-latest steps: - uses: JuliaRegistries/TagBot@v1 with: token: ${{ secrets.GITHUB_TOKEN }} # Edit the following line to reflect the actual name of the GitHub Secret containing your private key ssh: ${{ secrets.DOCUMENTER_KEY }} # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }} ================================================ FILE: .github/workflows/ci.yml ================================================ name: CI on: push: branches: - master - staging - trying tags: '*' pull_request: # needed to allow julia-actions/cache to delete old caches that it has created permissions: actions: write contents: read defaults: run: shell: bash jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.julia-threads }} thread(s) runs-on: ${{ matrix.os }} env: JULIA_NUM_THREADS: ${{ matrix.julia-threads }} strategy: fail-fast: false matrix: version: - '1.10' # uncomment when julia 1.10 is out - '1' # automatically expands to the latest stable 1.x release of Julia - 'nightly' os: - ubuntu-latest # - macOS-latest # - windows-latest julia-threads: - '1' include: - os: windows-latest version: '1' julia-threads: '1' - os: macOS-latest version: '1' julia-threads: '1' - os: ubuntu-latest version: '1' julia-threads: '2' steps: - uses: actions/checkout@v6 - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - name: "Run test without coverage" uses: julia-actions/julia-runtest@v1 if: ${{ !contains(fromJson('["1"]'), matrix.version) || matrix.os != 'ubuntu-latest' }} with: coverage: false - name: "Run test with coverage" uses: julia-actions/julia-runtest@v1 if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest' - uses: julia-actions/julia-processcoverage@v1 if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v5 if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest' with: file: lcov.info docs: name: Documentation runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: julia-actions/setup-julia@v2 with: version: '1.10' - uses: julia-actions/cache@v2 - run: | julia --project=docs -e ' using Pkg Pkg.develop(PackageSpec(path=pwd())) Pkg.instantiate()' - run: julia --project=docs docs/make.jl env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} ================================================ FILE: .github/workflows/clean_preview.yml ================================================ # from https://github.com/CliMA/ClimaTimeSteppers.jl name: Doc Preview Cleanup on: pull_request: types: [closed] jobs: doc-preview-cleanup: runs-on: ubuntu-latest steps: - name: Checkout gh-pages branch uses: actions/checkout@v6 with: ref: gh-pages - name: Delete preview and history + push changes run: | if [ -d "previews/PR$PRNUM" ]; then git config user.name "Documenter.jl" git config user.email "documenter@juliadocs.github.io" git rm -rf "previews/PR$PRNUM" git commit -m "delete preview" git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) git push --force origin gh-pages-new:gh-pages fi env: PRNUM: ${{ github.event.number }} ================================================ FILE: .github/workflows/pr_comment.yml ================================================ name: pr_comment on: pull_request: types: [labeled] jobs: pr_comment: runs-on: ubuntu-latest steps: - name: Create PR comment if: github.event_name == 'pull_request' && github.repository == github.event.pull_request.head.repo.full_name && github.event.label.name == 'documentation' # if this is a pull request build AND the pull request is NOT made from a fork uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b with: message: 'Once the documentation build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/NNlib.jl/previews/PR${{ github.event.number }}/' GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .gitignore ================================================ *.jl.cov *.jl.*.cov *.jl.mem *.o *.so *.dylib *.dll *~ \#* deps/usr deps.jl *.log .vscode/ /Manifest.toml test/Manifest.toml benchmark/Manifest.toml benchmark/*.json benchmark/report.md ================================================ FILE: LICENSE.md ================================================ The NNlib.jl package is licensed under the MIT "Expat" License: > Copyright (c) 2017-19: Julia Computing, Inc., Mike J Innes, and Contributors > > Permission is hereby granted, free of charge, to any person obtaining a copy > of this software and associated documentation files (the "Software"), to deal > in the Software without restriction, including without limitation the rights > to use, copy, modify, merge, publish, distribute, sublicense, and/or sell > copies of the Software, and to permit persons to whom the Software is > furnished to do so, subject to the following conditions: > > The above copyright notice and this permission notice shall be included in all > copies or substantial portions of the Software. > > THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR > IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, > FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE > AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER > LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, > OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE > SOFTWARE. > ================================================ FILE: Project.toml ================================================ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.9.34" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" NNlibForwardDiffExt = "ForwardDiff" NNlibMetalExt = "Metal" NNlibSpecialFunctionsExt = "SpecialFunctions" [compat] AMDGPU = "1, 2" Adapt = "3.2, 4" Atomix = "0.1, 1" CUDA = "4, 5, 6" ChainRulesCore = "1.25" EnzymeCore = "0.7, 0.8" FFTW = "1.8.0" ForwardDiff = "1" GPUArraysCore = "0.2" KernelAbstractions = "0.9.2" LinearAlgebra = "1" Metal = "1.6" Random = "1" ScopedValues = "1.3.0" SpecialFunctions = "2" Statistics = "1" cuDNN = "1, 6" julia = "1.10" ================================================ FILE: README.md ================================================ # NNlib.jl [![Documentation][docs-dev-img]][docs-dev-url] [![CI](https://github.com/FluxML/NNlib.jl/actions/workflows/ci.yml/badge.svg)](https://github.com/FluxML/NNlib.jl/actions/workflows/ci.yml) [![Coverage](https://codecov.io/gh/FluxML/NNlib.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/FluxML/NNlib.jl) [docs-stable-img]: https://img.shields.io/badge/docs-stable-blue.svg [docs-stable-url]: https://fluxml.ai/NNlib.jl/stable/ [docs-dev-img]: https://img.shields.io/badge/docs-latest-blue.svg [docs-dev-url]: https://fluxml.ai/NNlib.jl/dev/ This package provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multiplication, convolutions and pooling. Many of these are used by [Flux.jl](https://github.com/FluxML/Flux.jl), which loads this package, but they may be used independently. For use with automatic differentiation, this package defines gradients using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). These will be seen by various packages including [Zygote.jl](https://github.com/FluxML/Zygote.jl). GPU support is provided as package extensions (see the `ext/` folder). In order to load the extensions, use the imports ```julia using NNlib, CUDA, cuDNN ``` for CUDA support, or ```julia using NNlib, AMDGPU ``` for AMDGPU support. ================================================ FILE: benchmark/Project.toml ================================================ [deps] ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" BenchmarkCI = "20533458-34a3-403d-a444-e18f38190b5b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" [compat] # No compat bounds for NNlib because we may test breaking versions ArgParse = "1" BenchmarkCI = "0.1" BenchmarkTools = "1.3" PkgBenchmark = "0.2" julia = "1.6" ================================================ FILE: benchmark/benchmarks.jl ================================================ using BenchmarkTools using NNlib using NNlib.ChainRulesCore: rrule using Random Random.seed!(1234567890) const SUITE = BenchmarkGroup() SUITE["activations"] = BenchmarkGroup() for et in (Float16, Float32, Float64) et_suite = BenchmarkGroup() SUITE["activations"][string(et)] = et_suite let x = rand(et, 1024, 1024), y = similar(x) for f in NNlib.ACTIVATIONS act = @eval($f) et_suite[string(f)] = @benchmarkable broadcast!($act, $y, $x) end end end for (fn!, fn_bw) in [(softmax!, NNlib.∇softmax_data), (logsoftmax!, NNlib.∇logsoftmax_data)] fn_suite = BenchmarkGroup() SUITE[rstrip(string(fn!), '!')] = fn_suite let SIZES = [ (128, 384, 8), (512, 784, 8), (768, 1024, 4), (1024, 2048, 4), (2048, 2048, 2), (4096, 2048, 2), (4096, 4096, 2), (12288, 2048, 1) ] for et in (Float16, Float32) et_suite = BenchmarkGroup("fw" => BenchmarkGroup(), "bw" => BenchmarkGroup()) fn_suite[string(et)] = et_suite for sz in SIZES x = randn(et, sz) y = similar(x) dy = zero(x) fn!(y, x) et_suite["fw"][string(sz)] = @benchmarkable $fn!($y, $x) et_suite["bw"][string(sz)] = @benchmarkable $fn_bw($dy, $y) end end end end ================================================ FILE: benchmark/perf_report.jl ================================================ using JLD2, NNlib, BenchmarkTools # TODO organize and compare benchmarks using BenchmarkGroups # We need things to go quickly here BenchmarkTools.DEFAULT_PARAMETERS.samples = 20 BenchmarkTools.DEFAULT_PARAMETERS.seconds = 2.5 results = Dict() function add_result(val, keys...) r = results for k in keys[1:end-1] if !haskey(r, k) r[k] = Dict() end r = r[k] end r[keys[end]] = val return r end # Modify these as needed for rank in (2,), N in (20, 40, 80), C_in in (1,), C_out in (1,), K in (3,), stride in (1,), dilation in (1,), padding in (0, 2) benchmark_items = [ (NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, "direct"), (NNlib.conv_im2col!, NNlib.∇conv_data_im2col!, NNlib.∇conv_filter_im2col!, DenseConvDims, "im2col"), (NNlib.depthwiseconv_direct!, NNlib.∇depthwiseconv_data_direct!, NNlib.∇depthwiseconv_filter_direct!, DepthwiseConvDims, "direct"), (NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, "im2col"), ] for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in benchmark_items x = zeros(Float32, repeat([N], rank)..., C_in, 1) if cT == DenseConvDims w = zeros(Float32, repeat([K], rank)..., C_in, C_out) else w = zeros(Float32, repeat([K], rank)..., C_out, C_in) end cdims = try cT(x, w; stride=stride, dilation=dilation, padding=padding) catch continue end if cT == DenseConvDims y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1) else y = zeros(Float32, NNlib.output_size(cdims)..., C_out*C_in, 1) end dx = similar(x) dw = similar(w) dy = similar(y) t_fwd = @benchmark $(conv!)($y, $x, $w, $cdims) t_dx = @benchmark $(∇conv_data!)($dx, $y, $w, $cdims) t_dw = @benchmark $(∇conv_filter!)($dw, $x, $y, $cdims) add_result(t_fwd, "conv$(rank)d", backend, cdims) add_result(t_dx, "conv$(rank)d_data", backend, cdims) add_result(t_dw, "conv$(rank)d_filter", backend, cdims) @show(cdims) @save "results.jld2" results end end # Modify these as needed for rank in (2,), N in (20,), K in (2, 4), stride in (1, 2, 4) x = zeros(Float32, repeat([N], rank)..., 1, 1) pdims = PoolDims(x, K; stride=stride) y = zeros(Float32, NNlib.output_size(pdims)..., 1, 1) dx = similar(x) for (pool, ∇pool, name) in ( (NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"), (NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"), (NNlib.lpnormpool!, NNlib.∇lpnormpool!, "lpnormpool"), ) t_fwd = @benchmark $(pool)( $y, $x, $pdims) t_data = @benchmark $(∇pool)($dx, $y, $y, $x, $pdims) add_result(t_fwd, "$(name)$(rank)d", "direct", pdims) add_result(t_data, "$(name)$(rank)d_data", "direct", pdims) @show(pdims) @save "results.jld2" results end end ================================================ FILE: benchmark/runbenchmarks.jl ================================================ # Adapted from # https://github.com/kul-forbes/ProximalOperators.jl/tree/master/benchmark using ArgParse using PkgBenchmark using BenchmarkCI: displayjudgement, printresultmd, CIResult using Markdown function markdown_report(judgement) md = sprint(printresultmd, CIResult(judgement = judgement)) md = replace(md, ":x:" => "❌") md = replace(md, ":white_check_mark:" => "✅") return md end function parse_commandline() s = ArgParseSettings() @add_arg_table! s begin "--target" help = "the branch/commit/tag to use as target" default = "HEAD" "--baseline" help = "the branch/commit/tag to use as baseline" default = "master" "--retune" help = "force re-tuning (ignore existing tuning data)" action = :store_false end return parse_args(s) end function main() parsed_args = parse_commandline() mkconfig(; kwargs...) = BenchmarkConfig( env = Dict( "JULIA_NUM_THREADS" => get(ENV, "JULIA_NUM_THREADS", "1"), ); kwargs... ) target = parsed_args["target"] group_target = benchmarkpkg( dirname(@__DIR__), mkconfig(id = target), resultfile = joinpath(@__DIR__, "result-$(target).json"), retune = parsed_args["retune"], ) baseline = parsed_args["baseline"] group_baseline = benchmarkpkg( dirname(@__DIR__), mkconfig(id = baseline), resultfile = joinpath(@__DIR__, "result-$(baseline).json"), ) judgement = judge(group_target, group_baseline) report_md = markdown_report(judgement) write(joinpath(@__DIR__, "report.md"), report_md) display(Markdown.parse(report_md)) end main() ================================================ FILE: docs/.gitignore ================================================ build/ site/ Manifest.toml ================================================ FILE: docs/Project.toml ================================================ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" ================================================ FILE: docs/make.jl ================================================ using Documenter, NNlib DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using FFTW, NNlib, UnicodePlots); recursive = true) makedocs(modules = [NNlib], sitename = "NNlib.jl", doctest = true, pages = ["Home" => "index.md", "Reference" => "reference.md", "Audio" => "audio.md"], format = Documenter.HTML( canonical = "https://fluxml.ai/NNlib.jl/stable/", # analytics = "UA-36890222-9", assets = ["assets/flux.css"], prettyurls = get(ENV, "CI", nothing) == "true"), warnonly=[:missing_docs,] ) deploydocs(repo = "github.com/FluxML/NNlib.jl.git", target = "build", push_preview = true) ================================================ FILE: docs/src/assets/flux.css ================================================ @import url('https://fonts.googleapis.com/css?family=Lato:400,400i'); body { font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif; } nav.toc { padding-top: 0; background: rgb(240, 240, 240); line-height: 2em; cursor: default; user-select: none; } h1+h2 { margin-top: 0; } /* Green banner in ToC */ nav.toc > h1 { margin-top: 0; padding-top: 0.4em; padding-bottom: 0.5em; border-bottom: 5px solid white; box-shadow: 0px -2px 5px rgb(60,60,60); margin-bottom: 0.5em; background: rgb(60, 150, 60); font-style: italic; font-weight: normal; font-size: 50pt; text-transform: lowercase; text-shadow: 2px 2px 5px rgba(0,0,0,0.2); color: white; } /* Reduce ToC font size */ .toctext { font-size: 10pt; } /* Fade out non-clickable ToC headers */ nav.toc ul span.toctext { color: rgb(180, 180, 180); } nav.toc ul .toctext { color: rgb(100, 100, 100); } nav.toc ul a.toctext:hover { color: inherit; background: rgb(220, 220, 220); cursor: default; } nav.toc li.current > .toctext { background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%); font-weight: normal; } nav.toc ul.internal li.toplevel { font-weight: normal; } /* Content */ article { max-width: none; } article > p, article > ul { max-width: 45em; } /* Links */ a, a:visited { color: rgb(0, 120, 0); } article p a { border-bottom: 1px solid rgb(200, 230, 200); } a:hover, a:visited:hover { color: rgb(0, 80, 0); } /* Article Links */ article p a { border-bottom: 1px solid rgb(200, 230, 200); } article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); } article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); } /* Doctstrings */ article section.docstring { padding: 0.5em 0; border-left: none; border-right: none; border-bottom: none; } /* Code */ article pre, article p > code { background: rgb(245, 250, 245); } article pre { border: none; max-width: none; padding: 1em; border-radius: 10px 0px 0px 10px; } .hljs-comment { font-style: italic; } .hljs-number { color: rgb(0, 150, 150); } ================================================ FILE: docs/src/audio.md ================================================ # Reference !!! note Spectral functions require importing `FFTW` package to enable them. ## Window functions ```@docs hann_window hamming_window ``` ## Spectral ```@docs stft istft NNlib.power_to_db NNlib.db_to_power ``` ## Spectrogram ```@docs melscale_filterbanks spectrogram ``` Example: ```@example 1 using FFTW # <- required for STFT support. using NNlib using FileIO using Makie, CairoMakie CairoMakie.activate!() waveform, sampling_rate = load("./assets/jfk.flac") fig = lines(reshape(waveform, :)) save("waveform.png", fig) # Spectrogram. n_fft = 1024 spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft)) fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1])) save("spectrogram.png", fig) # Mel-scale spectrogram. n_freqs = n_fft ÷ 2 + 1 fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate)) mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels) fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1]) save("mel-spectrogram.png", fig) nothing # hide ``` |Waveform|Spectrogram|Mel Spectrogram| |:---:|:---:|:---:| |![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)| ================================================ FILE: docs/src/index.md ================================================ # NNlib.jl `NNlib` provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multiplication, convolutions and pooling. Many of these are used by [Flux.jl](https://github.com/FluxML/Flux.jl), which loads this package, but they may be used independently. For use with automatic differentiation, this package defines gradients using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). These will be seen by various packages including [Zygote.jl](https://github.com/FluxML/Zygote.jl). GPU support is provided as package extensions. In order to load the extensions, use the imports ```julia using NNlib, CUDA, cuDNN ``` for CUDA support, or ```julia using NNlib, AMDGPU ``` for AMDGPU support. ## Threading Various `NNlib` functions utilize available julia threads on divisible workloads. To disable this use the `ScopedValue`-backed switch `NNlib.@disallow_spawns` i.e. ```julia NNlib.@disallow_spawns function_that_uses_nnlib() ``` ================================================ FILE: docs/src/reference.md ================================================ # Reference The API reference of `NNlib`. ## Activation Functions Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on. ```@docs celu elu gelu gelu_tanh gelu_sigmoid gelu_erf hardsigmoid sigmoid_fast hardtanh tanh_fast leakyrelu lisht logcosh logsigmoid mish relu relu6 rrelu selu sigmoid softplus softshrink softsign swish hardswish tanhshrink trelu ``` ## Attention ```@docs dot_product_attention dot_product_attention_scores make_causal_mask ``` ## Softmax `Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally. ```@docs softmax logsoftmax ``` ## Pooling `Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, `MeanPool` and `lpnormpool` use `NNlib.PoolDims`, `NNlib.maxpool`, `NNlib.meanpool` and `NNlib.lpnormpool` as their backend. ```@docs PoolDims maxpool meanpool lpnormpool ``` ## Padding ```@docs pad_reflect pad_symmetric pad_circular pad_repeat pad_constant pad_zeros ``` ## Convolution `Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally. `NNlib.conv` supports complex datatypes on CPU and CUDA devices. !!! note "AMDGPU MIOpen supports only cross-correlation (`flipkernel=true`)." Therefore for every regular convolution (`flipkernel=false`) kernel is flipped before calculation. For better performance, use cross-correlation (`flipkernel=true`) and manually flip the kernel before `NNlib.conv` call. `Flux` handles this automatically, this is only required for direct calls. ```@docs conv ConvDims depthwiseconv DepthwiseConvDims DenseConvDims NNlib.unfold NNlib.fold ``` ## Upsampling `Flux`'s `Upsample` layer uses `NNlib.upsample_nearest`, `NNlib.upsample_bilinear`, and `NNlib.upsample_trilinear` as its backend. Additionally, `Flux`'s `PixelShuffle` layer uses `NNlib.pixel_shuffle` as its backend. ```@docs upsample_nearest ∇upsample_nearest upsample_linear ∇upsample_linear upsample_bilinear ∇upsample_bilinear upsample_trilinear ∇upsample_trilinear pixel_shuffle ``` ## Rotation Rotate images in the first two dimensions of an array. ```@docs imrotate ∇imrotate ``` ## Batched Operations `Flux`'s `Bilinear` layer uses `NNlib.batched_mul` internally. ```@docs batched_mul batched_mul! batched_adjoint batched_transpose batched_vec ``` ## Gather and Scatter `Flux`'s `Embedding` layer uses `NNlib.gather` as its backend. ```@docs NNlib.gather NNlib.gather! NNlib.scatter NNlib.scatter! ``` ## Sampling ```@docs grid_sample ∇grid_sample ``` ## Losses ```@docs ctc_loss ``` ## Miscellaneous ```@docs logsumexp NNlib.glu NNlib.within_gradient bias_act! ``` ================================================ FILE: ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl ================================================ module NNlibAMDGPUExt using Adapt using AMDGPU using ChainRulesCore using NNlib using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans using NNlib: DenseConvDims, PoolDims const MIOPENFloat = Union{Float16, Float32} const ROCBatchedAdjoint{T} = BatchedAdjoint{T, <: ROCArray{T}} const ROCBatchedTranspose{T} = BatchedTranspose{T, <: ROCArray{T}} const ROCBatchedAdjOrTrans{T} = Union{ROCBatchedAdjoint{T}, ROCBatchedTranspose{T}} const WrappedROCBatchedAdjOrTrans{T, N} = Adapt.WrappedArray{T, N, ROCBatchedAdjOrTrans{T}, ROCBatchedAdjOrTrans{T}} const AnyROCBatchedAdjOrTrans = Union{ROCBatchedAdjOrTrans, WrappedROCBatchedAdjOrTrans} function Base.convert(::Type{T}, b::AnyROCBatchedAdjOrTrans) where {T <: Array} Base.convert(T, adapt(Array, b)) end function Base.Array{T, N}(b::AnyROCBatchedAdjOrTrans) where {T, N} Array{T, N}(adapt(Array, b)) end Base.collect(b::AnyROCBatchedAdjOrTrans) = collect(adapt(Array, b)) function Base.show( io::IO, mime::MIME{Symbol("text/plain")}, x::AnyROCBatchedAdjOrTrans, ) show(io, mime, adapt(Array, x)) end Base.show(io::IO, x::AnyROCBatchedAdjOrTrans) = show(io, adapt(Array, x)) Base.display(x::AnyROCBatchedAdjOrTrans) = display(adapt(Array, x)) function nnlib_padding(dims) pd = NNlib.padding(dims) if !all(pd[1:2:end] .== pd[2:2:end]) @warn """ MIOpen does not support asymmetric padding, defaulting to symmetric choice: $pd -> $(pd[1:2:end]). """ maxlog=1 end pd[1:2:end] end include("batched_mul.jl") @static if AMDGPU.functional(:MIOpen) using AMDGPU.MIOpen include("conv.jl") include("pool.jl") include("activations.jl") else @warn """ ROCm MIOpen is not available for AMDGPU. NNlib has limited functionality for AMDGPU. """ end end ================================================ FILE: ext/NNlibAMDGPUExt/activations.jl ================================================ for (f, op) in [ NNlib.relu => MIOpen.relu, NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6), NNlib.softplus => MIOpen.softrelu, NNlib.σ => MIOpen.sigmoid, Base.tanh => MIOpen.tanh, # TODO define for leakyrelu, elu, etc.? ], N in 1:5 @eval function Base.materialize( bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat,$N}}} ) return $op(bc.args[1]) end end Base.broadcasted(::typeof(identity), x::ROCArray{T}) where {T<:MIOPENFloat} = x ================================================ FILE: ext/NNlibAMDGPUExt/batched_mul.jl ================================================ function _blas_at(x) Base.stride(x, 1) == 1 && return x, 'N' Base.stride(x, 2) == 1 && return batched_transpose(x), 'T' throw(ArgumentError(""" Unsupported array layout for batched mul. - Size: $(size(x)) - Strides: $(strides(x)) """)) end function NNlib._batched_mul!( ::Type{AT}, C, A, B, α::Float16, β::Float16, ) where AT <: ROCArray{Float16} blasA, transA = _blas_at(A) blasB, transB = _blas_at(B) NNlib._batched_gemm!(AT, transA, transB, α, blasA, blasB, β, C) C end function NNlib._batched_gemm!( ::Type{<:ROCArray{T}}, transA::Char, transB::Char, α::T, A, B, β::T, C, ) where T <: Union{MIOPENFloat, Float64} AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C) end ================================================ FILE: ext/NNlibAMDGPUExt/conv.jl ================================================ function NNlib.conv!( y::ROCArray{T, N}, x::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims, ) where {T <: MIOPENFloat, N} if !NNlib.flipkernel(cdims) @warn """ MIOpen supports only cross-correlation (flipkernel=true). Therefore for every regular convolution (flipkernel=false) kernel is flipped before calculation. For better performance, use cross-correlation (flipkernel=true) and manually flip the kernel before `NNlib.conv` call. """ maxlog=1 flip_dims = ntuple( i -> (i ≤ ndims(w) - 2) ? (size(w, i):-1:1) : Colon(), ndims(w)) w = w[flip_dims...] end nd = max(0, 4 - N) ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd) MIOpen.convolution!( NNlib.insert_singleton_spatial_dimension(y, nd), NNlib.insert_singleton_spatial_dimension(x, nd), NNlib.insert_singleton_spatial_dimension(w, nd); padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims), dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims)) return y end function NNlib.∇conv_data!( dx::ROCArray{T, N}, dy::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims, ) where {T <: MIOPENFloat, N} if !NNlib.flipkernel(cdims) @warn """ MIOpen supports only cross-correlation (flipkernel=true). Therefore for every regular convolution (flipkernel=false) kernel is flipped before calculation. For better performance, use cross-correlation (flipkernel=true) and manually flip the kernel before `NNlib.conv` call. """ maxlog=1 flip_dims = ntuple( i -> (i ≤ ndims(w) - 2) ? (size(w, i):-1:1) : Colon(), ndims(w)) w = w[flip_dims...] end nd = max(0, 4 - N) ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd) MIOpen.∇convolution_data!( NNlib.insert_singleton_spatial_dimension(dx, nd), NNlib.insert_singleton_spatial_dimension(dy, nd), NNlib.insert_singleton_spatial_dimension(w, nd); padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims), dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims)) return dx end function NNlib.∇conv_filter!( dw::ROCArray{T, N}, x::ROCArray{T, N}, dy::ROCArray{T, N}, cdims::DenseConvDims, ) where {T <: MIOPENFloat, N} nd = max(0, 4 - N) ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd) MIOpen.∇convolution_weight!( NNlib.insert_singleton_spatial_dimension(dw, nd), NNlib.insert_singleton_spatial_dimension(dy, nd), NNlib.insert_singleton_spatial_dimension(x, nd); padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims), dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims)) if !NNlib.flipkernel(cdims) @warn """ MIOpen supports only cross-correlation (flipkernel=true). Therefore for every regular convolution (flipkernel=false) kernel is flipped before calculation. For better performance, use cross-correlation (flipkernel=true) and manually flip the kernel before `NNlib.conv` call. """ maxlog=1 flip_dims = ntuple( i -> (i ≤ ndims(dw) - 2) ? (size(dw, i):-1:1) : Colon(), ndims(dw)) dw = dw[flip_dims...] end return dw end ================================================ FILE: ext/NNlibAMDGPUExt/pool.jl ================================================ for poolname in (:maxpool, :meanpool) @eval function NNlib.$(poolname)( x::ROCArray{T, N}, pdims::PoolDims, ) where {T <: MIOPENFloat, N} y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N)) nd = max(0, 4 - N) npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd) MIOpen.$(Symbol("$(poolname)!"))( NNlib.insert_singleton_spatial_dimension(y, nd), NNlib.insert_singleton_spatial_dimension(x, nd); dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), stride=NNlib.stride(npdims), do_backward=false) return y end @eval function ChainRulesCore.rrule( ::typeof(NNlib.$(poolname)), x::ROCArray{T, N}, pdims::PoolDims, ) where {T <: MIOPENFloat, N} y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N)) nd = max(0, 4 - N) npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd) # `workspace` is used in the pullback. _, workspace = MIOpen.$(Symbol("$(poolname)!"))( NNlib.insert_singleton_spatial_dimension(y, nd), NNlib.insert_singleton_spatial_dimension(x, nd); dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), stride=NNlib.stride(npdims)) function _pooling_pullback(Δ) dx = similar(x) MIOpen.$(Symbol("∇$(poolname)!"))( NNlib.insert_singleton_spatial_dimension(dx, nd), NNlib.insert_singleton_spatial_dimension(unthunk(Δ), nd), NNlib.insert_singleton_spatial_dimension(y, nd), NNlib.insert_singleton_spatial_dimension(x, nd); dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), stride=NNlib.stride(npdims), workspace) return NoTangent(), dx, NoTangent() end y, _pooling_pullback end end ================================================ FILE: ext/NNlibCUDACUDNNExt/NNlibCUDACUDNNExt.jl ================================================ module NNlibCUDACUDNNExt using NNlib using cuDNN using CUDA using Random, Statistics using cuDNN: handle, with_workspace, cudnnTensorDescriptor, cudnnFilterDescriptor, cudnnDataType, math_mode, CUDNN_DEFAULT_REORDER, CUDNN_CROSS_CORRELATION, CUDNN_NOT_PROPAGATE_NAN, CUDNN_TENSOR_NCHW, dim4 cudnnversion() = cuDNN.version() function nnlibPadding(dims) pd = NNlib.padding(dims) if !all(pd[1:2:end] .== pd[2:2:end]) @warn "cuDNN does not support asymmetric padding; defaulting to symmetric choice" maxlog=1 end return pd[1:2:end] end include("conv.jl") include("pooling.jl") include("softmax.jl") include("activations.jl") include("batchnorm.jl") end # module ================================================ FILE: ext/NNlibCUDACUDNNExt/activations.jl ================================================ # Activation using Base.Broadcast using cuDNN: cudnnActivationForward!, cudnnOpTensor!, CUDNN_ACTIVATION_TANH, CUDNN_ACTIVATION_SIGMOID, CUDNN_ACTIVATION_ELU, CUDNN_ACTIVATION_RELU, CUDNN_ACTIVATION_CLIPPED_RELU, CUDNN_OP_TENSOR_MAX, CUDNN_ACTIVATION_IDENTITY for (f, op) in [ CUDA.tanh => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_TANH), NNlib.σ => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_SIGMOID), NNlib.elu => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_ELU), NNlib.relu => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_RELU), # NNlib.relu6 => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_CLIPPED_RELU, coef=6.0), # NNlib.leakyrelu => (src,dst)->cudnnOpTensor!(dst, src, src; op=CUDNN_OP_TENSOR_MAX, alpha1=0.01), ] @eval begin # in-place function Base.materialize!(dst::DenseCuArray{<:CUDNNFloat}, bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}}) $op(bc.args[1], dst) return dst end # out of place function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}}) ElType = Broadcast.combine_eltypes(bc.f, bc.args) dst = similar(bc, ElType) $op(bc.args[1], dst) return dst end end end # CUDNN_ACTIVATION_IDENTITY does not work with cudnnActivationForward # FIXME: put this optimization in GPUArrays' `copyto!` (like Base.Broadcast's `copyto!`) Base.broadcasted(::typeof(identity), x::DenseCuArray{T}) where {T<:CUDNNFloat} = x ================================================ FILE: ext/NNlibCUDACUDNNExt/batchnorm.jl ================================================ using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, cudnnBatchNormalizationForwardTraining import NNlib: batchnorm, ∇batchnorm # TODO: replace with new cudnn normalization interface # https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl mutable struct BNCache mean ivar end BNCache() = BNCache(nothing, nothing) @inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, running_mean, running_var, momentum; kws...) affine_sz = _wsize(x) g = fill!(similar(x, affine_sz), 1) b = fill!(similar(x, affine_sz), 0) return batchnorm(g, b, x, running_mean, running_var, momentum; kws...) end # NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations # so reshape a 2D Tensor into 4D function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2}, running_mean, running_var, momentum; kws...) where T<:CUDNNFloat x = reshape(x, 1, 1, size(x, 1), size(x, 2)) y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...) return dropdims(y, dims = (1, 2)) end function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}}, running_mean, running_var, momentum; kws...) where T<:CUDNNFloat cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...) end function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, running_mean, running_var, momentum; cache = nothing, alpha = T(1), beta = T(0), eps = T(1e-5), training = true, affine = true, track_stats = true) where T<:CUDNNFloat dims = _wsize(x) if eps < CUDNN_BN_MIN_EPSILON @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" eps = CUDNN_BN_MIN_EPSILON end if running_mean === nothing || running_var === nothing running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing")) if track_stats || !training running_mean = fill!(similar(x, dims), 0) running_var = fill!(similar(x, dims), 1) end end xd = cudnnTensorDescriptor(x) yd = cudnnTensorDescriptor(y) gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW))) if training if !track_stats running_mean = CU_NULL running_var = CU_NULL end if cache !== nothing mean = fill!(similar(x, dims), 0) ivar = fill!(similar(x, dims), 1) else mean = CU_NULL ivar = CU_NULL end cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar) if cache !== nothing cache.mean = mean cache.ivar = ivar end else if track_stats cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) else # cudnnBatchNormalizationForwardInference does not accept CV_NULL for running_mean # and running_var. We could calculate mean and var of `x` here, but instead use # cudnnBatchNormalizationFowardTraining. cudnnBatchNormalizationForwardTraining does # accept CV_NULL and will calculate mean and var itself. cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, CU_NULL, CU_NULL, eps, CU_NULL, CU_NULL) end end return y end function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray, running_mean, running_var, momentum; kws...) affine_sz = _wsize(x) g = fill!(similar(x, affine_sz), 1) b = fill!(similar(x, affine_sz), 0) return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...) end function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2}, running_mean, running_var, momentum; kws...) where T<:CUDNNFloat dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), size(dy, 2)), running_mean, running_var, momentum; kws...) (dg, db, dropdims(dx, dims = (1, 2))) end function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, running_mean, running_var, momentum; affine=true, kws...) where T<:CUDNNFloat dg = similar(g) db = similar(b) dx = similar(x) cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...) if affine (dg, db, dx) else # cuDNN always calculates dg and db, therefore we just have to drop them (nothing, nothing, dx) end end function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T}, dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, running_mean, running_var, momentum; cache = nothing, eps = T(1e-5), alpha = T(1), beta = T(0), dalpha = T(1), dbeta = T(0), training = true, track_stats = true) where T<:CUDNNFloat if !track_stats running_mean = CU_NULL running_var = CU_NULL end xd = cudnnTensorDescriptor(x) dyd = cudnnTensorDescriptor(dy) dxd = cudnnTensorDescriptor(dx) gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW))) if cache !== nothing @debug "fetching mean and ivar from the cache" mean, ivar = cache.mean, cache.ivar else mean, ivar = CU_NULL, CU_NULL end if eps < CUDNN_BN_MIN_EPSILON @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON" eps = CUDNN_BN_MIN_EPSILON end cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar) end ================================================ FILE: ext/NNlibCUDACUDNNExt/conv.jl ================================================ using NNlib: DenseConvDims import NNlib: conv!, ∇conv_filter!, ∇conv_data!, conv_bias_act! using cuDNN: scalingParameter, CUDNN_CONVOLUTION, convdims, cudnnConvolutionBwdDataAlgoPerf, cudnnConvolutionForward!, cudnnConvolutionBwdFilterAlgoPerf, cudnnConvolutionBackwardData, cudnnConvolutionBackwardFilter, cudnnConvolutionBackwardBias import cuDNN: cudnnConvolutionDescriptor const CUDNNFloat = Union{Float16,Float32,Float64} const CUDNNComplexFloat = Union{ComplexF16,ComplexF32,ComplexF64} function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T # The main purpose of this function is to catch asymmetric padding which cudnn does not support # If we find asymmetric padding we'll make a copy of x which is manually padded so that we can # call cudnn with symmetric padding. pad = NNlib.padding(cdims) sdims = NNlib.spatial_dims(cdims) all(i -> pad[i] .== pad[i+1], 1:2:2sdims) && return (cudnnConvolutionDescriptor(cdims, x), x, identity) # Naive implementation, is there a faster way? # How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top # and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out # which side of x to pad. Oh, and we use a CartesianIndex as we will mainly use this to index in x pad_manual = CartesianIndex(ntuple(i -> i > sdims ? 0 : pad[2(i-1)+1] - pad[2(i-1)+2], ndims(x))) # How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top # and pad_bottom etc. respectively pad_cudnn = ntuple(i -> min(pad[2(i-1)+1], pad[2(i-1)+2]), sdims) x_padded_size = ntuple(i -> i <= sdims ? size(x, i) + abs(pad_manual[i]) : size(x ,i), ndims(x)) x_padded = similar(x, x_padded_size) fill!(x_padded, 0) # This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x # Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim # by dim to an array in a loop xIs = CartesianIndices(x) xI_first = first(xIs) xI_last = last(xIs) xIs_pad = max(xI_first, xI_first + pad_manual) : max(xI_last, xI_last + pad_manual) x_padded[xIs_pad] = x return cudnnConvolutionDescriptor(cdims, x_padded, pad_cudnn), x_padded, _x -> _x[xIs_pad] end function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pad = nnlibPadding(cdims)) where T mode=(NNlib.flipkernel(cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION) cudnnConvolutionDescriptor(convdims(pad, size(x),0), convdims(NNlib.stride(cdims),size(x),1), convdims(NNlib.dilation(cdims),size(x),1), mode, cudnnDataType(real(T)), math_mode(), CUDNN_DEFAULT_REORDER, Cint(NNlib.groupcount(cdims))) end @inline function _complex!(y::DenseCuArray{T1}, yr::DenseCuArray{T2}, yi::DenseCuArray{T2}; bias=zero(T1), alpha=one(T1), beta=zero(T1), σ=identity) where {T1 <: CUDNNComplexFloat, T2<:CUDNNFloat} # if y is from similar(), it may have NaNs, and beta*NaN will propagate. if beta != 0 @. y = σ(alpha*(yr + im*yi) + bias + beta*y) else @. y = σ(alpha*(yr + im*yi) + bias) end return y end function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat if cudnnversion() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end if algo != -1 @warn "algo option has been deprecated, the fastest algo is computed automatically" maxlog=1 end d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x) cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y) end # Complex convolution with Gauss's trick (1 complex mul === 3 real mul): # Consider x = xr + im*xi, y = yr + im*yi, # so x*y = (xr*yr - xi*yi) + im*(xr*yi + xi*yr). # Let a = xr*yr, # b = xi*yi, # c = (xr + xi)*(yr + yi) = xr*yr + xr*yi + xi*yr + xi*yi. # Then, # x*y = (a - b) + im*(c - a - b). # Convolution is linear so this multiplication trick translates to convolution. function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat xr, xi = reim(x) wr, wi = reim(w) a = conv!(similar(real(y)), xr, wr, cdims; algo=algo) b = conv!(similar(a), xi, wi, cdims; algo=algo) c = conv!(similar(a), xr + xi, wr + wi, cdims; algo=algo) return _complex!(y, a - b, c - a - b; alpha=alpha, beta=beta) end # (xr + im*xi) * w = xr*w + im*(xi*w) function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T1}, w::DenseCuArray{T2}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat} xr, xi = reim(x) yr = conv!(similar(real(y)), xr, w, cdims; algo=algo) yi = conv!(similar(yr), xi, w, cdims; algo=algo) return _complex!(y, yr, yi; alpha=alpha, beta=beta) end # x * (wr + im*wi) = x*wr + im*(x*wi) function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T2}, w::DenseCuArray{T1}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat} wr, wi = reim(w) yr = conv!(similar(real(y)), x, wr, cdims; algo=algo) yi = conv!(similar(yr), x, wi, cdims; algo=algo) return _complex!(y, yr, yi; alpha=alpha, beta=beta) end function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity; z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat if cudnnversion() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end if algo != -1 @warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1 end d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x) # only relu and identity are supported by cudnnConvolutionForward! activation = (σ == NNlib.relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY) cudnnConvolutionForward!(y, w, x, d; z, bias, activation, alpha, beta) if activation === CUDNN_ACTIVATION_IDENTITY && σ ∉ (nothing, identity) @. y = σ(y) end return y end function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity; z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat xr, xi = reim(x) wr, wi = reim(w) a = conv!(similar(real(y)), xr, wr, cdims; alpha=1, beta=0, algo=algo) b = conv!(similar(a), xi, wi, cdims; alpha=1, beta=0, algo=algo) c = conv!(similar(a), xr + xi, wr + wi, cdims; alpha=1, beta=0, algo=algo) return _complex!(y, a - b, c - a - b; bias=bias, alpha=alpha, beta=beta, σ=σ) end function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat if cudnnversion() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end if algo != -1 @warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1 end alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta); convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx) xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w) p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx, beta!=0) with_workspace(p.memory) do workspace cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx) end return depad(dx) end function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat dyr, dyi = reim(dy) wr, wi = reim(w) # note: w is conjugated, i.e. wi is negated below a = ∇conv_data!(similar(real(dx)), dyr, wr, cdims; alpha=1, beta=0, algo=algo) b = ∇conv_data!(similar(a), dyi, -wi, cdims; alpha=1, beta=0, algo=algo) c = ∇conv_data!(similar(a), dyr + dyi, wr - wi, cdims; alpha=1, beta=0, algo=algo) return _complex!(dx, a - b, c - a - b; alpha=alpha, beta=beta) end # dx = (dyr + im*dyi)*w = dyr*w + im*(dyi*w) function ∇conv_data!(dx::DenseCuArray{T1}, dy::DenseCuArray{T1}, w::DenseCuArray{T2}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat} dyr, dyi = reim(dy) dxr = ∇conv_data!(similar(real(dx)), dyr, w, cdims; alpha=1, beta=0, algo=algo) dxi = ∇conv_data!(similar(dxr), dyi, w, cdims; alpha=1, beta=0, algo=algo) return _complex!(dx, dxr, dxi; alpha=alpha, beta=beta) end function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat if cudnnversion() < v"6" all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6") end if algo != -1 @warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1 end alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta); convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x) xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw) p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw, beta!=0); with_workspace(p.memory) do workspace cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw); end return dw end function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat xr, xi = reim(x) dyr, dyi = reim(dy) # note: x is conjugated, i.e. xi is negated below a = ∇conv_filter!(similar(real(dw)), xr, dyr, cdims; alpha=1, beta=0, algo=algo) b = ∇conv_filter!(similar(a), -xi, dyi, cdims; alpha=1, beta=0, algo=algo) c = ∇conv_filter!(similar(a), xr - xi, dyr + dyi, cdims; alpha=1, beta=0, algo=algo) return _complex!(dw, a - b, c - a - b; alpha=alpha, beta=beta) end # dw = x*(dyr + im*dyi) = x*dyr + im*(x*dyi) function ∇conv_filter!(dw::DenseCuArray{T1}, x::DenseCuArray{T2}, dy::DenseCuArray{T1}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat} dyr, dyi = reim(dy) dwr = ∇conv_filter!(similar(real(dw)), x, dyr, cdims; alpha=1, beta=0, algo=algo) dwi = ∇conv_filter!(similar(dwr), x, dyi, cdims; alpha=1, beta=0, algo=algo) return _complex!(dw, dwr, dwi; alpha=alpha, beta=beta) end function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNFloat alpha,beta = scalingParameter(T,alpha), scalingParameter(T,beta) bDesc, yDesc = cudnnTensorDescriptor.((db,dy)) cudnnConvolutionBackwardBias(handle(), alpha, yDesc, dy, beta, bDesc, db) return db end function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNComplexFloat dyr, dyi = reim(dy) dbr = ∇conv_bias!(similar(real(db)), dyr; alpha=1, beta=0) dbi = ∇conv_bias!(similar(dbr), dyi; alpha=1, beta=0) return _complex!(db, dbr, dbi; alpha=alpha, beta=beta) end ================================================ FILE: ext/NNlibCUDACUDNNExt/pooling.jl ================================================ using cuDNN: cudnnPoolingMode_t, CUDNN_POOLING_MAX, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING, cudnnPoolingForward!, pooldims, cudnnPoolingBackward import NNlib: maxpool!, ∇maxpool!, meanpool!, ∇meanpool! import cuDNN: cudnnPoolingDescriptor function cudnnPoolingDescriptor(pdims::PoolDims, x::DenseCuArray{T}, mode::cudnnPoolingMode_t) where T window, padding, stride = NNlib.kernel_size(pdims), nnlibPadding(pdims), NNlib.stride(pdims) nanOpt = CUDNN_NOT_PROPAGATE_NAN cudnnPoolingDescriptor(mode, nanOpt, Cint(ndims(x)-2), pooldims(window,size(x)), pooldims(padding,size(x)), pooldims(stride,size(x))) end function maxpool!(y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_MAX) cudnnPoolingForward!(y, x, d) end function ∇maxpool!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat xDesc, yDesc = cudnnTensorDescriptor.((x, y)) d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_MAX) alpha, beta = scalingParameter(T,1), scalingParameter(T,0) cudnnPoolingBackward(handle(), d, alpha, yDesc, y, yDesc, dy, xDesc, x, beta, xDesc, dx) return dx end function meanpool!(y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING) cudnnPoolingForward!(y, x, d) end function ∇meanpool!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat xDesc, yDesc = cudnnTensorDescriptor.((x, y)) d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING) alpha, beta = scalingParameter(T,1), scalingParameter(T,0) cudnnPoolingBackward(handle(), d, alpha, yDesc, y, yDesc, dy, xDesc, x, beta, xDesc, dx) return dx end ### Since CUDA.jl does not support 1D pooling, we have to convert to 2d add1d(x) = reshape(x, 1, size(x)...) function fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D} PoolDims{2, K + 1, S + 1, P + 2, D + 1}((1, NNlib.input_size(pdims)...), (1, NNlib.kernel_size(pdims)...), NNlib.channels_in(pdims), (1, NNlib.stride(pdims)...), (0, 0, NNlib.padding(pdims)...), (1, NNlib.dilation(pdims)...)) end function maxpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat maxpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims)) return y end function meanpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat meanpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims)) return y end function ∇maxpool!(dx::DenseCuArray{T,3}, dy::DenseCuArray{T,3}, y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat ∇maxpool!(add1d(dx), add1d(dy), add1d(y), add1d(x), fix_pooldims_1d(pdims)) return dx end function ∇meanpool!(dx::DenseCuArray{T,3}, dy::DenseCuArray{T,3}, y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat ∇meanpool!(add1d(dx), add1d(dy), add1d(y), add1d(x), fix_pooldims_1d(pdims)) return dx end ================================================ FILE: ext/NNlibCUDACUDNNExt/softmax.jl ================================================ import NNlib: softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax! using cuDNN: CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, CUDNN_SOFTMAX_FAST, CUDNN_SOFTMAX_ACCURATE, cudnnSoftmaxForward!, cudnnSoftmaxBackward # Softmax # @denizyuret: do not do inplace operations with softmax/logsoftmax when (1) cpu version is not, (2) one can use softmax! function softmax(x::T; dims=1) where {T<:DenseCuArray} softmax!(similar(x), x; dims) end function ∇softmax(dy::T, x::T, y::T; dims=1) where {T<:DenseCuArray} ∇softmax!(similar(x), dy, x, y; dims) end function logsoftmax(x::T; dims=1) where {T<:DenseCuArray} logsoftmax!(similar(x), x; dims) end function ∇logsoftmax(dy::T, x::T, y::T; dims=1) where {T<:DenseCuArray} ∇logsoftmax!(similar(x), dy, x, y; dims) end # @denizyuret: backup implementations for unsupported/slow size/dims combinations: function _softmax!(y::T, x::T; dims) where {T<:DenseCuArray} y .= exp.(x .- maximum(x; dims)) y ./= sum(y; dims) end function _∇softmax!(dx::T, dy::T, x::T, y::T; dims) where {T<:DenseCuArray} dx .= y .* (dy .- sum(dy .* y; dims)) end function _logsoftmax!(y::T, x::T; dims) where {T<:DenseCuArray} y .= x .- maximum(x; dims) y .-= log.(sum(exp.(y); dims)) end function _∇logsoftmax!(dx::T, dy::T, x::T, y::T; dims) where {T<:DenseCuArray} dx .= dy .- sum(dy; dims) .* exp.(y) end # Trick by @norci to use cudnn for softmax dims args that are contiguous: # If dims=(dmin:dmax) then CUDNN_SOFTMAX_MODE_CHANNEL does the trick with reshape # (1, prod(size(x)[1:dmin-1]), prod(size(x)[dmin:dmax]), :) # softmaxdims returns nothing when the backup implementation should be used. function softmaxdims(x, dims) dims === Colon() && return (1, 1, length(x), 1) mind,maxd = minimum(dims),maximum(dims) all(i in dims for i in mind:maxd) || return nothing # cannot handle if not contiguous stride = dimsize = 1 for i in 1:(mind-1); stride *= size(x,i); end # Using size(x,i) assumes trailing dims = 1, robust to maxd > ndims(x) for i in mind:maxd; dimsize *= size(x,i); end batchsize = length(x)÷(stride*dimsize) # Here is a region where cudnn is slower, so we go with the backup: batchsize == 1 && 64 <= stride <= 4096 && 64 <= dimsize <= 4096 && return nothing return (1, stride, dimsize, batchsize) end # Determine softmax algo based on math_mode softmaxalgo() = (CUDA.math_mode()===CUDA.FAST_MATH ? CUDNN_SOFTMAX_FAST : CUDNN_SOFTMAX_ACCURATE) # Main implementations: function softmax!(y::T, x::T = y; dims=1) where {T<:DenseCuArray} s = softmaxdims(x, dims) s === nothing && return _softmax!(y, x; dims) cudnnSoftmaxForward!(reshape(y,s), reshape(x,s); mode = CUDNN_SOFTMAX_MODE_CHANNEL, algo = softmaxalgo()) return y end function ∇softmax!(dx::T, dy::T, x::T, y::T; dims=1) where {R,T<:DenseCuArray{R}} s = softmaxdims(x, dims) s === nothing && return _∇softmax!(dx, dy, x, y; dims) xDesc = cudnnTensorDescriptor(reshape(x,s)) alpha, beta = scalingParameter(R,1), scalingParameter(R,0) cudnnSoftmaxBackward(handle(), softmaxalgo(), CUDNN_SOFTMAX_MODE_CHANNEL, alpha, xDesc, y, xDesc, dy, beta, xDesc, dx) return dx end function logsoftmax!(y::T, x::T = y; dims=1) where {T<:DenseCuArray} s = softmaxdims(x, dims) s === nothing && return _logsoftmax!(y, x; dims) cudnnSoftmaxForward!(reshape(y,s), reshape(x,s); mode = CUDNN_SOFTMAX_MODE_CHANNEL, algo = CUDNN_SOFTMAX_LOG) return y end function ∇logsoftmax!(dx::T, dy::T, x::T, y::T; dims=1) where {R,T<:DenseCuArray{R}} s = softmaxdims(x, dims) s === nothing && return _∇logsoftmax!(dx, dy, x, y; dims) xDesc = cudnnTensorDescriptor(reshape(x,s)) alpha, beta = scalingParameter(R,1), scalingParameter(R,0) cudnnSoftmaxBackward(handle(), CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, xDesc, y, xDesc, dy, beta, xDesc, dx) return dx end ================================================ FILE: ext/NNlibCUDAExt/NNlibCUDAExt.jl ================================================ module NNlibCUDAExt using NNlib using CUDA using Random, Statistics include("sampling.jl") include("activations.jl") include("batchedadjtrans.jl") include("batchedmul.jl") include("ctc.jl") include("scatter.jl") include("utils.jl") end # module ================================================ FILE: ext/NNlibCUDAExt/activations.jl ================================================ # Activation functions # Some of activation functions need a wrapper for GPU support # https://github.com/JuliaGPU/CuArrays.jl/issues/614 # @cufunc softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x))) # @cufunc logσ(x::Real) = -softplus(-x) # @cufunc function gelu(x::Real) # p = oftype(x / 1, π) # λ = oftype(x / 1, √(2 / p)) # α = oftype(x / 1, 0.044715) # h = oftype(x / 1, 0.5) # h * x * (one(x) + tanh(λ * (x + α * x^3))) # end # @cufunc lisht(x::Real) = x * tanh(x) # @cufunc logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2)) # @cufunc mish(x::Real) = x * tanh(softplus(x)) # @cufunc tanhshrink(x::Real) = x - tanh(x) ================================================ FILE: ext/NNlibCUDAExt/batchedadjtrans.jl ================================================ using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans using Adapt using Adapt: WrappedArray const CuBatchedAdjoint{T} = BatchedAdjoint{T, <: CuArray{T}} const CuBatchedTranspose{T} = BatchedTranspose{T, <: CuArray{T}} const CuBatchedAdjOrTrans{T} = Union{CuBatchedAdjoint{T}, CuBatchedTranspose{T}} const WrappedCuBatchedAdjOrTrans{T, N} = WrappedArray{T, N, CuBatchedAdjOrTrans{T}, CuBatchedAdjOrTrans{T}} Base.print_array(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b)) Base._show_nonempty(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix) Base.show_vector(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls) Base.convert(::Type{T}, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b)) Base.Array{T, N}(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b)) Base.collect(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = collect(adapt(Array, b)) ================================================ FILE: ext/NNlibCUDAExt/batchedmul.jl ================================================ # Batched matrix multiplication # 1st argument is produced by NNlib.storage_type(A) NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) Base.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} = Base.unsafe_convert(CuPtr{T}, parent(A)) ================================================ FILE: ext/NNlibCUDAExt/ctc.jl ================================================ # CTC loss moved from Flux.jl to NNlib import NNlib: ctc_loss, ctc_alpha, ∇ctc_loss ## GPU implementation # a port of the GPU kernels from Baidu's C++ warp-ctc package, # which itself is Copyright 2015-2016 Baidu USA LLC # and available under the Apache 2.0 license # # Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0 # GitHub: https://github.com/baidu-research/warp-ctc/ # paper: https://arxiv.org/pdf/1512.02595.pdf const MAX_THREADS = 256 @inline function log_plus_f(p1, p2) isinf(p1) && return p2 isinf(p2) && return p1 if p1 < p2 p1, p2 = p2, p1 end return p1 + log(1+exp(p2 - p1)) end function count_repeats(A) repeats = 0 for (i,elem) in enumerate(A) if i > 1 && A[i] == A[i-1] repeats += 1 end end return repeats end function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel) tid = threadIdx().x L = labelSize T = uttLength S = length(labelsWithBlanks) if L + repeats > T return nothing end labels = labelsWithBlanks # Corner-case checking start = (L + repeats <= T) ? 0 : 1 last = S > 1 ? 2 : 1 # Fill in first column (time step) i = tid while i <= last - start alpha[start+i, 1] = probs[labels[start+i], 1] i += blockDim().x end sync_threads() # Fill in coefficients for each time step for t=2:T # Corner-case checking if tid == 1 && !(1 < S - 2*(T-t) - 1) if start == 0 alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t] elseif start == 1 alpha[1, t] = alpha[1, t-1] end end sync_threads() # Fill in coefficients for each label class in the target output sequence; # each thread will process the calculations for one class idx = tid+1 while idx <= S prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1]) if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] prevSum = log_plus_f(prevSum, alpha[idx-2, t-1]) end if idx < S - 2*(T-t) - 1 alpha[idx, t] = -Inf32 else alpha[idx, t] = prevSum + probs[labels[idx], t] end idx += blockDim().x end sync_threads() end return nothing end function compute_beta_and_grad_kernel(probs, labelSize, uttLength, repeatsInLabel, labelsWithBlanks, alphas, beta, output, accum, grad, blankLabel, loss) tid = threadIdx().x L = labelSize T = uttLength S = 2*L + 1 repeats = repeatsInLabel labels = labelsWithBlanks if (L+repeats) > T return nothing end # Corner-case checking start = S > 1 ? S-2 : 0 last = L + repeats < T ? S : S-1 sync_threads() i = tid # Calculate coefficients for last column (time step) # then determine alpha and beta product while i <= last - start beta[i+start, T] = 0 output[i+start, T] = beta[i+start, T] + alphas[i+start, T] i += blockDim().x end sync_threads() # Fill in `accum` for last column (time step) if tid == 1 for i=1:S labelIdx = labels[i] accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T]) end end sync_threads() # Fill in `grad` for last column (time step) idx = tid while idx <= size(grad, 1) s = -Inf32 for i=1:S s = log_plus_f(s, output[i, T]) end # ∂L/∂a (where a is activation before logsoftmax) grad[idx, T] = exp(probs[idx, T]) - exp(accum[idx, T] - s) idx += blockDim().x end sync_threads() # Fill in the rest of the coefficients t = T-1 while t >= 1 if t < T idx = tid while idx <= S nextSum = probs[labels[idx], t+1] + beta[idx, t+1] if idx < S nextSum = log_plus_f(nextSum, probs[labels[idx+1], t+1] + beta[idx+1, t+1]) end if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] nextSum = log_plus_f(nextSum, probs[labels[idx+2], t+1] + beta[idx + 2, t+1]) end if idx > 2*t beta[idx, t] = -Inf32 else beta[idx, t] = nextSum end idx += blockDim().x end sync_threads() idx = tid while idx <= S output[idx, t] = alphas[idx, t] + beta[idx, t] idx += blockDim().x end sync_threads() end sync_threads() # Calculate accumulated alpha-beta products for each label class for # each time step; used in calculating gradients if tid == 1 for i=1:S labelIdx = labels[i] accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t]) end end sync_threads() idx = tid # Calculate gradients while idx <= size(grad, 1) # ∂L/∂a (where a is activation before logsoftmax) grad[idx, t] = exp(probs[idx, t]) - exp(accum[idx, t] + loss) idx += blockDim().x end sync_threads() t -= 1 sync_threads() end return nothing end function ctc_alpha(ŷ::CuArray, y) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) ycu = cu(y) z′ = CUDA.fill(blank, 2 * length(y) + 1) z′[eachindex(y) .* 2] .= ycu T = size(ŷ, 2) U′ = 2*length(y) + 1 alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T) nRepeats = count_repeats(CUDA.adapt(Array, y)) nThreads = min(U′, MAX_THREADS) @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank) return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats) end ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss function ∇ctc_loss(ŷ::CuArray, y, out) loss, alphas, z′, ŷ, nRepeats = out U′, T = size(alphas) blank = size(ŷ, 1) typed_zero = zero(eltype(ŷ)) betas = CUDA.fill(log(typed_zero), U′, T) output = CUDA.fill(log(typed_zero), U′, T) nThreads = min(U′, MAX_THREADS) grads = CUDA.fill(log(typed_zero), size(ŷ)) accum = CUDA.fill(log(typed_zero), size(ŷ)) @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss) return grads end ================================================ FILE: ext/NNlibCUDAExt/sampling.jl ================================================ @inline function NNlib._safe_add!(dx::CuDeviceArray{T, 4}, value, ix, iy, c, n) where T @inbounds CUDA.@atomic dx[ix, iy, c, n] += value end function grid_sample_kernel!(n_elem, output::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V} index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x if index < n_elem iW, iH, iC, _ = size(input) _, gW, gH, _ = size(grid) w = index % gW + 1 h = (index ÷ gW) % gH + 1 n = index ÷ (gW * gH) + 1 NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, n, iW, iH, iC) end nothing end function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 4}, dgrid::AbstractArray{V, 4}, Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V} index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x if index < n_elem iW, iH, iC, _ = size(input) _, gW, gH, _ = size(grid) w = index % gW + 1 h = (index ÷ gW) % gH + 1 n = index ÷ (gW * gH) + 1 NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, n, iW, iH, iC) end nothing end function NNlib.grid_sample(x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V} pad = Val(padding_mode) _, _, xC, xN = size(x) _, gW, gH, _ = size(grid) n_elem = gW * gH * xN y = similar(x, T, (gW, gH, xC, xN)) kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad) config = launch_configuration(kernel.fun; max_threads=256) threads = min(n_elem, config.threads) blocks = cld(n_elem, threads) kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks) y end function NNlib.∇grid_sample(Δ::CuArray{T, 4}, x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V} pad = Val(padding_mode) xN = size(x, 4) _, gW, gH, _ = size(grid) n_elem = gW * gH * xN dx, dgrid = CUDA.zeros(T, size(x)), similar(grid) kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad) config = launch_configuration(kernel.fun; max_threads=256) threads = min(n_elem, config.threads) blocks = cld(n_elem, threads) kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks) dx, dgrid end @inline function NNlib._safe_add!(dx::CuDeviceArray{T, 5}, value, ix, iy, iz, c, n) where T @inbounds CUDA.@atomic dx[ix, iy, iz, c, n] += value end function grid_sample_kernel!(n_elem, output::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V} index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x if index < n_elem iW, iH,iD, iC, _ = size(input) _, gW, gH, gD, _ = size(grid) w = index % gW + 1 h = (index ÷ gW) % gH + 1 d = (index ÷ (gW * gH)) % gD + 1 n = index ÷ (gW * gH * gD) + 1 # n = index ÷ (gW * gH) + 1 # d = (index ÷ (gW * gH * n)) + 1 NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC) end nothing end function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 5}, dgrid::AbstractArray{V, 5}, Δ::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V} index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x if index < n_elem iW, iH, iD, iC, _ = size(input) _, gW, gH, gD, _ = size(grid) w = index % gW + 1 h = (index ÷ gW) % gH + 1 d = (index ÷ (gW * gH)) % gD + 1 n = index ÷ (gW * gH * gD) + 1 # n = index ÷ (gW * gH) + 1 # d = (index ÷ (gW * gH * n)) + 1 NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC) end nothing end function NNlib.grid_sample(x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V} pad = Val(padding_mode) _, _, _, xC, xN = size(x) _, gW, gH, gD, _ = size(grid) n_elem = gW * gH * gD * xN y = similar(x, T, (gW, gH, gD, xC, xN)) kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad) config = launch_configuration(kernel.fun; max_threads=256) threads = min(n_elem, config.threads) blocks = cld(n_elem, threads) kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks) y end function NNlib.∇grid_sample(Δ::CuArray{T, 5}, x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V} pad = Val(padding_mode) xN = size(x, 5) _, gW, gH, gD, _ = size(grid) n_elem = gW * gH * gD * xN dx, dgrid = CUDA.zeros(T, size(x)), similar(grid) kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad) config = launch_configuration(kernel.fun; max_threads=256) threads = min(n_elem, config.threads) blocks = cld(n_elem, threads) kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks) dx, dgrid end ================================================ FILE: ext/NNlibCUDAExt/scatter.jl ================================================ # supported op: +, -, *, /, max, min, &, |, mean ## TODO support sparse dst/src/idx ## See issue https://github.com/FluxML/NNlib.jl/issues/647 # import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector # const AnyCuSparseMatrix{Tv,Ti} = Union{ # AbstractCuSparseMatrix{Tv,Ti}, # CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix # CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX # CUDA.CuSparseMatrixCOO{Tv,Ti}, # } # const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}} function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= length(idx) CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src[index]) end return nothing end function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= length(idx) li = Base._to_linear_index(dst, Tuple(idx[index])...) CUDA.@atomic dst[li] = op(dst[li], src[index]) end return nothing end function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx j, k = divrem(index-1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index]) end return nothing end function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx j, k = divrem(index-1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...) CUDA.@atomic dst[li] = op(dst[li], src[index]) end return nothing end function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) where OP isempty(idx) && return dst dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) op, dst, src, idx else dims_size = size(dst)[1:dims] max_dims_idx = prod(dims_size) max_idx = max_dims_idx * length(idx) op, dst, src, idx, max_idx, max_dims_idx, dims_size end kernel = @cuda launch=false scatter_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) kernel(args...; threads=threads, blocks=blocks) return dst end function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) return dst end ## Gradients function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx cart_j = CartesianIndices(idx)[index] # get aggregating indeices, which is to be aggregated together, and itself index inds = rev_idx[idx[cart_j]...] # multiply all values to be aggregated but not itself x = one(T) for k in inds x *= src[k] end x /= src[cart_j] # apply `op` on `Δsrc[i, k]` and `x` Δsrc[cart_j] = op(Δsrc[cart_j], x) end return nothing end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx cart_j = CartesianIndices(idx)[index] # get aggregating indeices, which is to be aggregated together, and itself index inds = rev_idx[Tuple(idx[cart_j])...] # multiply all values to be aggregated but not itself x = one(T) for k in inds x *= src[k] end x /= src[cart_j] # apply `op` on `Δsrc[i, k]` and `x` Δsrc[cart_j] = op(Δsrc[cart_j], x) end return nothing end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx i, j = fldmod1(index, max_dims_idx) cart_i = CartesianIndices(idx)[i] cart_j = pre_cart_idx[j] # get aggregating indeices, which is to be aggregated together, and itself index inds = rev_idx[idx[cart_i]...] # multiply all values to be aggregated but not itself x = one(T) for k in inds jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...) x *= src[jk] end x /= src[index] # apply `op` on `Δsrc[i, k]` and `x` Δsrc[index] = op(Δsrc[index], x) end return nothing end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx i, j = fldmod1(index, max_dims_idx) cart_i = CartesianIndices(idx)[i] cart_j = pre_cart_idx[j] # get aggregating indeices, which is to be aggregated together, and itself index inds = rev_idx[Tuple(idx[cart_i])...] # multiply all values to be aggregated but not itself x = one(T) for k in inds jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...) x *= src[jk] end x /= src[index] # apply `op` on `Δsrc[i, k]` and `x` Δsrc[index] = op(Δsrc[index], x) end return nothing end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src::AnyCuArray, idx::AnyCuArray) dims = ndims(src) - ndims(idx) Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) rev_idx = NNlib.reverse_indices(idx) rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx)) if dims == 0 max_idx = length(idx) args = op, Δsrc, src, idx, rev_idx, max_idx, eltype(src) else pre_cart_idx = CartesianIndices(axes(src)[1:dims]) max_dims_idx = length(pre_cart_idx) max_idx = max_dims_idx * length(idx) args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, eltype(src) end kernel = @cuda launch=false ∇scatter_src_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) kernel(args...; threads=threads, blocks=blocks) CUDA.unsafe_free!(rev_idx) return Δsrc end ================================================ FILE: ext/NNlibCUDAExt/utils.jl ================================================ NNlib._rng_from_array(::CuArray) = CUDA.default_rng() NNlib._rng_compat_array(rng::CUDA.RNG, A::CuArray) = nothing NNlib._rng_compat_array(rng::AbstractRNG, A::CuArray) = throw(ArgumentError( "cannot use rng::$(typeof(rng)) with array::CuArray, only CUDA's own RNG type works")) function divide_kernel!(xs, ys, max_idx) index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx xs[index] = xs[index] / ys[index] end return nothing end function divide_kernel!(xs, counts, max_idx, max_dims_idx, dims_size) index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx j, k = divrem(index-1, max_dims_idx) dims_i = Tuple(CartesianIndices(dims_size)[k+1]) CUDA.@atomic xs[dims_i..., j+1] = xs[dims_i..., j+1] / counts[j+1] end return nothing end function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N max_dims = NNlib.maximum_dims(idx) T = CartesianIndex{N} rev = Array{Vector{T}}(undef, max_dims...) for i in eachindex(rev) rev[i] = T[] end NNlib.reverse_indices!(rev, idx) return map(cu, rev) end ================================================ FILE: ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl ================================================ module NNlibEnzymeCoreExt using NNlib import EnzymeCore using Random using EnzymeCore.EnzymeRules for (name, dataname, filtername) in ( (typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!), (typeof(NNlib.∇conv_data!), NNlib.conv!, NNlib.∇conv_filter!), (typeof(NNlib.∇conv_filter!), NNlib.∇conv_data!, NNlib.conv!), ) @eval begin function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, cdims; kwargs...) where {RT, yT, xT, wT, N} if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated func.val(y.val, x.val, w.val, cdims.val; kwargs...) end primal = if EnzymeRules.needs_primal(config) y.val else nothing end shadow = if EnzymeRules.needs_shadow(config) y.dval else nothing end # Cache x if its overwritten and w is active (and thus required) cache_x = ( EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) ) ? copy(x.val) : nothing # Cache w if its overwritten and x is active (and thus required) cache_w = ( EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) ) ? copy(w.val) : nothing cache = (cache_x, cache_w) return EnzymeRules.AugmentedReturn(primal, shadow, cache) end function EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, y::EnzymeCore.Annotation{<:AbstractArray{yT, N}}, x::EnzymeCore.Annotation{<:AbstractArray{xT, N}}, w::EnzymeCore.Annotation{<:AbstractArray{wT, N}}, cdims; kwargs...) where {RT, yT, xT, wT, N} cache_x, cache_w = cache # Don't cache x if not overwritten and w is active (and thus required) if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) if !EnzymeRules.overwritten(config)[3] cache_x = x.val end end # Don't cache w if not overwritten and x is active (and thus required) if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) if !EnzymeRules.overwritten(config)[4] cache_w = w.val end end dys = y.dval dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval if EnzymeRules.width(config) == 1 dys = (dys,) dxs = (dxs,) dws = (dws,) end for (dy, dx, dw) in zip(dys, dxs, dws) if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val # dx += grad wrt x.val $dataname(dx, $(name != typeof(NNlib.∇conv_filter!) ? :dy : :cache_w), $(name != typeof(NNlib.∇conv_filter!) ? :cache_w : :dy), cdims.val; alpha=xT(1), beta=xT(1), kwargs...) end if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val # dw += grad wrt w.val $filtername(dw, $(name != typeof(NNlib.∇conv_data!) ? :cache_x : :dy), $(name != typeof(NNlib.∇conv_data!) ? :dy : :cache_x), cdims.val; alpha=wT(1), beta=wT(1), kwargs...) end dy .= 0 end end return (nothing, nothing, nothing, nothing) end end end function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(dst.val, src.val, idx.val) end primal = if EnzymeRules.needs_primal(config) dst.val else nothing end shadow = if EnzymeRules.needs_shadow(config) dst.dval else nothing end # Cache idx if its overwritten cache_idx = ( EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) end function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} # Don't cache idx if not overwritten if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) if !EnzymeRules.overwritten(config)[4] cache_idx = idx.val end end ddsts = dst.dval dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeRules.width(config) == 1 ddsts = (ddsts,) dsrcs = (dsrcs,) end for (ddst, dsrc) in zip(ddsts, dsrcs) if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val NNlib.scatter!(+, dsrc, ddst, cache_idx) end ddst .= 0 end end return (nothing, nothing, nothing) end function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} @assert !(OutType <: EnzymeCore.Const) if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(op.val, dst.val, src.val, idx.val) end primal = if EnzymeRules.needs_primal(config) dst.val else nothing end shadow = if EnzymeRules.needs_shadow(config) dst.dval else nothing end # Cache idx if its overwritten cache_idx = ( EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx) end function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT} # Don't cache idx if not overwritten if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const) if !EnzymeRules.overwritten(config)[4] cache_idx = idx.val end end ddsts = dst.dval dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeRules.width(config) == 1 ddsts = (ddsts,) dsrcs = (dsrcs,) end for (ddst, dsrc) in zip(ddsts, dsrcs) if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val if eltype(typeof(op)) == typeof(+) dsrc .+= NNlib.gather(ddst, cache_idx) else @assert eltype(typeof(op)) == typeof(-) dsrc .-= NNlib.gather(ddst, cache_idx) end end end end return (nothing, nothing, nothing, nothing) end for pool in [:maxpool, :meanpool, :lpnormpool] pool! = Symbol(pool, :!) ∇pool = Symbol(:∇, pool, :!) @eval begin function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT} if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated func.val(y.val, x.val, dims.val; kwargs...) end primal = if EnzymeRules.needs_primal(config) y.val else nothing end shadow = if EnzymeRules.needs_shadow(config) y.dval else nothing end cache_y = ( EnzymeRules.overwritten(config)[2] && !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) ) ? copy(y.val) : nothing cache_x = ( EnzymeRules.overwritten(config)[3] && !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) ) ? copy(x.val) : nothing cache = (cache_y, cache_x) return EnzymeRules.AugmentedReturn(primal, shadow, cache) end function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT} cache_y, cache_x = cache # Don't cache y if not overwritten if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) if !EnzymeRules.overwritten(config)[2] cache_y = y.val end end # Don't cache x if not overwritten if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const) if !EnzymeRules.overwritten(config)[3] cache_x = x.val end end dys = y.dval dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval if EnzymeRules.width(config) == 1 dys = (dys,) dxs = (dxs,) end for (dy, dx) in zip(dys, dxs) if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...) end dy .= 0 end end return (nothing, nothing, nothing) end end end function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} T = float(real(eltype(dst.val))) val = convert(T, 1/(1-p.val)) keep = if dims.val isa Colon similar(dst.val, T, size(dst.val)) else similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) end rand!(rng.val, keep) keep = keep .> p.val if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated dst.val .= (keep .* val) .* src.val end primal = if EnzymeRules.needs_primal(config) dst.val else nothing end shadow = if EnzymeRules.needs_shadow(config) dst.dval else nothing end if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const keep = nothing end return EnzymeRules.AugmentedReturn(primal, shadow, keep) end function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} T = float(real(eltype(dst.val))) val = convert(T, 1/(1-p.val)) ddsts = dst.dval dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval if EnzymeRules.width(config) == 1 ddsts = (ddsts,) dsrcs = (dsrcs,) end for (ddst, dsrc) in zip(ddsts, dsrcs) if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val dsrc .+= (keep .* val) .* ddst end ddst .= 0 end end dp = if typeof(p) <: EnzymeCore.Active typeof(p.val)(0) else nothing end return (nothing, nothing, nothing, dp, nothing) end end ================================================ FILE: ext/NNlibFFTWExt/NNlibFFTWExt.jl ================================================ module NNlibFFTWExt using FFTW using NNlib using KernelAbstractions include("stft.jl") end ================================================ FILE: ext/NNlibFFTWExt/stft.jl ================================================ function NNlib.stft(x; n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, center::Bool = true, normalized::Bool = false, ) kab = get_backend(x) use_window = !isnothing(window) use_window && kab != get_backend(window) && throw(ArgumentError( "`window` must be on the same device as stft input `x` ($kab), \ instead: `$(get_backend(window))`.")) use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError( "Expected `0 < length(window) ≤ n_fft=$n_fft`, \ but got `length(window)=$(length(window))`.")) hop_length < 0 && throw(ArgumentError( "Expected `hop_length > 0`, but got `hop_length=$hop_length`.")) # Pad window on both sides with `0` to `n_fft` length if needed. if use_window && length(window) < n_fft left = ((n_fft - length(window)) ÷ 2) + 1 tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft) tmp[left:left + length(window) - 1] .= window window = tmp end if center pad_amount = n_fft ÷ 2 x = pad_reflect(x, pad_amount; dims=1) end n = size(x, 1) (0 < n_fft ≤ n) || throw(ArgumentError( "Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`.")) n_frames = 1 + (n - n_fft) ÷ hop_length # time2col. # Reshape `x` to (n_fft, n_frames, B) if needed. # Each row in `n_frames` is shifted by `hop_length`. if n_frames > 1 # TODO can be more efficient if we support something like torch.as_strided ids = [ row + hop_length * col for row in 1:n_fft, col in 0:(n_frames - 1)] x = @inbounds x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...] end region = 1 use_window && (x = x .* window;) y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region) normalized && (y = y .* eltype(y)(n_fft^-0.5);) return y end function NNlib.istft(y; n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, center::Bool = true, normalized::Bool = false, return_complex::Bool = false, original_length::Union{Nothing, Int} = nothing, ) kab = get_backend(y) use_window = !isnothing(window) use_window && kab != get_backend(window) && throw(ArgumentError( "`window` must be on the same device as istft input `y` ($kab), \ instead: `$(get_backend(window))`.")) use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError( "Expected `0 < length(window) ≤ n_fft=$n_fft`, \ but got `length(window)=$(length(window))`.")) hop_length < 0 && throw(ArgumentError( "Expected `hop_length > 0`, but got `hop_length=$hop_length`.")) # TODO check `y` eltype is complex n_frames = size(y, 2) # Pad window on both sides with `0` to `n_fft` length if needed. if use_window && length(window) < n_fft left = ((n_fft - length(window)) ÷ 2) + 1 tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft) tmp[left:left + length(window) - 1] .= window window = tmp end # Denormalize. normalized && (y = y .* eltype(y)(n_fft^0.5);) region = 1 x = return_complex ? ifft(y, region) : irfft(y, n_fft, region) # De-apply window. use_window && (x = x ./ window;) # col2time. expected_output_len = n_fft + hop_length * (n_frames - 1) ids = Vector{Int}(undef, expected_output_len) in_idx, out_idx = 0, 0 prev_e, v = 0, 0 for col in 0:(n_frames - 1) for row in 1:n_fft in_idx += 1 v = row + hop_length * col v > prev_e || continue out_idx += 1 ids[out_idx] = in_idx end prev_e = v end # In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch). nd = ntuple(_ -> Colon(), ndims(x) - 2) ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));) x = @inbounds x[ids, nd...] # Trim padding. left = center ? (n_fft ÷ 2 + 1) : 1 right = if isnothing(original_length) center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len else left + original_length - 1 end x = x[left:right, nd...] return x end ================================================ FILE: ext/NNlibForwardDiffExt.jl ================================================ module NNlibForwardDiffExt using ForwardDiff: ForwardDiff using NNlib: NNlib NNlib.within_gradient(x::ForwardDiff.Dual) = true NNlib.within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true end ================================================ FILE: ext/NNlibMetalExt.jl ================================================ module NNlibMetalExt using Metal: method_table, @device_override using NNlib: NNlib @device_override NNlib.tanh_fast(x) = Base.FastMath.tanh_fast(x) end ================================================ FILE: ext/NNlibSpecialFunctionsExt.jl ================================================ module NNlibSpecialFunctionsExt using NNlib: NNlib, oftf using SpecialFunctions: erf # Full gelu (gelu_erf) NNlib.gelu_erf(x) = x/2*(1 + erf(x/sqrt(oftf(x,2)))) function NNlib.deriv_gelu_erf(x) SQRT2 = sqrt(oftf(x,2)) Φ = (1 + erf(x/SQRT2))/2 Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π)) end end ================================================ FILE: src/NNlib.jl ================================================ module NNlib import Atomix import ChainRulesCore: rrule using Base.Broadcast: broadcasted using Base.Threads using ChainRulesCore using GPUArraysCore using KernelAbstractions using KernelAbstractions: @atomic using LinearAlgebra using LinearAlgebra.BLAS: @blasfunc, BlasInt using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose using Random using ScopedValues using Statistics using Statistics: mean const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} # internal. TODO: change to an approach where amount of threading is controlled, not just on/off const ALLOW_SPAWNS = ScopedValue(true) should_use_spawn() = Threads.nthreads(:default) > 1 && ALLOW_SPAWNS[] """ @disallow_spawns ex Disallow NNlib to use `@spawn` on divisible workloads. i.e. within `conv` etc. """ macro disallow_spawns(ex) quote @with ALLOW_SPAWNS => false $(esc(ex)) end end # Include APIs include("dim_helpers.jl") export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims include("activations.jl") for f in ACTIVATIONS @eval export $(f) end export sigmoid, hardsigmoid, logsigmoid, thresholdrelu, gelu # Aliases include("attention.jl") export dot_product_attention, dot_product_attention_scores, make_causal_mask include("dropout.jl") export dropout, dropout! include("softmax.jl") export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp include("batched/batchedadjtrans.jl") include("batched/batchedmul.jl") export batched_mul, batched_mul!, ⊠, batched_vec, batched_transpose, batched_adjoint include("gemm.jl") export grid_sample, ∇grid_sample include("conv.jl") export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, depthwiseconv, depthwiseconv!, ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter, ∇depthwiseconv_filter! include("conv_bias_act.jl") export conv_bias_act, conv_bias_act! include("bias_act.jl") export bias_act! include("fold.jl") include("ctc.jl") export ctc_loss include("pooling.jl") export maxpool, maxpool!, meanpool, meanpool!, lpnormpool, lpnormpool!, ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lpnormpool, ∇lpnormpool! include("padding.jl") export pad_constant, pad_repeat, pad_reflect, pad_zeros, pad_symmetric, pad_circular include("upsample.jl") export upsample_nearest, ∇upsample_nearest, upsample_linear, ∇upsample_linear, upsample_bilinear, ∇upsample_bilinear, upsample_trilinear, ∇upsample_trilinear, pixel_shuffle include("gather.jl") include("scatter.jl") include("utils.jl") include("sampling.jl") include("functions.jl") include("normalization.jl") # export batchnorm, ∇batchnorm ## Include implementations include("impl/padding_edges.jl") # Direct implementations of convolutional and depthwise-convolutional algorithms include("impl/conv_direct.jl") include("impl/depthwiseconv_direct.jl") # im2col implementations of convolutional and depthwise-convolutional algorithms include("impl/conv_im2col.jl") include("impl/depthwiseconv_im2col.jl") # Direct implementations of pooling include("impl/pooling_direct.jl") include("deprecations.jl") include("rotation.jl") export imrotate, ∇imrotate include("audio/stft.jl") include("audio/spectrogram.jl") include("audio/mel.jl") export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks end # module NNlib ================================================ FILE: src/activations.jl ================================================ ## Activation functions # # Some of activation functions have its wrapper function for GPU in NNlibCUDAExt.jl. # https://github.com/JuliaGPU/CuArrays.jl/issues/614 ACTIVATIONS = [ :σ, :hardσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_sigmoid, :gelu_erf, :swish, :hardswish, :selu, :celu, :softplus, :softsign, :logσ, :logcosh, :mish, :tanhshrink, :softshrink, :trelu, :lisht, :tanh_fast, :sigmoid_fast, ] # of type float (to allow for integer inputs) oftf(x, y) = oftype(float(x), y) # oftype contains control flow on 1.10+, which can lead to type instabilities under AD function rrule(::typeof(oftf), x, y) proj_y = ChainRulesCore.ProjectTo(y) oftf_pullback(Δ) = (NoTangent(), NoTangent(), proj_y(Δ)) return oftf(x, y), oftf_pullback end """ σ(x) = 1 / (1 + exp(-x)) Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation function. Unicode `σ` can be entered as `\\sigma` then tab, in many editors. The ascii name `sigmoid` is also exported. See also [`sigmoid_fast`](@ref). ```julia-repl julia> using UnicodePlots julia> lineplot(sigmoid, -5, 5, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡏⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡔⠋⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠊⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⣀⣀⣀⣀⣀⣀⣀⠤⠤⠤⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> sigmoid === σ true ``` """ function σ(x) t = exp(-abs(x)) ifelse(x ≥ 0, inv(1 + t), t / (1 + t)) end const sigmoid = σ """ hardσ(x) = max(0, min(1, (x + 3) / 6)) Piecewise linear approximation of [`sigmoid`](@ref). ```julia-repl julia> lineplot(hardsigmoid, -5, 5, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠉⠉⠉⠉⠉⠉⠉⠉│ hardσ(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡠⠔⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⡗⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠋⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⠤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> lineplot(sigmoid, -5, 5, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡏⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡔⠋⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠊⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⣀⣀⣀⣀⣀⣀⣀⠤⠤⠤⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ hardσ(x) = clamp((x + 3) / 6, 0, 1) # https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html const hardsigmoid = hardσ """ logσ(x) Return `log(σ(x))` which is computed in a numerically stable way. ```julia-repl julia> lineplot(logsigmoid, -5, 5, height=7) ┌────────────────────────────────────────┐ 0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡧⠤⠔⠒⠒⠒⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ logσ(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠉⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⢀⡤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⡤⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ logσ(x) = -softplus(-x) const logsigmoid = logσ """ hardtanh(x) = max(-1, min(1, x)) Segment-wise linear approximation of `tanh`, much cheaper to compute. See ["Large Scale Machine Learning"](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf). See also [`tanh_fast`](@ref). ```julia-repl julia> lineplot(hardtanh, -2, 2, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⠔⠋⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ hardtanh(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡷⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠖⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠖⠋⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⠔⠋⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x julia> lineplot(tanh, -2, 2, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠤⠒⠒⠒⠊⠉⠉⠉│ tanh(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡷⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠊⠁⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⣀⣀⣀⡠⠤⠤⠤⠖⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ hardtanh(x) = clamp(x, oftype(x, -1), oftype(x, 1)) # clamp(x, -1, 1) is type-stable, but would promote Int32, for which we have tests """ relu(x) = max(0, x) [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation function. ```julia-repl julia> lineplot(relu, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│ relu(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠊⠁⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⡠⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⡠⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⠔⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ relu(x) = ifelse(x<0, zero(x), x) # faster than max(zero(x), x), still preserves NaN """ leakyrelu(x, a=0.01) = max(a*x, x) Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation function. You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`. ```julia-repl julia> lineplot(x -> leakyrelu(x, 0.5), -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ #42(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠤⠒⠒⠋⠉⠁⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⣀⣀⠤⠤⠒⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> leakyrelu(-10f0, 0.2) -2.0f0 julia> leakyrelu(-10f0, 0.02) -0.5f0 ``` """ leakyrelu(x, a=oftf(x, leakyrelu_a)) = ifelse(x>0, float(x), oftf(x, a*x)) # max(a*x, x) is 3x slower const leakyrelu_a = 0.01 # also used in gradient below """ relu6(x) = min(max(0, x), 6) [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation function capped at 6. See ["Convolutional Deep Belief Networks"](https://www.cs.toronto.edu/~kriz/conv-cifar10-aug2010.pdf) from CIFAR-10. ```julia-repl julia> lineplot(relu6, -10, 10, height=7) ┌────────────────────────────────────────┐ 6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠎⠉⠉⠉⠉⠉⠉⠉⠉│ relu6(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⡤⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⡠⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡔⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⡧⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-10⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀10⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ relu6(x) = clamp(x, oftype(x, 0), oftype(x, 6)) # clamp promotes, but clamp(x, 0, 6) would promote x::Int32 """ rrelu(x, lo=1/8, hi=1/3) = max(a*x, x) # where `a` is randomly sampled from uniform distribution `U(lo, hi)` Randomized Leaky Rectified Linear Unit activation function. See ["Empirical Evaluation of Rectified Activations"](https://arxiv.org/abs/1505.00853) You can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`. ```julia-repl julia> lineplot(rrelu, -20, 10, height=7) ┌────────────────────────────────────────┐ 10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ rrelu(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⠤⣤⣤⢤⣤⣤⠤⠤⠤⢼⠮⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⣰⢀⣆⡄⣄⡄⡠⡰⠦⠷⡜⢢⠷⠳⠢⠊⠉⠉⠀⠀⠁⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠃⠉⠙⠘⠃⠈⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-20⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀10⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> extrema(rrelu.(fill(-10f0, 1000))) (-3.3316886f0, -1.2548422f0) ``` """ function rrelu(x::T, l=oftf(x,1/8), u=oftf(x,1/3)) where T<:Number a = (u - l) * rand(float(T)) + l return leakyrelu(x, a) end """ elu(x, α=1) = x > 0 ? x : α * (exp(x) - 1) Exponential Linear Unit activation function. See ["Fast and Accurate Deep Network Learning by Exponential Linear Units"](https://arxiv.org/abs/1511.07289). You can also specify the coefficient explicitly, e.g. `elu(x, 1)`. ```julia-repl julia> lineplot(elu, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ elu(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠔⠒⠋⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⠤⠤⠤⠤⠔⠒⠒⠒⠊⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> elu(-10f0) -0.9999546f0 julia> elu(-10f0, 2) -1.9999092f0 ``` """ elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1)) deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α)) """ gelu_tanh(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) using tanh approximation. This implementation uses `tanh` which allows for better pattern matching and fusion in optimizing compilers compared to the sigmoid-based implementation. For a potentially faster implementation that uses `sigmoid_fast`, see [`gelu_sigmoid`](@ref). ```julia-repl julia> lineplot(gelu_tanh, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu_tanh(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⣤⣤⣤⣤⣤⣤⣤⣤⡤⠤⠤⠤⠤⠤⠤⠤⣤⣤⣤⡤⡧⠶⠶⠭⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠉⠉⠉⠉⠉⠉⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> lineplot(gelu_tanh, -5, 0, height=7); julia> lineplot!(ans, swish) ┌────────────────────────────────────────┐ 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu_tanh(x) │⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x) │⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⣄⠀⠀⠀⠀⠀⢠⡞⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⢄⣀⣀⡤⢣⠃⠀⠀│ -0.2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠇⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ function gelu_tanh(x) α = oftf(x, 0.044715) λ = oftf(x, gelu_λ) x/2 * (1 + tanh_fast(λ * (x + α * x^3))) end const gelu_λ = √(2 / π) const gelu_2λ = √(8 / π) function deriv_gelu_tanh(x) α = oftf(x, 0.044715) α2 = oftf(x, 0.08943) λ = oftf(x, gelu_λ) x2 = x * x t = muladd(x2, α, one(x)) z = λ * x * t Ω = tanh_fast(z) sech2 = 1 - Ω^2 (1 + Ω)/2 + x * λ * muladd(x2, α2, t) * sech2 / 2 end """ gelu_sigmoid(x) = x * σ(√(8/π) * (x + 0.044715x^3)) Alternative implementation of the GELU activation function using `sigmoid` instead of `tanh`. This is mathematically equivalent to [`gelu_tanh`](@ref) but may be faster in some cases. The sigmoid-based implementation may prevent pattern matching and fusion in some optimizing compilers. Use [`gelu_tanh`](@ref) if you need better compiler optimization support. See ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). """ function gelu_sigmoid(x) α = oftf(x, 0.044715) λλ = oftf(x, gelu_2λ) x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x))) end function deriv_gelu_sigmoid(x) α = oftf(x, 0.044715) α2 = oftf(x, 0.08943) λλ = oftf(x, gelu_2λ) x2 = x * x t = muladd(x2, α, one(x)) Ω = sigmoid_fast(λλ * x * t) dσ = conj(Ω * (1 - Ω)) muladd(dσ * λλ * muladd(x2, α2, t), x, Ω) end """ gelu_erf(x) = xΦ(x) = 0.5x * (1 + erf(x/√2)) Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) without approximation. The SpecialFunctions.jl package needs to be loaded to use this function. """ function gelu_erf end function deriv_gelu_erf end """ gelu(x) = gelu_tanh(x) Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). See [`gelu_tanh`](@ref). """ const gelu = gelu_tanh # Need to alias the type as well to ensure serialization libraries still work # See https://github.com/FluxML/NNlib.jl/issues/631 const var"#gelu" = typeof(gelu_tanh) const deriv_gelu = deriv_gelu_tanh """ swish(x) = x * σ(x) Self-gated activation function. See ["Swish: a Self-Gated Activation Function"](https://arxiv.org/abs/1710.05941). ```julia-repl julia> lineplot(swish, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤│ swish(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋⠁⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⢀⣀⡤⠔⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⣤⣤⡤⡧⠴⠶⠯⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠉⠑⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ @inline swish(x) = x * sigmoid_fast(x) """ hardswish(x) = x * hardσ(x) Hard-Swish activation function. See ["Searching for MobileNetV3"](https://arxiv.org/abs/1905.02244). ```julia-repl julia> lineplot(hardswish, -2, 5, height = 7) ┌────────────────────────────────────────┐ 5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠒⠉│ hardswish(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠒⠉⠁⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣤⣤⣖⣚⣉⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│ -1 │⠉⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> lineplot(hardswish, -4, 0, height = 7); julia> lineplot!(ans, swish) ┌────────────────────────────────────────┐ 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜│ hardswish(x) │⠒⠒⠢⠤⢄⣀⡀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀│ swish(x) │⠀⠀⠀⠀⠀⠀⠈⠉⠑⠒⠦⢄⣘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡴⠃⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠑⡖⠦⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⢔⠏⠁⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠣⣄⠀⠉⠑⠒⠦⠤⢄⣀⣀⣀⣀⡠⠤⠖⣊⠕⠁⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⠤⡀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀│ -0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠒⠢⠤⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-4⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> hardswish.(-5:5)' 1×11 adjoint(::Vector{Float64}) with eltype Float64: -0.0 -0.0 -0.0 -0.333333 -0.333333 0.0 0.666667 1.66667 3.0 4.0 5.0 ``` """ @inline hardswish(x) = x * hardσ(x) deriv_hardswish(x) = ifelse(x < -3, oftf(x,0), ifelse(x > 3, oftf(x,1), x/3 + oftf(x,1/2))) """ lisht(x) = x * tanh(x) Activation function from ["LiSHT: Non-Parametric Linearly Scaled Hyperbolic Tangent ..."](https://arxiv.org/abs/1901.05894) ```julia-repl julia> lineplot(lisht, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x) │⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│ │⠀⠀⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⠦⣄⣀⣀⣇⣀⣀⠤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> lineplot!(ans, logcosh) ┌────────────────────────────────────────┐ 2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x) │⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│ logcosh(x) │⠢⣄⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⣀⠔│ f(x) │⠀⠈⠑⠢⣀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⣀⠔⠊⠁⠀│ │⠀⠀⠀⠀⠀⠉⠢⢄⡀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⡠⠔⠋⠁⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠦⣌⡓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⣁⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠪⠷⣦⣄⣀⣀⣇⣀⣀⣤⠶⠕⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ lisht(x) = x * tanh_fast(x) """ selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1)) λ ≈ 1.05070... α ≈ 1.67326... Scaled exponential linear units. See ["Self-Normalizing Neural Networks"](https://arxiv.org/abs/1706.02515). ```julia-repl julia> lineplot(selu, -3, 2, height=7) ┌────────────────────────────────────────┐ 3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠊⠉⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⡠⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⣉⠭⠛⡏⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⡤⠤⠒⠊⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -2 │⠤⠤⠖⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> selu(-10f0) -1.7580194f0 ``` """ function selu(x) λ = oftf(x, selu_λ) α = oftf(x, selu_α) λ * ifelse(x > 0, x, @fastmath α * (exp(x) - 1)) end const selu_λ = 1.0507009873554804934193349852946 const selu_α = 1.6732632423543772848170429916717 function deriv_selu(Ω) λ = oftf(Ω, selu_λ) α = oftf(Ω, selu_α) ifelse(Ω > 0, λ, Ω + α * λ) end """ celu(x, α=1) = x ≥ 0 ? x : α * (exp(x/α) - 1) Activation function from ["Continuously Differentiable Exponential Linear Units"](https://arxiv.org/abs/1704.07483). ```julia-repl julia> lineplot(celu, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ celu(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠔⠒⠋⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⠤⠤⠤⠤⠔⠒⠒⠒⠊⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> celu(-10f0) -0.9999546f0 ``` """ celu(x, α=1) = ifelse(x ≥ 0, float(x), oftf(x,α) * (exp(x/oftf(x,α)) - 1)) deriv_celu(Ω, α=1) = ifelse(Ω > 0, oftf(Ω, 1), Ω / oftf(Ω, α) + 1) """ trelu(x, theta=1) = x > theta ? x : 0 Threshold gated rectified linear activation function. See ["Zero-bias autoencoders and the benefits of co-adapting features"](https://arxiv.org/abs/1402.3337) ```julia-repl julia> lineplot(trelu, -2, 4, height=7) ┌────────────────────────────────────────┐ 4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ trelu(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠴⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⡏⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣀⣀⣀⣀⣀⣀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀4⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ trelu(x, theta=1) = ifelse(x <= theta, zero(x), x) const thresholdrelu = trelu """ softsign(x) = x / (1 + |x|) See ["Quadratic Polynomials Learn Better Image Features"](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205) (2009). ```julia-repl julia> lineplot(softsign, -5, 5, height=7) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⣀⣀⠤⠤⠤⠤⠤│ softsign(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⡤⠖⠒⠋⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⡔⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡯⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠒⠋⠁⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⠒⠒⠒⠒⠒⠊⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> lineplot!(ans, tanh) ┌────────────────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡤⠖⠊⠉⠉⠉⣉⣉⣉⣉⣉⠭⠭⠭⠭⠭│ softsign(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡔⣃⡤⠖⠒⠋⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanh(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣧⡞⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡯⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡴⠃⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠒⢋⠕⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -1 │⣒⣒⣒⣒⣒⣊⣉⣉⣉⣉⣁⣀⣀⡠⠤⠒⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> softsign(1f0) 0.5f0 julia> softsign(100f0) 0.990099f0 ``` """ softsign(x) = x / (1 + abs(x)) deriv_softsign(x) = 1 / (1 + abs(x))^2 """ softplus(x) = log(exp(x) + 1) See ["Deep Sparse Rectifier Neural Networks"](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf), JMLR 2011. ```julia-repl julia> lineplot(softplus, -3, 3, height=7) ┌────────────────────────────────────────┐ 4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠊⠁⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⠤⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡧⠤⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⣀⣀⣀⣀⣀⣀⣀⡠⠤⠤⠤⠤⠔⠒⠒⠚⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> lineplot!(ans, relu) ┌────────────────────────────────────────┐ 4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠│ relu(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠⡴⠞⠋⠁│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣤⡴⠞⠋⠁⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⢤⡲⠝⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡧⠤⠒⠊⣉⠥⠚⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⣀⣀⣀⣀⣀⣀⣀⣠⣤⣤⣤⣤⣔⣒⣒⣚⣉⣉⣁⣀⣇⠴⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> softplus(16f0) 16.0f0 ``` """ softplus(x) = log1p(exp(-abs(x))) + relu(x) """ logcosh(x) Return `log(cosh(x))` which is computed in a numerically stable way. ```julia-repl julia> lineplot(logcosh, -5, 5, height=7) ┌────────────────────────────────────────┐ 5 │⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ logcosh(x) │⠉⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│ │⠀⠀⠀⠑⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠑⠦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠊⠁⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⠦⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠢⢄⣀⣀⣇⣀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ logcosh(x) = x + softplus(-2x) - oftf(x, log2) const log2 = log(2) """ mish(x) = x * tanh(softplus(x)) Activation function from ["Mish: A Self Regularized Non-Monotonic Neural Activation Function"](https://arxiv.org/abs/1908.08681). ```julia-repl julia> lineplot(mish, -5, 5, height=7) ┌────────────────────────────────────────┐ 5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋│ mish(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠒⠁⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠔⠋⠁⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⡠⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣧⣔⣊⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│ -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ mish(x) = x * tanh(softplus(x)) """ tanhshrink(x) = x - tanh(x) See ["Tanhshrink Activation Function"](https://www.gabormelli.com/RKB/Tanhshrink_Activation_Function). ```julia-repl julia> lineplot(tanhshrink, -3, 3, height=7) ┌────────────────────────────────────────┐ 3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanhshrink(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠊│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⣀⡠⠤⠒⠊⠉⠁⠀⠀⠀⠀│ f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⣤⡤⠤⠤⠤⠤⠤⠤⡷⠶⠶⠶⠶⠶⠮⠭⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⣀⡠⠴⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⡠⠴⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> tanhshrink.((-10f0, 10f0)) (-9.0f0, 9.0f0) ``` """ tanhshrink(x) = x - tanh_fast(x) """ softshrink(x, λ=0.5) = (x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0)) See ["Softshrink Activation Function"](https://www.gabormelli.com/RKB/Softshrink_Activation_Function). ```julia-repl julia> lineplot(softshrink, -2, 2, height=7) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠒⠉⠁│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀│ f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⡤⠤⠤⠤⠤⠤⠤⡧⠤⠤⠤⠤⠶⠮⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⠀⠀⠀⠀⠀⢀⣀⠤⠖⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⣀⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -2 │⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ julia> lineplot!(ans, tanhshrink) ┌────────────────────────────────────────┐ 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠒⣉⡡│ tanhshrink(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⣒⣋⠥⠤⠒⠊⠉⠁⠀│ f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⣤⣤⡤⠤⠤⠤⠤⠤⠤⡷⠶⠶⠶⠶⠶⠾⠿⠯⠭⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤│ │⠀⢀⣀⡠⠤⠖⢒⣋⠭⠗⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠊⣉⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ -2 │⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ └────────────────────────────────────────┘ ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀ julia> softshrink.((-10f0, 10f0)) (-9.5f0, 9.5f0) ``` """ function softshrink(x, λ = 0.5) lo = x - oftf(x, λ) hi = x + oftf(x, λ) ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi) end # Define broadcasts for activation functions on arrays for f in ACTIVATIONS @eval $(f)(x::AbstractArray, args...) = $(f).(x, args...) end ## Faster, less accurate, versions of some. """ tanh_fast(x) This is a faster but slighly less accurate version of `tanh`. Where Julia's `tanh` function has an error under 2 eps, this may be wrong by 5 eps, a reduction by less than one decimal digit. For `x::Float32` this is usually about 10 times faster, with a smaller speedup for `x::Float64`. For any other number types, it just calls `tanh`. See also [`sigmoid_fast`](@ref). ```julia-repl julia> tanh(0.5f0) 0.46211717f0 julia> tanh_fast(0.5f0) 0.46211714f0 julia> hard_tanh(0.5f0) 0.5f0 ``` """ @inline function tanh_fast(x::Float32) # This method added in NNlib.jl#345 by @mcabbott and @oscardssmith, # with coeffiecients found using Remez.jl x2 = abs2(x) n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8)) d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7)) ifelse(x2 < 66f0, x * (n / d), sign(x)) end @inline function tanh_fast(x::Float64) exp2x = @fastmath exp(x + x) y = (exp2x - 1) / (exp2x + 1) # That has large errors near zero; using `expm1` would more accurate, but about as slow as `tanh`. # Instead, we switch to a polynomial, which is very accurate within its range: x2 = x * x ypoly = x * evalpoly(x2, (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953)) ifelse(x2 > 900.0, sign(x), ifelse(x2 < 0.017, ypoly, y)) end # These approximations are very badly behaved for Float16; none are fast. # They are also a bit slower with ForwardDiff.Dual numbers, let's use Base: tanh_fast(x::Number) = Base.tanh(x) """ sigmoid_fast(x) This is a faster, and very slightly less accurate, version of `sigmoid`. For `x::Float32`, perhaps 3 times faster, and maximum errors 2 eps instead of 1. See also [`tanh_fast`](@ref). ```julia-repl julia> sigmoid(0.2f0) 0.54983395f0 julia> sigmoid_fast(0.2f0) 0.54983395f0 julia> hardσ(0.2f0) 0.53333336f0 ``` """ function sigmoid_fast(x::Real) @static if VERSION ≥ v"1.11-" @inline end t = @fastmath exp(-abs(x)) y = ifelse(x ≥ 0, inv(1 + t), t / (1 + t)) ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y)) end # For x::Float32, this is not as quick as the rational tanh_fast(x) above, # but that polynomial has poor relative accuracy for negative x. sigmoid_fast(x::Float16) = sigmoid(x) # sigmoid_fast is extremely badly behaved at large x function sigmoid_fast(x::Number) Base.depwarn("sigmoid only makes sense on real numbers, got $(typeof(x))", :sigmoid_fast) sigmoid(x) end """ NNlib.fast_act(f, [x::AbstractArray]) Replaces `f == tanh` with [`tanh_fast`](@ref), etc. Takes an optional 2nd argument, so that you can disable this replacement for some array or element types. """ @inline fast_act(f::F, ::AbstractArray = 1:0) where {F<:Function} = f @inline fast_act(::typeof(tanh), ::AbstractArray = 1:0) = tanh_fast @inline fast_act(::typeof(sigmoid), ::AbstractArray = 1:0) = sigmoid_fast ## Define rrules for some activation functions, along with the ## broadcasted rrule activation functions. ## This is a performance hack specifically for Zygote, because it doesn't handle fused ## broadcasts well; but it generally should be good (or at least harmless) for any AD, as ## it saves ADing the broadcasting machinery. ## Related Issue https://github.com/JuliaDiff/ChainRulesCore.jl/issues/271 ## TODO: add to the lists below all activations. UNARY_ACTS = [ # f, dfdx ## In the same order as above! (:σ, :(conj(Ω * (1 - Ω)))), (:hardσ, :(ifelse((Ω>0)&(Ω<1), oftf(Ω, 1/6), oftf(Ω, 1)))), (:logσ, :(sigmoid_fast(-x))), (:hardtanh, :((Ω>-1) & (Ω<1))), (:relu, :(Ω > 0)), (:leakyrelu, :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, leakyrelu_a)))), (:relu6, :((Ω>0) & (Ω<6))), # rrelu is random, can't write a rule. (:elu, :(deriv_elu(Ω))), (:gelu_tanh, :(deriv_gelu_tanh(x))), (:gelu_sigmoid, :(deriv_gelu_sigmoid(x))), (:gelu_erf, :(deriv_gelu_erf(x))), (:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))), (:hardswish, :(deriv_hardswish(x))), # lisht (:selu, :(deriv_selu(Ω))), (:celu, :(deriv_celu(Ω))), (:trelu, :(Ω > 0)), (:softsign, :(deriv_softsign(x))), (:softplus, :(sigmoid_fast(x))), # (:softplus, :(1 - @fastmath exp(-Ω))), # slightly faster, check accuracy? # logcosh # mish (:tanhshrink, :((x - Ω)^2)), (:softshrink, :(Ω != 0)), ## Fast variants are the same! (:tanh_fast, :(conj(1 - Ω^2))), (:sigmoid_fast, :(conj(Ω * (1 - Ω)))), ] for (f, dfdx) in UNARY_ACTS @eval @scalar_rule($f(x), $dfdx) pullback = Symbol(:broadcasted_, f, :_pullback) @eval function rrule(::typeof(broadcasted), ::typeof($f), x::Union{Numeric, Broadcast.Broadcasted}) Ω = $f.(x) function $pullback(dΩ) x_thunk = InplaceableThunk( dx -> @.(dx += dΩ * $dfdx), @thunk @.(dΩ * $dfdx) ) NoTangent(), NoTangent(), x_thunk end return Ω, $pullback end end # NO_ACT_GRAD = ChainRulesCore.@not_implemented "for simplicitly NNlib assumes the 2nd argument of this activation function is a constant" NO_ACT_GRAD = NaN ## Still reminds you not to use this, but is perhaps more GPU friendly. BINARY_ACTS = [ # f, dfdx1, dfdx2 ## In the same order as above! (:leakyrelu, :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, x2))), NO_ACT_GRAD), (:elu, :(deriv_elu(Ω, x2)), NO_ACT_GRAD), (:celu, :(deriv_celu(Ω, x2)), NO_ACT_GRAD), (:trelu, :(Ω > 0), ZeroTangent()), (:softshrink, :(Ω != 0), NO_ACT_GRAD), ] for (f, dfdx1, dfdx2) in BINARY_ACTS @eval @scalar_rule($f(x1, x2), ($dfdx1, $dfdx2)) pullback = Symbol(:broadcasted_, f, :_pullback_2arg) @eval function rrule(::typeof(broadcasted), ::typeof($f), x1::Union{Numeric, Broadcast.Broadcasted}, x2::Number) Ω = $f.(x1, x2) ## Allowing x2::Array would allow size(Ω) != size(x1), which is not handled here: $pullback(dΩ) = (NoTangent(), NoTangent(), @.(dΩ * $dfdx1), NO_ACT_GRAD) return Ω, $pullback end end ================================================ FILE: src/attention.jl ================================================ const AA3{T} = AbstractArray{T,3} const AA4{T} = AbstractArray{T,4} const AA{N,T} = AbstractArray{T,N} """ dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads]) Multihead dot product attention used in transformer architectures. The input arrays must have the first two dimensions given by the number of features and the sequence length, then an arbitrary number of batch dimensions or none. Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores of size `(kv_len, q_len, nheads, batch_size...)`. See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. # Arguments - `query`: Query array of size `(qk_dim, q_len, batch_size...)`. - `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. - `value`: Value array of size `(v_dim, kv_len, batch_size...)`. - `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. It will be added to the attention scores before applying the softmax. Default `nothing`. - `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. Default `identity` (no dropout). - `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. The mask is applied to the attention scores just before the softmax. See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`. - `nheads`: Number of heads to split the input arrays into. Default `1`. # Examples ```julia q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) y, α = dot_product_attention(q, k, v) ``` """ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N batch_size = size(q)[3:end] batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) x, α = dot_product_attention(q, k, v, args...; kws...) x = reshape(x, size(x, 1), size(x, 2), batch_size...) α = reshape(α, size(α)[1:3]..., batch_size...) return x, α end function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; fdrop=identity, mask=nothing, nheads=1) (all(size.((q, k, v), 1) .% nheads .== 0)) || throw(ArgumentError(""" First dimension in query, key and value must be divisible by `nheads`. Instead: - size(q): $(size(q)) - size(k): $(size(k)) - size(v): $(size(v)) - nheads: $nheads """)) (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError(""" Batch dimensions have to be the same. Instead: - size(q): $(size(q)) - size(k): $(size(k)) - size(v): $(size(v)) """)) size(q, 1) == size(k, 1) || throw(ArgumentError(""" First dimension in query and key has to be the same. Instead: - size(q): $(size(q)) - size(k): $(size(k)) """)) size(k, 2) == size(v, 2) || throw(ArgumentError(""" Second dimension in key and value has to be the same. Instead: - size(k): $(size(k)) - size(v): $(size(v)) """)) # Multihead attention. TODO create fastpath for singlehead attention. q, k, v = split_heads.((q, k, v), nheads) x, α = _dot_product_attention(q, k, v, bias, fdrop, mask) return join_heads(x), α end function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask) # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] α = dot_product_attention_scores(q, k, bias; fdrop, mask) # [α] = [kv_len, q_len, nheads, batch_size] # The following permutedims and batched_mul are equivalent to # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] vt = permutedims(v, (1, 3, 2, 4)) x = batched_mul(vt, α) x = permutedims(x, (1, 3, 2, 4)) # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] return x, α end """ dot_product_attention_scores(query, key, [bias]; [fdrop, mask]) Return the attention scores for the [`dot_product_attention`](@ref). Input arrays must have dimensions `(num_features ÷ nheads, nheads, sequence_length, batch_size)`. See [`dot_product_attention`](@ref) for more details. """ function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; fdrop=identity, mask=nothing) where T # The following permutedims and batched_mul are equivalent to # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) kt = permutedims(k, (3, 1, 2, 4)) qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1)) logits = batched_mul(kt, qt) # [logits] = [kv_len, q_len, nheads, batch_size] logits = apply_attn_bias(logits, bias) logits = apply_attn_mask(logits, mask) α = softmax(logits, dims=1) return fdrop(α) end apply_attn_bias(logits, bias::Nothing) = logits apply_attn_bias(logits, bias) = logits .+ bias apply_attn_mask(logits, mask::Nothing) = logits function apply_attn_mask(logits, mask) neginf = typemin(eltype(logits)) ifelse.(mask, logits, neginf) end """ make_causal_mask(x, dims=2) Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`. Its elements are set such that `m[i, j] == i ≤ j`. Can be used to mask the attention scores in [`dot_product_attention`](@ref). """ function make_causal_mask(x::AbstractArray; dims::Int=2) len = size(x, dims) mask = triu(trues_like(x, (len, len))) return mask end trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) join_heads(x) = reshape(x, :, size(x)[3:end]...) @non_differentiable make_causal_mask(::Any...) @non_differentiable trues_like(::Any...) @non_differentiable falses_like(::Any...) ================================================ FILE: src/audio/mel.jl ================================================ """ melscale_filterbanks(; n_freqs::Int, n_mels::Int, sample_rate::Int, fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2)) Create triangular Mel scale filter banks (ref: [Mel scale - Wikipedia](https://en.wikipedia.org/wiki/Mel_scale)). Each column is a filterbank that highlights its own frequency. # Arguments: - `n_freqs::Int`: Number of frequencies to highlight. - `n_mels::Int`: Number of mel filterbanks. - `sample_rate::Int`: Sample rate of the audio waveform. - `fmin::Float32`: Minimum frequency in Hz. - `fmax::Float32`: Maximum frequency in Hz. # Returns: Filterbank matrix of shape `(n_freqs, n_mels)` where each column is a filterbank. ```jldoctest julia> n_mels = 8; julia> fb = melscale_filterbanks(; n_freqs=200, n_mels, sample_rate=16000); julia> plot = lineplot(fb[:, 1]); julia> for i in 2:n_mels lineplot!(plot, fb[:, i]) end julia> plot ┌────────────────────────────────────────┐ 1 │⠀⡀⢸⠀⢸⠀⠀⣧⠀⠀⢸⡄⠀⠀⠀⣷⠀⠀⠀⠀⠀⣷⠀⠀⠀⠀⠀⠀⢀⣿⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⡇⢸⡆⢸⡇⠀⣿⠀⠀⡜⡇⠀⠀⢰⠋⡆⠀⠀⠀⢰⠁⡇⠀⠀⠀⠀⠀⡸⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⣿⢸⡇⡇⡇⢰⠹⡄⠀⡇⢱⠀⠀⢸⠀⢣⠀⠀⠀⡜⠀⢸⡀⠀⠀⠀⢀⠇⠀⠈⡇⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⣿⡇⡇⡇⡇⢸⠀⡇⢀⠇⠸⡀⠀⡇⠀⠸⡀⠀⢀⠇⠀⠀⢇⠀⠀⠀⡸⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀│ │⢠⢻⡇⡇⡇⢱⢸⠀⢇⢸⠀⠀⡇⢀⠇⠀⠀⡇⠀⢸⠀⠀⠀⠸⡀⠀⢠⠇⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀│ │⢸⢸⡇⢱⡇⢸⡇⠀⢸⢸⠀⠀⢣⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⢇⠀⡜⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀⠀│ │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⡎⠀⠀⠀⠈⣶⠁⠀⠀⠀⠀⠸⣤⠃⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀⠀⠀│ │⢸⠀⡇⢸⠀⠀⡇⠀⠀⡇⠀⠀⠀⡇⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⢱⡀⠀⠀⠀⠀│ │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⢇⠀⠀⠀⢀⠿⡀⠀⠀⠀⠀⢰⠛⡄⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀⠀⠀│ │⢸⢸⡇⡸⡇⢸⡇⠀⢸⢸⠀⠀⡜⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⡎⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀│ │⢸⢸⡇⡇⡇⡸⢸⠀⡎⢸⠀⠀⡇⠈⡆⠀⠀⡇⠀⢸⠀⠀⠀⢰⠁⠀⠘⡆⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀│ │⡇⢸⡇⡇⡇⡇⢸⠀⡇⠈⡆⢰⠁⠀⡇⠀⢰⠁⠀⠈⡆⠀⠀⡎⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀│ │⡇⢸⢸⡇⡇⡇⠸⣰⠃⠀⡇⡸⠀⠀⢸⠀⡜⠀⠀⠀⢣⠀⢸⠁⠀⠀⠀⠈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀│ │⡇⡇⢸⠇⢸⡇⠀⣿⠀⠀⢣⡇⠀⠀⠸⣄⠇⠀⠀⠀⠸⡀⡇⠀⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄│ 0 │⣇⣇⣸⣀⣸⣀⣀⣟⣀⣀⣸⣃⣀⣀⣀⣿⣀⣀⣀⣀⣀⣿⣀⣀⣀⣀⣀⣀⣈⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣱│ └────────────────────────────────────────┘ ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀200⠀ ``` """ function melscale_filterbanks(; n_freqs::Int, n_mels::Int, sample_rate::Int, fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2), ) mel_min, mel_max = _hz_to_mel(fmin), _hz_to_mel(fmax) mel_points = range(mel_min, mel_max; length=n_mels + 2) all_freqs = collect(range(0f0, Float32(sample_rate ÷ 2); length=n_freqs)) freq_points = _mel_to_hz.(mel_points) filter_banks = _triangular_filterbanks(freq_points, all_freqs) if any(maximum(filter_banks; dims=1) .≈ 0f0) @warn """At least one mel filterbank has all zero values. The value for `n_mels=$n_mels` may be set too high. Or the value for `n_freqs=$n_freqs` may be set too low. """ end return filter_banks end _hz_to_mel(freq::T) where T = T(2595) * log10(T(1) + (freq / T(700))) _mel_to_hz(mel::T) where T = T(700) * (T(10)^(mel / T(2595)) - T(1)) """ _triangular_filterbanks( freq_points::Vector{Float32}, all_freqs::Vector{Float32}) Create triangular filter banks. # Arguments: - `freq_points::Vector{Float32}`: Filter midpoints of size `n_filters`. - `all_freqs::Vector{Float32}`: Frequency points of size `n_freqs`. # Returns: Array of size `(n_freqs, n_filters)`. """ function _triangular_filterbanks( freq_points::Vector{Float32}, all_freqs::Vector{Float32}, ) diff = @view(freq_points[2:end]) .- @view(freq_points[1:end - 1]) slopes = transpose(reshape(freq_points, :, 1) .- reshape(all_freqs, 1, :)) down_slopes = -(@view(slopes[:, 1:end - 2]) ./ reshape(@view(diff[1:end - 1]), 1, :)) up_slopes = @view(slopes[:, 3:end]) ./ reshape(@view(diff[2:end]), 1, :) return max.(0f0, min.(down_slopes, up_slopes)) end ================================================ FILE: src/audio/spectrogram.jl ================================================ """ spectrogram(waveform; pad::Int = 0, n_fft::Int, hop_length::Int, window, center::Bool = true, power::Real = 2.0, normalized::Bool = false, window_normalized::Bool = false, ) Create a spectrogram or a batch of spectrograms from a raw audio signal. # Arguments - `pad::Int`: Then amount of padding to apply on both sides. - `window_normalized::Bool`: Whether to normalize the waveform by the window’s L2 energy. - `power::Real`: Exponent for the magnitude spectrogram (must be ≥ 0) e.g., `1` for magnitude, `2` for power, etc. If `0`, complex spectrum is returned instead. See [`stft`](@ref) for other arguments. # Returns Spectrogram in the shape `(T, F, B)`, where `T` is the number of window hops and `F = n_fft ÷ 2 + 1`. """ function spectrogram(waveform::AbstractArray{T}; pad::Int = 0, n_fft::Int, hop_length::Int, window, center::Bool = true, power::Real = 2.0, normalized::Bool = false, window_normalized::Bool = false, ) where T pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);) # Pack batch dimensions. sz = size(waveform) spec_ = stft(reshape(waveform, (sz[1], :)); n_fft, hop_length, window, center, normalized) # Unpack batch dimensions. spec = reshape(spec_, (size(spec_)[1:2]..., sz[2:end]...)) window_normalized && (spec = spec .* inv(norm(window));) if power > 0 p = T(power) spec = abs.(spec .+ eps(T)).^p end return spec end """ power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0) Convert a power spectrogram (amplitude squared) to decibel (dB) units. # Arguments - `s`: Input power. - `ref`: Scalar w.r.t. which the input is scaled. - `amin`: Minimum threshold for `s`. - `top_db`: Threshold the output at `top_db` below the peak: `max.(s_db, maximum(s_db) - top_db)`. # Returns `s_db ~= 10 * log10(s) - 10 * log10(ref)` """ function power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0) log_spec = 10f0 .* (log10.(max.(amin, s)) .- log10.(max.(amin, ref))) return max.(log_spec, maximum(log_spec) - top_db) end """ db_to_power(s_db; ref::Real = 1f0) Inverse of [`power_to_db`](@ref). """ function db_to_power(s_db; ref::Real = 1f0) return ref .* 10f0.^(s_db .* 0.1f0) end ================================================ FILE: src/audio/stft.jl ================================================ """ hamming_window( window_length::Int, ::Type{T} = Float32; periodic::Bool = true, α::T = T(0.54), β::T = T(0.46), ) where T <: Real Hamming window function (ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)). Generalized version of `hann_window`. ``w[n] = \\alpha - \\beta \\cos(\\frac{2 \\pi n}{N - 1})`` Where ``N`` is the window length. ```julia-repl julia> lineplot(hamming_window(100); width=30, height=10) ┌──────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠚⠉⠉⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠁⠀⠀⠀⠀⠀⠈⢢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⢰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⣠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⡀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⢰⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⡰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀│ │⠀⠀⠀⢀⠴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀│ │⠀⢀⡠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⣀⠀│ 0 │⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉│ └──────────────────────────────┘ ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀ ``` # Arguments: - `window_length::Int`: Size of the window. - `::Type{T}`: Elemet type of the window. # Keyword Arguments: - `periodic::Bool`: If `true` (default), returns a window to be used as periodic function. If `false`, return a symmetric window. Following always holds: ```jldoctest julia> N = 256; julia> hamming_window(N; periodic=true) ≈ hamming_window(N + 1; periodic=false)[1:end - 1] true ``` - `α::Real`: Coefficient α in the equation above. - `β::Real`: Coefficient β in the equation above. # Returns: Vector of length `window_length` and eltype `T`. """ function hamming_window( window_length::Int, ::Type{T} = Float32; periodic::Bool = true, α::T = T(0.54), β::T = T(0.46), ) where T <: Real window_length < 1 && throw(ArgumentError( "`window_length` must be > 0, instead: `$window_length`.")) n::T = ifelse(periodic, window_length, window_length - 1) scale = T(2) * π / n return [α - β * cos(scale * T(k)) for k in 0:(window_length - 1)] end """ hann_window( window_length::Int, ::Type{T} = Float32; periodic::Bool = true, ) where T <: Real Hann window function (ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)). ``w[n] = \\frac{1}{2}[1 - \\cos(\\frac{2 \\pi n}{N - 1})]`` Where ``N`` is the window length. ```julia-repl julia> lineplot(hann_window(100); width=30, height=10) ┌──────────────────────────────┐ 1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠚⠉⠉⠉⠢⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡔⠁⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⠀⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⠀⢀⡜⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀│ │⠀⠀⠀⠀⢀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀│ │⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠣⡀⠀⠀│ 0 │⣀⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢤⣀│ └──────────────────────────────┘ ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀ ``` # Arguments: - `window_length::Int`: Size of the window. - `::Type{T}`: Elemet type of the window. # Keyword Arguments: - `periodic::Bool`: If `true` (default), returns a window to be used as periodic function. If `false`, return a symmetric window. Following always holds: ```jldoctest julia> N = 256; julia> hann_window(N; periodic=true) ≈ hann_window(N + 1; periodic=false)[1:end - 1] true julia> hann_window(N) ≈ hamming_window(N; α=0.5f0, β=0.5f0) true ``` # Returns: Vector of length `window_length` and eltype `T`. """ function hann_window( window_length::Int, ::Type{T} = Float32; periodic::Bool = true, ) where T <: Real hamming_window(window_length, T; periodic, α=T(0.5), β=T(0.5)) end """ stft(x; n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, center::Bool = true, normalized::Bool = false, ) Short-time Fourier transform (STFT). The STFT computes the Fourier transform of short overlapping windows of the input, giving frequency components of the signal as they change over time. ``Y[\\omega, m] = \\sum_{k = 0}^{N - 1} \\text{window}[k] \\text{input}[m \\times \\text{hop length} + k] \\exp(-j \\frac{2 \\pi \\omega k}{\\text{n fft}})`` where ``N`` is the window length, ``\\omega`` is the frequency ``0 \\le \\omega < \\text{n fft}`` and ``m`` is the index of the sliding window. # Arguments: - `x`: Input, must be either a 1D time sequence (`(L,)` shape) or a 2D batch of time sequence (`(L, B)` shape). # Keyword Arguments: - `n_fft::Int`: Size of Fourier transform. - `hop_length::Int`: Distance between neighboring sliding window frames. - `window`: Optional window function to apply. Must be 1D vector `0 < length(window) ≤ n_fft`. If window is shorter than `n_fft`, it is padded with zeros on both sides. If `nothing` (default), then no window is applied. - `center::Bool`: Whether to pad input on both sides so that ``t``-th frame is centered at time ``t \\times \\text{hop length}``. Padding is done with `pad_reflect` function. - `normalized::Bool`: Whether to return normalized STFT, i.e. multiplied with ``\\text{n fft}^{-0.5}``. # Returns: Complex array of shape `(n_fft, n_frames, B)`, where `B` is the optional batch dimension. """ function stft end """ istft(y; n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing, center::Bool = true, normalized::Bool = false, return_complex::Bool = false, original_length::Union{Nothing, Int} = nothing, ) Inverse Short-time Fourier Transform. Return the least squares estimation of the original signal # Arguments: - `y`: Input complex array in the `(n_fft, n_frames, B)` shape. Where `B` is the optional batch dimension. # Keyword Arguments: - `n_fft::Int`: Size of Fourier transform. - `hop_length::Int`: Distance between neighboring sliding window frames. - `window`: Window function that was applied to the input of `stft`. If `nothing` (default), then no window was applied. - `center::Bool`: Whether input to `stft` was padded on both sides so that ``t``-th frame is centered at time ``t \\times \\text{hop length}``. Padding is done with `pad_reflect` function. - `normalized::Bool`: Whether input to `stft` was normalized. - `return_complex::Bool`: Whether the output should be complex, or if the input should be assumed to derive from a real signal and window. - `original_length::Union{Nothing, Int}`: Optional size of the first dimension of the input to `stft`. Helps restoring the exact `stft` input size. Otherwise, the array might be a bit shorter. """ function istft end ================================================ FILE: src/batched/batchedadjtrans.jl ================================================ import Base: - import Adapt: adapt_structure, adapt _batched_doc = """ batched_transpose(A::AbstractArray{T,3}) batched_adjoint(A) Equivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`. These exist to control how `batched_mul` behaves, as it operates on such matrix slices of an array with `ndims(A)==3`. `PermutedDimsArray(A, (2,1,3))` is equivalent to `batched_transpose(A)`, and is also understood by `batched_mul` (and more widely supported elsewhere). BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3} BatchedAdjoint{T, S} Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose` etc. """ @doc _batched_doc struct BatchedTranspose{T, S} <: AbstractArray{T, 3} parent::S BatchedTranspose{T, S}(X::S) where {T, S} = new{T, S}(X) end @doc _batched_doc batched_transpose(A::AbstractArray{T, 3}) where T = BatchedTranspose(A) batched_transpose(A::BatchedTranspose) = A.parent @doc _batched_doc struct BatchedAdjoint{T, S} <: AbstractArray{T, 3} parent::S BatchedAdjoint{T, S}(X::S) where {T, S} = new{T, S}(X) end @doc _batched_doc batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A) batched_adjoint(A::BatchedAdjoint) = A.parent batched_adjoint(A::BatchedTranspose{<:Real}) = A.parent batched_transpose(A::BatchedAdjoint{<:Real}) = A.parent batched_adjoint(A::PermutedDimsArray{<:Real,3,(2,1,3)}) = A.parent batched_transpose(A::PermutedDimsArray{<:Number,3,(2,1,3)}) = A.parent # if you can't unwrap, put BatchedAdjoint outside (for dispatch): batched_transpose(A::BatchedAdjoint{<:Complex}) = BatchedAdjoint(BatchedTranspose(A.parent)) BatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A) BatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A) const BatchedAdjOrTrans{T, S} = Union{BatchedTranspose{T, S}, BatchedAdjoint{T, S}} LinearAlgebra.wrapperop(A::BatchedAdjoint) = batched_adjoint LinearAlgebra.wrapperop(B::BatchedTranspose) = batched_transpose # AbstractArray Interface Base.length(A::BatchedAdjOrTrans) = length(A.parent) Base.size(m::BatchedAdjOrTrans) = (size(m.parent, 2), size(m.parent, 1), size(m.parent, 3)) Base.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.parent, 3)) Base.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian() Base.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k) Base.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k)) Base.@propagate_inbounds Base.setindex!(m::BatchedTranspose, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k) Base.@propagate_inbounds Base.setindex!(m::BatchedAdjoint, v, i::Int, j::Int, k::Int) = setindex!(m.parent, adjoint(v), j, i, k) Base.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims) Base.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims) Base.similar(A::BatchedAdjOrTrans, T::Type) = similar(A.parent, T, size(A)) Base.similar(A::BatchedAdjOrTrans) = similar(A.parent, size(A)) Base.parent(A::BatchedAdjOrTrans) = A.parent (-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent) (-)(A::BatchedTranspose) = BatchedTranspose(-A.parent) # C interface function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}) sp = strides(A.parent) (sp[2], sp[1], sp[3]) end function Base.stride(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}, d::Integer) d == 1 && return Base.stride(A.parent, 2) d == 2 && return Base.stride(A.parent, 1) Base.stride(A.parent, d) end Base.pointer(A::BatchedAdjOrTrans) = pointer(parent(A)) Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} = Base.unsafe_convert(Ptr{T}, parent(A)) # Gradients function rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3}) b_transpose_back(Δ) = (NoTangent(), batched_transpose(unthunk(Δ))) batched_transpose(A), b_transpose_back end function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3}) b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(unthunk(Δ))) batched_adjoint(A), b_adjoint_back end adapt_structure(to, x::BatchedAdjoint) = BatchedAdjoint(adapt(to, parent(x))) adapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x))) Broadcast.BroadcastStyle(::Type{<:BatchedAdjOrTrans{T, S}}) where {T, S} = Broadcast.BroadcastStyle(S) ================================================ FILE: src/batched/batchedmul.jl ================================================ _unbatch(A) = A _unbatch(A::BatchedAdjOrTrans) = parent(A) """ batched_mul(A, B) -> C A ⊠ B # \\boxtimes Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent any indices in the last dimensions. If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. To transpose each matrix, apply `batched_transpose` to the array, or `batched_adjoint` for conjugate-transpose: ```jldoctest julia> A, B = randn(2,5,17), randn(5,9,17); julia> A ⊠ B |> size (2, 9, 17) julia> batched_adjoint(A) |> size (5, 2, 17) julia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size (2, 9, 17) julia> A ⊠ randn(5,9,1) |> size (2, 9, 17) julia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3)) true ``` The equivalent `PermutedDimsArray` may be used in place of `batched_transpose`. Other permutations are also handled by BLAS, provided that the batch index `k` is not the first dimension of the underlying array. Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine. However, `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS, since the batch dimension is the contiguous one: `stride(A,3) == 1`. This will be copied, as doing so is faster than `batched_mul_generic!`. Both this `copy` and `batched_mul_generic!` produce `@debug` messages, and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. """ function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} batch_size = size(x)[3:end] @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." x2 = reshape(x, size(x, 1), size(x, 2), :) y2 = reshape(y, size(y, 1), size(y, 2), :) z = batched_mul(x2, y2) return reshape(z, size(z, 1), size(z, 2), batch_size...) end function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != B")) _batched_mul(storage_typejoin(A, B), A, B) end const ⊠ = batched_mul function _batched_mul(::Type, A, B) T = promote_type(eltype(A), eltype(B)) C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3)))) batched_mul!(C, A, B) C end function _batched_mul(::Type{<:DenseArray{T}}, A, B) where {T<:BlasFloat} C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3)))) batched_mul!(C, _copy_if_faster(A), _copy_if_faster(B)) C end function _copy_if_faster(X::AbstractArray{<:Number, 3}) is_strided(X) || return X if Base.stride(X, 3) == 1 && Base.stride(X, 1) != 1 @debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(X) return copy(X) end X end function _copy_if_faster(X::BatchedAdjoint{<:Complex}) Xbase = _unbatch(X) is_strided(Xbase) || return X if Base.stride(Xbase, 1) != 1 @debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(_unbatch(X)) return copy(X) # or batched_adjoint(copy(Xbase)), may be better on GPU? end X end # Gradient, allowing that size(A,3)==1 means it's "broadcasted" out to size(B,3) function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}) function batched_mul_pullback(_Δ) Δ = unthunk(_Δ) Athunk = @thunk begin tmp = batched_mul(Δ, batched_adjoint(B)) size(A,3) == 1 ? sum(tmp, dims=3) : tmp end Bthunk = @thunk begin tmp = batched_mul(batched_adjoint(A), Δ) size(B,3) == 1 ? sum(tmp, dims=3) : tmp end return (NoTangent(), Athunk, Bthunk) end batched_mul(A, B), batched_mul_pullback end """ batched_mul(A::Array{T,3}, B::Matrix) batched_mul(A::Matrix, B::Array{T,3}) A ⊠ B This is always matrix-matrix multiplication, but either `A` or `B` may lack a batch index. * When `B` is a matrix, result has `C[:,:,k] == A[:,:,k] * B[:,:]` for all `k`. * When `A` is a matrix, then `C[:,:,k] == A[:,:] * B[:,:,k]`. This can also be done by reshaping and calling `*`, for instance `A ⊡ B` using TensorCore.jl, but is implemented here using `batched_gemm` instead of `gemm`. ```jldoctest julia> randn(16,8,32) ⊠ randn(8,4) |> size (16, 4, 32) julia> randn(16,8,32) ⊠ randn(8,4,1) |> size # equivalent (16, 4, 32) julia> randn(16,8) ⊠ randn(8,4,32) |> size (16, 4, 32) ``` See also `batched_vec` to regard `B` as a batch of vectors, `A[:,:,k] * B[:,k]`. """ batched_mul(A::AbstractArray{T,3} where T, B::AbstractMatrix) = _semi_batched_mul(A,B) # Simplify signature of batched_mul by hiding dispatch on Adjoint etc: _semi_batched_mul(A::AbstractArray{<:Any,3}, B::AbstractMatrix) = batched_mul(A, reshape(B, size(B)..., 1)) _semi_batched_mul(A::AbstractArray{<:Any,3}, B::Adjoint{<:Number,<:AbstractMatrix}) = batched_mul(A, batched_adjoint(reshape(parent(B), size(parent(B))..., 1))) _semi_batched_mul(A::AbstractArray{<:Any,3}, B::Transpose{<:Number,<:AbstractMatrix}) = batched_mul(A, batched_transpose(reshape(parent(B), size(parent(B))..., 1))) batched_mul(A::AbstractMatrix, B::AbstractArray{T,3} where T) = _semi_batched_mul(A,B) _semi_batched_mul(A::AbstractMatrix, B::AbstractArray{<:Any,3}) = batched_mul(reshape(A, size(A)..., 1), B) _semi_batched_mul(A::Adjoint{<:Number,<:AbstractMatrix}, B::AbstractArray{<:Any,3}) = batched_mul(batched_adjoint(reshape(parent(A), size(parent(A))..., 1)), B) _semi_batched_mul(A::Transpose{<:Number,<:AbstractMatrix}, B::AbstractArray{<:Any,3}) = batched_mul(batched_transpose(reshape(parent(A), size(parent(A))..., 1)), B) """ batched_vec(A::AbstractArray{T,3}, B::AbstractMatrix) batched_vec(A::AbstractArray{T,3}, b::AbstractVector) batched_vec(A::AbstractArray, B::AbstractArray) Batched matrix-vector multiplication. For the 3D case: the result has `C[:,:,k] == A[:,:,k] * B[:,k]` for all `k`, or else `C[:,:,k] == A[:,:,k] * b` for `b::Vector`. For the general N-D case where `ndims(A) == ndims(B) + 1`: the result has `C[:,k...] == A[:,:,k...] * B[:,k...]` for all batch indices `k...`. The batch dimensions must match: `size(A)[3:end] == size(B)[2:end]`. With the same argument types, `batched_mul(A, B)` would regard `B` as a fixed matrix, not a batch of vectors. Both reshape and then call `batched_mul(::Array{T,3}, ::Array{T,3})`. ```jldoctest julia> A, B, b = randn(16,8,32), randn(8,32), randn(8); julia> batched_vec(A,B) |> size (16, 32) julia> batched_vec(A,b) |> size (16, 32) julia> A4d, B3d = randn(16,8,10,32), randn(8,10,32); # 4D and 3D arrays julia> batched_vec(A4d, B3d) |> size (16, 10, 32) ``` """ function batched_vec(A::AbstractArray, B::AbstractArray) ndims(A) == ndims(B) + 1 || throw(DimensionMismatch( "batched_vec requires ndims(A) == ndims(B) + 1, got ndims(A)=$(ndims(A)) and ndims(B)=$(ndims(B))")) size(A)[3:end] == size(B)[2:end] || throw(DimensionMismatch( "batch dimensions must match: size(A)[3:end]=$(size(A)[3:end]) != size(B)[2:end]=$(size(B)[2:end])")) # Reshape B to add a singleton dimension for matrix multiplication B_reshaped = reshape(B, size(B, 1), 1, size(B)[2:end]...) # Perform batched multiplication C = batched_mul(A, B_reshaped) # Remove the singleton dimension return dropdims(C, dims=2) end batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) = reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3)) # If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway: batched_vec(A::AbstractArray{T,3} where T, B::AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}) = batched_vec(A, copy(B)) batched_vec(A::AbstractArray{T,3} where T, b::AbstractVector) = reshape(batched_mul(A, reshape(b, length(b), 1, 1)), size(A,1), size(A,3)) """ batched_mul!(C, A, B) -> C batched_mul!(C, A, B, α=1, β=0) In-place batched matrix multiplication, equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`. If `size(B,3) == 1` then every batch uses `B[:,:,1]` instead. This will call `batched_gemm!` whenever possible. For real arrays this means that, for `X ∈ [A,B,C]`, either `stride(X,1)==1` or `stride(X,2)==1`, the latter may be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`. Unlike `batched_mul` this will never make a copy. For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen. In this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then only `stride(AorB::BatchedAdjoint,2) == 1` is accepted. """ function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}, α::Number=one(T), β::Number=zero(T)) where {T} _batched_mul!(storage_typejoin(C,A,B), C, A, B, α, β) C end _batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β) _batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} = _batched_try_gemm!(DT, C, A, B, α, β) function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} alpha, beta = promote(α, β, zero(T)) alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β) are_strided(_unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β) C isa StridedArray || return batched_mul_generic!(C, A, B, α, β) blasA, transA = if A isa BatchedAdjoint && T <: Complex Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β) parent(A), 'C' elseif Base.stride(A,2) == 1 && size(A,1) > 1 batched_transpose(A), 'T' elseif Base.stride(A,1) == 1 A, 'N' elseif Base.stride(A,2) == 1 # This is awful, but exhaustively tested. Issues 268, 282. batched_transpose(A), 'T' else return batched_mul_generic!(C, A, B, α, β) end blasB, transB = if B isa BatchedAdjoint && T <: Complex Base.stride(parent(B),1) == 1 || return batched_mul_generic!(C, A, B, α, β) parent(B), 'C' elseif Base.stride(B,2) == 1 && size(B,1) > 1 batched_transpose(B), 'T' elseif Base.stride(B,1) == 1 B, 'N' elseif Base.stride(B,2) == 1 batched_transpose(B), 'T' else return batched_mul_generic!(C, A, B, α, β) end _batched_gemm!(DT, transA, transB, alpha, blasA, blasB, beta, C) C end _batched_gemm!(::Type{<:Array}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = batched_gemm!(transA, transB, α, A, B, β, C) _BATCHED_LIST = [ (:(AbstractArray{<:Any, 3}), :identity), (:BatchedTranspose, :transpose), (:BatchedAdjoint, :adjoint), ] for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST @eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB, α::Number=one(T), β::Number=zero(T)) where {T} size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C")) size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C")) @debug "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C) Abase, Bbase = _unbatch(A), _unbatch(B) sA, oA = size(A,3) == 1 ? (0,1) : (1,0) sB, oB = size(B,3) == 1 ? (0,1) : (1,0) @inbounds for k in 1:size(C,3) @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β) end C end end """ storage_type(A) -> Type Removes all wrappers to return the `Array` or `CuArray` (or whatever) type within. ``` julia> view(reshape(ones(10)',2,5),:, 3:4) |> storage_type Array{Float64,1} julia> reshape(sparse(rand(10)), 5,2) |> storage_type SparseVector{Float64,Int64} ``` """ function storage_type(A::AbstractArray) P = parent(A) typeof(A) === typeof(P) ? typeof(A) : storage_type(P) end storage_type(A) = typeof(A) """ storage_typejoin(A, B, C, ...) -> Type Reduces with `Base.promote_typejoin`, in order that this conveys useful information for dispatching to BLAS. It does not tell you what container to allocate: ``` julia> storage_typejoin(rand(2), rand(Float32, 2)) Array{T,1} where T julia> eltype(ans) <: LinearAlgebra.BlasFloat false julia> storage_typejoin(rand(2), rand(2,3), rand(2,3,4)) Array{Float64,N} where N ``` """ storage_typejoin(A, Bs...) = Base.promote_typejoin(storage_type(A), storage_typejoin(Bs...)) storage_typejoin(A) = storage_type(A) """ is_strided(A::AbstractArray) -> Bool This generalises `A isa StridedArray` to treat wrappers like `A::PermutedDimsArray`, for which it returns `is_strided(parent(A))`. It returns `true` for `CuArray`s, and `PermutedDimsArray`s of those. Other wrappers (defined outside Base, LinearAlgebra) are assumed not to break strided-ness, and hence also return `is_strided(parent(A))`. This correctly handles things like `NamedDimsArray` wihch don't alter indexing. However, it's a little pessimistic in that e.g. a `view` of such a container will return `false`, even in cases where the same `view` of `parent(A)` would be a `StridedArray`. """ is_strided(A::StridedArray) = true is_strided(A) = false function is_strided(A::AbstractArray) M = parentmodule(typeof(A)) if parent(A) === A # SparseMatrix, StaticArray, etc false elseif M === Base || M === Core || M ===LinearAlgebra # bad reshapes, etc, plus Diagonal, UpperTriangular, etc. false else is_strided(parent(A)) # PermutedDimsArray, NamedDimsArray end end is_strided(A::BatchedAdjoint) = eltype(A) <: Real && is_strided(parent(A)) is_strided(A::BatchedTranspose) = is_strided(parent(A)) is_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A)) is_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A)) # This needs Compat 3.14, for any Julia < 1.6 are_strided(As...) = mapfoldl(is_strided, &, As; init=true) ================================================ FILE: src/bias_act.jl ================================================ using NNlib: fast_act, tanh_fast using ChainRulesCore const RCR = RuleConfig{>:HasReverseMode} # This just saves typing `only.(only.(` many times: @inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x))) # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` # is independent of `x`, as `return_type` says `Union{}` when calling is an error. struct NotaNumber <: Real end """ bias_act!(σ, x, b) This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh` with `sigmoid_fast` & `tanh_fast`. It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`. When used within a gradient, it will overwrite only when `σ` has a method of `derivatives_given_output` which does not need the input at all. Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative contains only `Ω` (the output) not `x`. !!! warning This is not safe to use if `x` is still needed for the gradient of some other function. Incorrect use will give silently wrong answers. It is intended mainly for Flux layers, in which the previous operation is known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer. """ bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) = _fast_broadcast!(fast_act(σ, x)∘(+), x, b) # works around a SIMD bug function bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") _fast_broadcast!(fast_act(σ, x), x) end function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") x # pass-through end function bias_act!(σ::Function, x::AbstractArray, b) b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") fast_act(σ, x).(x .+ b) # fallback end function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B} biasgrad = if eltype(B) !== Bool # Summing over ndims(x)+1 is a trick to make b_dims type-stable dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) _biasgrad(dx) = reshape(sum(dx; dims), size(b)) else Returns(NoTangent()) end # Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ if isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, NotaNumber})) Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat} function bias_act!_fastback(Δ) # Tempting to overwrite x again, but only safe if you call pullback at most once, # TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340 # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ) return (NoTangent(), NoTangent(), dx, biasgrad(dx)) end return Ω, bias_act!_fastback # # Slower path: can't overwrite x, but can use derivatives_given_output # # This case is WRONG and tests fail, but not sure why # elseif isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, T})) # Ω2 = fast_act(σ, x).(x) .+ b # @show σ b # function bias_act!_back2(Δ) # dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ) # return (NoTangent(), NoTangent(), dx, biasgrad(dx)) # end # return Ω2, bias_act!_back2 # Fallback path: let AD handle the broadcast else Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b)) @inline function bias_act!_slowback(Δ) _, _, dx = back(Δ) return (NoTangent(), NoTangent(), dx, biasgrad(dx)) end return Ω3, bias_act!_slowback end end # Two easy cases with identity function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B} dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) biasgrad(dx) = reshape(sum(dx; dims), size(b)) function bias_act!_idback(Δ) dx = unthunk(Δ) return (NoTangent(), NoTangent(), dx, biasgrad(dx)) end return bias_act!(identity, x, b), bias_act!_idback end function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N} bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent()) return x, bias_act!_trivial end ================================================ FILE: src/conv.jl ================================================ ## Convolution API # # We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d, # 2d and 3d convolutions, based on the rank of the input tensors, in both mutating and # non-mutating auto-allocating variants: # - Convolution: # - conv(x, w, cdims) # - conv!(y, x, w, cdims) # - Convolution data backpropagation # - ∇conv_data(dy, w, cdims) # - ∇conv_data!(dx, dy, w, cdims) # - Convolution filter backpropagation # - ∇conv_filter(x, dy, cdims) # - ∇conv_filter!(dw, x, dy, cdims) # # All methods require a `ConvDims` object to define the dimensions and optional # elements of the convolution (padding, stride, dilation, kernel-flipping, etc...), # which is easily constructable through something like `DenseConvDims(x, w)`. All # methods take in the `ConvDims` of the associated normal, forward-pass convolution, # that is, the following is legal: # # cdims = ConvDims(x, w; stride=2, dilation=(3,2)) # dx = ∇conv_data(conv(x, w, cdims), w, cdims) # The computational flow, starting from the user facing functions, # goes through the following steps: # # STEP 1: # use ConvDims objects (only for `conv` and `depthwiseconv`) # STEP 2: # define autoallocating version (frontend and implementations) # STEP 3: # reshape to 3d convolutions (frontend and implementions) # STEP 4: # choose implementation # TODO: should we also add # STEP X: # use homogeneus datatypes # to handle etherogeneus inputs now handled by conv_direct? ########## STEP 1 ############ """ conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types. """ function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N} stride = expand(Val(N - 2), stride) padding = expand(Val(N - 2), pad) dilation = expand(Val(N - 2), dilation) cdims = DenseConvDims( size(x), size(w); stride, padding, dilation, flipkernel=flipped, groups) return conv(x, w, cdims) end """ depthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false) Depthwise convolution operation with filter `w` on input `x`. `x` and `w` are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively. """ function depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N} stride = expand(Val(N-2), stride) pad = expand(Val(N-2), pad) dilation = expand(Val(N-2), dilation) cdims = DepthwiseConvDims(x, w; stride=stride, padding=pad, dilation=dilation, flipkernel=flipped) return depthwiseconv(x, w, cdims) end ############################## ########### STEP 2 ################### # Let's generate auto-allocating versions of all our functions, for all backends. # We `@timeit` these methods separately, as we want to know how much time is spent in # allocation. :P for backend in (Symbol(), :_direct, :_im2col) # First make auto-allocating versions of the conv()-like calls: for name in (:conv, :depthwiseconv) @eval begin function $(Symbol("$(name)$(backend)"))( x::AbstractArray{xT,N}, w::AbstractArray{wT,N}, cdims::ConvDims; kwargs...) where {xT, wT, N} y = similar(x, promote_type(xT, wT), output_size(cdims)..., channels_out(cdims), size(x,N)) return $(Symbol("$(name)$(backend)!"))(y, x, w, cdims; kwargs...) end end end for name in (:∇conv_data, :∇depthwiseconv_data) @eval begin function $(Symbol("$(name)$(backend)"))( dy::AbstractArray{yT,N}, w::AbstractArray{wT,N}, cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims} dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N)) return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...) end end end # We do the conv/depthwiseconv filter backprops separately, as the shape calculation # for `w` is slightly different for depthwise than for normal dense convolution. @eval begin function $(Symbol("∇conv_filter$(backend)"))( x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, cdims::ConvDims; kwargs...) where {xT, yT, N} dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims), channels_out(cdims)) return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...) end end @eval begin function $(Symbol("∇depthwiseconv_filter$(backend)"))( x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, cdims::ConvDims; kwargs...) where {xT, yT, N} dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims), channels_in(cdims)) return $(Symbol("∇depthwiseconv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...) end end end ########################################## ########## STEP 3 ############ # Our strategy for 1d and 2d convolution is to reshape to 3d convolutions, which # makes things MUCH EASIER for us on the backend side, and is in general pretty fast, # since we can specialize on sizes. for front_name in (:conv, :∇conv_data, :∇conv_filter, :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter) for backend in (Symbol(), :_direct, :_im2col) for N in (3, 4) @eval begin function $(Symbol("$(front_name)$(backend)!"))( y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N}, w::AbstractArray{wT,$N}, cdims::ConvDims; kwargs...) where {yT, xT, wT} $(Symbol("$(front_name)$(backend)!"))( insert_singleton_spatial_dimension(y, $(5 - N)), insert_singleton_spatial_dimension(x, $(5 - N)), insert_singleton_spatial_dimension(w, $(5 - N)), insert_singleton_spatial_dimension(cdims, $(5 - N)); kwargs... ) # We explicitly return `y` here, because the backend call # itself may return a reshaped view, which we don't want. return y end end end end end ####################################### ########### STEP 4 ############ # First, we will define mappings from the generic API names to our accelerated backend # implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using # im2col + GEMM. # But we always support a fallback, non-accelerated path, where we use the direct, but # slow, implementations. These should not typically be used, hence the `@warn`, # These are the GEMM types we will accelerate with `im2col` const G = Union{[x[2] for x in gemm_datatype_mappings]...} for (front_name, backend, signature) in ( # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types))) (:conv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:conv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), ) # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution @eval begin function $(Symbol("$(front_name)!"))( out::AbstractArray{$(signature[1][1]), $(signature[1][2])}, in1::AbstractArray{$(signature[2][1]), $(signature[1][2])}, in2::AbstractArray{$(signature[3][1]), $(signature[1][2])}, cdims::$(signature[4]); kwargs...) where {$(signature[5]...)} if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 end x_cs = Iterators.partition(1:size(in1, 4), channels_in(cdims) ÷ groupcount(cdims)) w_cs = Iterators.partition(1:size(in2, 5), channels_out(cdims) ÷ groupcount(cdims)) cdims2 = basetype(C)(cdims, G = 1, C_in = channels_in(cdims) ÷ groupcount(cdims), C_out = channels_out(cdims) ÷ groupcount(cdims)) function conv_group(xc, wc) x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...] w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...] y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...] $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...) end if should_use_spawn() && length(x_cs) > 1 Threads.@sync for (xc, wc) in zip(x_cs, w_cs) Threads.@spawn conv_group(xc, wc) end else for (xc, wc) in zip(x_cs, w_cs) conv_group(xc, wc) end end return out end end end # im2col-accelerated function forwarding definition for (front_name, backend, signature) in ( # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types))) (:∇conv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:∇conv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), ) # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution @eval begin function $(Symbol("$(front_name)!"))( out::AbstractArray{$(signature[1][1]), $(signature[1][2])}, in1::AbstractArray{$(signature[2][1]), $(signature[1][2])}, in2::AbstractArray{$(signature[3][1]), $(signature[1][2])}, cdims::$(signature[4]); kwargs...) where {$(signature[5]...)} if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 end dx_cs = Iterators.partition(1:size(out, 4), channels_in(cdims) ÷ groupcount(cdims)) w_cs = Iterators.partition(1:size(in2, 5), channels_out(cdims) ÷ groupcount(cdims)) dy_cs = Iterators.partition(1:size(in1, 4), channels_out(cdims) ÷ groupcount(cdims)) cdims2 = basetype(C)(cdims, G = 1, C_in = channels_in(cdims) ÷ groupcount(cdims), C_out = channels_out(cdims) ÷ groupcount(cdims)) function ∇conv_data_group(xc, yc, wc) dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...] dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...] wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...] $(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...) end if should_use_spawn() && length(dx_cs) > 1 Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs) Threads.@spawn ∇conv_data_group(xc, yc, wc) end else for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs) ∇conv_data_group(xc, yc, wc) end end return out end end end for (front_name, backend, signature) in ( # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types))) (:∇conv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:∇conv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), ) # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution @eval begin function $(Symbol("$(front_name)!"))( out::AbstractArray{$(signature[1][1]), $(signature[1][2])}, in1::AbstractArray{$(signature[2][1]), $(signature[1][2])}, in2::AbstractArray{$(signature[3][1]), $(signature[1][2])}, cdims::$(signature[4]); kwargs...) where {$(signature[5]...)} if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 end dw_cs = Iterators.partition(1:size(out, 5), channels_out(cdims) ÷ groupcount(cdims)) dy_cs = Iterators.partition(1:size(in2, 4), channels_out(cdims) ÷ groupcount(cdims)) x_cs = Iterators.partition(1:size(in1, 4), channels_in(cdims) ÷ groupcount(cdims)) cdims2 = basetype(C)(cdims, G = 1, C_in = channels_in(cdims) ÷ groupcount(cdims), C_out = channels_out(cdims) ÷ groupcount(cdims)) function ∇conv_filter_group(wc, xc, yc) x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...] dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...] dw = @view out[ntuple(i -> i == 5 ? wc : Colon(), 5)...] $(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...) end if should_use_spawn() && length(dw_cs) > 1 Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs) Threads.@spawn ∇conv_filter_group(wc, xc, yc) end else for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs) ∇conv_filter_group(wc, xc, yc) end end return out end end end for (front_name, backend, signature) in ( # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types))) (:depthwiseconv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:depthwiseconv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), (:∇depthwiseconv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:∇depthwiseconv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), (:∇depthwiseconv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))), (:∇depthwiseconv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))), ) # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution @eval begin # im2col-accelerated function forwarding definition function $(Symbol("$(front_name)!"))( out::AbstractArray{$(signature[1][1]), $(signature[1][2])}, in1::AbstractArray{$(signature[2][1]), $(signature[1][2])}, in2::AbstractArray{$(signature[3][1]), $(signature[1][2])}, cdims::$(signature[4]); kwargs...) where {$(signature[5]...)} if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 end $(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...) end end end for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims] @eval @non_differentiable $Dims(::Any...) end colmajor(x) = (is_strided(x) && Base.stride(x, 1) == 1) ? x : collect(x) for conv in [:conv, :depthwiseconv] local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback) @eval function rrule(::typeof($conv), x, w, cdims; kw...) function $conv_pullback(Δraw) Δ = colmajor(unthunk(Δraw)) return ( NoTangent(), @thunk($∇conv_data(Δ, w, cdims, kw...)), @thunk($∇conv_filter(x, Δ, cdims, kw...)), NoTangent(), ) end return $conv(x, w, cdims; kw...), $conv_pullback end @eval function rrule(::typeof($∇conv_data), x, w, cdims; kw...) function $∇conv_data_pullback(Δraw) Δ = colmajor(unthunk(Δraw)) return ( NoTangent(), @thunk($conv(Δ, w, cdims, kw...)), @thunk($∇conv_filter(Δ, x, cdims, kw...)), NoTangent(), ) end return $∇conv_data(x, w, cdims; kw...), $∇conv_data_pullback end end function rrule(::typeof(∇conv_filter), x, dy, cdims; kw...) function ∇conv_filter_pullback(Δ) Δ1 = colmajor(unthunk(Δ)) return ( NoTangent(), @thunk(∇conv_data(dy, Δ1, cdims, kw...)), @thunk(conv(x, Δ1, cdims, kw...)), NoTangent(), ) end return ∇conv_filter(x, dy, cdims; kw...), ∇conv_filter_pullback end ================================================ FILE: src/conv_bias_act.jl ================================================ function conv_bias_act(x::AbstractArray{xT,N}, w::AbstractArray{wT,N}, cdims::ConvDims, b::AbstractArray{bT,N}, σ=identity; kwargs...) where {xT, wT, bT, N} y = similar(x, promote_type(xT, wT, bT), output_size(cdims)..., channels_out(cdims), size(x,N)) conv_bias_act!(y, x, w, cdims, b, σ; kwargs...) return y end function conv_bias_act!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5}, cdims::ConvDims, b::AbstractArray{bT,5}, σ=identity; kwargs...) where {yT, xT, wT, bT} conv!(y, x, w, cdims) y .= σ.(y .+ b) return y end for N in (3, 4) @eval begin function $(Symbol("conv_bias_act!"))( y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N}, w::AbstractArray{wT,$N}, cdims::ConvDims, b::AbstractArray{bT,$N}, σ=identity; kwargs...) where {yT, xT, wT, bT} $(Symbol("conv_bias_act!"))( insert_singleton_spatial_dimension(y, $(5 - N)), insert_singleton_spatial_dimension(x, $(5 - N)), insert_singleton_spatial_dimension(w, $(5 - N)), insert_singleton_spatial_dimension(cdims, $(5 - N)), insert_singleton_spatial_dimension(b, $(5 - N)), σ; kwargs... ) # We explicitly return `y` here, because the backend call # itself may return a reshaped view, which we don't want. return y end end end ================================================ FILE: src/ctc.jl ================================================ # CTC loss moved from Flux.jl to NNlib ## CPU implementation """ logaddexp(a, b) Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))` """ function logaddexp(a, b) isinf(a) && return b isinf(b) && return a # always want the greater number on the left in the exponentiation; # the magnitude difference may end up making the number very positive # which will cause exp() to return Inf # E.g., a = -900, b = -800, will give exp(-800 - -900), which will be # Inf for Float32 values if a < b a, b = b, a end return a + log(1+exp(b-a)) end """ add_blanks(z) Adds blanks to the start and end of `z`, and between items in `z` """ function add_blanks(z, blank) z′ = fill(blank, 2*length(z) + 1) z′[2 .* eachindex(z)] = z return z′ end function ctc_alpha(ŷ::AbstractArray, y) typed_zero = zero(ŷ[1]) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) z′ = add_blanks(y, blank) T = size(ŷ, 2) U′ = length(z′) α = fill(log(typed_zero), U′, T) α[1,1] = ŷ[blank, 1] α[2,1] = ŷ[z′[2], 1] for t=2:T bound = max(1, U′ - 2(T - t) - 1) for u=bound:U′ if u == 1 α[u,t] = α[u, t-1] else α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) # array bounds check and f(u) function from Eq. 7.9 if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u]) α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) end end α[u,t] += ŷ[z′[u], t] end end return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ) end function ∇ctc_loss(ŷ::AbstractArray, y, out) loss, α, z′, ŷ = out U′, T = size(α) blank = size(ŷ, 1) typed_zero = zero(first(α)) # Calculate beta coefficients, from the bottom-right, to the upper-left β = fill(log(typed_zero), U′, T) # Fill bottom-right corner so bounding errors can be avoided # by starting `u` at `U′-1` β[U′, T] = typed_zero β[U′-1, T] = typed_zero # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 for t=(T-1):-1:1 bound = min(U′, 2t) for u=bound:-1:1 if u == U′ β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] else β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) # array bounds check and g(u) function from Eq. 7.16 if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2] β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) end end end end # Accumulate alpha-beta products for each category, # then calculate gradients accum = fill(log(typed_zero), size(ŷ)) for t=1:T for u=1:U′ accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t]) end end grads = exp.(ŷ) .- exp.(accum .+ loss) return grads end """ ctc_loss(ŷ, y) Computes the connectionist temporal classification loss between `ŷ` and `y`. `ŷ` must be a classes-by-time matrices, i.e., each row represents a class and each column represents a time step. Additionally, the `logsoftmax` function will be applied to `ŷ`, so `ŷ` must be the raw activation values from the neural network and not, for example, the activations after being passed through a `softmax` activation function. `y` must be a 1D array of the labels associated with `ŷ`. The blank label is assumed to be the last label category in `ŷ`, so it is equivalent to `size(ŷ, 1)`. Used for sequence-to-sequence classification problems such as speech recognition and handwriting recognition where the exact time-alignment of the output (e.g., letters) is not needed to solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf) or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7) for mathematical details. """ ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y) tmp = ctc_alpha(ŷ, y) ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent()) return tmp.loss, ctc_loss_pullback end ================================================ FILE: src/deprecations.jl ================================================ ### Deprecated while v0.8 was latest export ∇softmax, ∇softmax!, logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax! function ∇softmax!(out::AbstractArray, Δ::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1) Base.depwarn("`∇softmax!(dx, dy, x, y)` is deprecated, just use `∇softmax_data(dy, y)`", :∇softmax!) # Removed because using a mutating function blocks 2nd derivatives, and # the CUDA overload was slow anyway, https://github.com/FluxML/NNlibCUDA.jl/issues/30 out .= Δ .* y out .= out .- y .* sum(out; dims) end function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1) Base.depwarn("`∇logsoftmax!(dx, dy, x, y)` is deprecated, just use `∇logsoftmax_data(dy, y)`", :∇softmax!) out .= Δ .- sum(Δ; dims) .* exp.(y) end function ∇softmax(dy::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} # Removed because there's no need to close over `x` here, that was done only to distinguish # this from `∇softmax(Δ, x; dims = 1)` which re-computed `y = softmax(x)`, which is slow. Base.depwarn("`∇softmax(dy, x, y)` should be replaced with `∇softmax_data(dy, y)`", :∇softmax) ∇softmax_data(dy, y) end function ∇logsoftmax(dy::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1) Base.depwarn("`∇logsoftmax(dy, x, y)` should be replaced with `∇logsoftmax_data(dy, y)`", :∇softmax) ∇logsoftmax_data(dy, y) end ================================================ FILE: src/dim_helpers/ConvDims.jl ================================================ """ ConvDims Type system-level information about convolution dimensions. Critical for things like `im2col!()` to generate efficient code, and helpful to reduce the number of kwargs getting passed around. """ abstract type ConvDims{N} end @inline spatial_dims(::ConvDims{N}) where N = N @inline groupcount(c::ConvDims) = 1 # Below functions should be implemented by dims that subtype `ConvDims`. function input_size end function kernel_size end function stride end function padding end function dilation end function flipkernel end # Hack to get rid of type parameters function basetype(::Type{C}) where {C <: ConvDims} if C <: DepthwiseConvDims return DepthwiseConvDims elseif C <: DenseConvDims return DenseConvDims elseif C <: PoolDims return PoolDims else return nothing end end function output_size(c::ConvDims) I = input_size(c) K = kernel_size(c) S = stride(c) P = padding(c) D = dilation(c) return ntuple(spatial_dims(c)) do i return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1 end end function Base.show(io::IO, cdims::C) where {C <: ConvDims} I = (input_size(cdims)..., channels_in(cdims)) O = (output_size(cdims)..., channels_out(cdims)) K = kernel_size(cdims) S = stride(cdims) P = padding(cdims) D = dilation(cdims) F = flipkernel(cdims) G = groupcount(cdims) print(io, "$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F, groups: $G") end """ im2col_dims(c::ConvDims) im2col calculates, for each output pixel, the "convolution" of N kernels where N is the number of output channels, by doing a matrix multiply. The dimensions of that matrix are given by this function. Note that because im2col is multithreaded, we need to allocate a separate workspace of memory per-thread; hence the dimensions returned by this will depend on the number of threads Julia is currently running with. """ function im2col_dims(c::ConvDims) return ( # Output size prod(output_size(c)), # Size of single dotproduct within convolution prod(kernel_size(c))*channels_in(c), # One workspace per thread Threads.nthreads(:default), ) end """ ∇filter_im2col_dims(c::ConvDims) Like [`im2col_dims`](@ref), but saves some memory because multiple (Julia) threads are not required for the filter gradient calculation. Note: in the future, this may return `Dims{2}` instead of `Dims{3}`. """ function ∇filter_im2col_dims(c::ConvDims) return ( # Output size prod(output_size(c)), # Size of single dotproduct within convolution prod(kernel_size(c))*channels_in(c), # No threading, this is just here for backwards compat 1 ) end # Protect your skin, kids. Also do common validation of stride, padding, etc... function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N} # Number of spatial dimensions in `x` and `w`. nd = N - 2 # Given a number, duplicate it out to have `nd` length. If it's already a collection, # just splat it out into a tuple so it's always a tuple. We'll lint length later. expand_size(p::Number) = ntuple(_ -> Int(p), nd) expand_size(p) = tuple(p...) # Convert stride, padding, dilation, etc.. to fully-specified tuples pstride = expand_size(stride) pdilation = expand_size(dilation) ppadding = expand_size(padding) if length(pstride) != nd throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!")) end if length(pdilation) != nd throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!")) end # padding is kind of a special case; we allow it to be either 2-length or 4-length, # since we support asymmetrical padding if length(ppadding) == 2 * nd _validate_padding(x_size, w_size, ppadding, pdilation) return pstride, ppadding, pdilation end length(ppadding) != nd && throw(DimensionMismatch( "Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!")) # Do this repeat dance so that we get lo/hi symmetrical padding ppadding_expanded = ntuple(i -> ppadding[(i - 1) ÷ 2 + 1], 2 * nd) _validate_padding(x_size, w_size, ppadding_expanded, pdilation) return pstride, ppadding_expanded, pdilation end # Assert that kernel size * dilation is <= padded input size function _validate_padding(x_size::NTuple{N}, w_size::NTuple{N}, padding, dilation) where N for idx in 1:(N - 2) Is = x_size[idx] Ks = w_size[idx] Pl = padding[(idx - 1) * 2 + 1] Ph = padding[(idx - 1) * 2 + 2] Ds = dilation[idx] if Is + Pl + Ph < (Ks - 1) * Ds + 1 throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!")) end end nothing end ================================================ FILE: src/dim_helpers/DenseConvDims.jl ================================================ """ DenseConvDims Concrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d. """ struct DenseConvDims{N, K, S, P, D} <: ConvDims{N} input_size::NTuple{N, Int} kernel_size::NTuple{K, Int} channels_in::Int channels_out::Int groupcount::Int stride::NTuple{S, Int} padding::NTuple{P, Int} dilation::NTuple{D, Int} flipkernel::Bool end function DenseConvDims( x_size::NTuple{M}, w_size::NTuple{M}; stride = 1, padding = 0, dilation = 1, groups = 1, flipkernel::Bool = false, ) where {M} sstride, ppadding, ddilation = check_spdf( x_size, w_size, stride, padding, dilation) # Ensure channels are equal if x_size[end - 1] != w_size[end - 1] * groups xs = x_size[end - 1] ws = w_size[end - 1] throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) end # Ensure groups are valid if x_size[end - 1] % w_size[end - 1] != 0 || w_size[end] % groups != 0 throw(DimensionMismatch( "Group count should be divisble by input and output channels ($groups vs. $(w_size[end-1:end]))")) end DenseConvDims( x_size[1:(end - 2)], w_size[1:(end - 2)], x_size[end - 1], w_size[end], groups, sstride, ppadding, ddilation, flipkernel) end function DenseConvDims(x::AbstractArray, w::AbstractArray; kwargs...) if ndims(x) != ndims(w) throw(DimensionMismatch( "Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) end return DenseConvDims(size(x), size(w); kwargs...) end # Useful for constructing a new DenseConvDims that has only a few elements different # from the original progenitor object that it inherits shapes from. @inline DenseConvDims( c::C; I=input_size(c), K=kernel_size(c), C_in=channels_in(c), C_out=channels_out(c), S=stride(c), P=padding(c), D=dilation(c), F=flipkernel(c), G=groupcount(c), ) where C <: ConvDims = DenseConvDims( I, K, C_in, C_out, G, S, P, D, F) @inline groupcount(c::DenseConvDims) = c.groupcount @inline channels_in(c::DenseConvDims) = c.channels_in @inline channels_out(c::DenseConvDims) = c.channels_out @inline input_size(c::DenseConvDims) = c.input_size @inline kernel_size(c::DenseConvDims) = c.kernel_size @inline stride(c::DenseConvDims) = c.stride @inline padding(c::DenseConvDims) = c.padding @inline dilation(c::DenseConvDims) = c.dilation @inline flipkernel(c::DenseConvDims) = c.flipkernel function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M} # First, check that channel counts are all correct: @assert x[M-1] * groupcount(cdims) == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") @assert y[M-1] == channels_out(cdims) ÷ groupcount(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))") @assert w[M-1] * groupcount(cdims) == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))") @assert w[M] * groupcount(cdims) == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))") # Next, check that the spatial dimensions match up @assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))") @assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))") @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))") # Check the groups match @assert channels_in(cdims) % groupcount(cdims) == 0 DimensionMismatch("Groups ($(groupcount(cdims))) should be divisble by input channels $(channels_in(cdims))") # Finally, check that the batch size matches @assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))") end ================================================ FILE: src/dim_helpers/DepthwiseConvDims.jl ================================================ """ DepthwiseConvDims Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to characterization by `C_in`, `C_mult`, rather than `C_in`, `C_out`. Useful to be separate from DenseConvDims primarily for channel calculation differences. """ struct DepthwiseConvDims{N, K, S, P, D} <: ConvDims{N} input_size::NTuple{N, Int} kernel_size::NTuple{K, Int} channels_in::Int channels_multiplier::Int stride::NTuple{S, Int} padding::NTuple{P, Int} dilation::NTuple{D, Int} flipkernel::Bool end function DepthwiseConvDims( x_size::NTuple{M}, w_size::NTuple{M}; stride = 1, padding = 0, dilation = 1, flipkernel::Bool = false, ) where M sstride, ppadding, ddilation = check_spdf( x_size, w_size, stride, padding, dilation) # Ensure channels are equal if x_size[end-1] != w_size[end] xs = x_size[end-1] ws = w_size[end] throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) end DepthwiseConvDims( x_size[1:(end - 2)], w_size[1:(end - 2)], x_size[end - 1], w_size[end - 1], sstride, ppadding, ddilation, flipkernel) end function DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...) if ndims(x) != ndims(w) throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) end return DepthwiseConvDims(size(x), size(w); kwargs...) end # Useful for constructing a new DepthwiseConvDims that has only a few elements different # from the original progenitor object. @inline DepthwiseConvDims( c::DepthwiseConvDims; I=input_size(c), K=kernel_size(c), C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c), P=padding(c), D=dilation(c), F=flipkernel(c), ) = DepthwiseConvDims( I, K, C_in, C_m, S, P, D, F) @inline channels_in(c::DepthwiseConvDims) = c.channels_in @inline channels_out(c::DepthwiseConvDims) = c.channels_in * c.channels_multiplier @inline channel_multiplier(c::DepthwiseConvDims) = c.channels_multiplier @inline input_size(c::DepthwiseConvDims) = c.input_size @inline kernel_size(c::DepthwiseConvDims) = c.kernel_size @inline stride(c::DepthwiseConvDims) = c.stride @inline padding(c::DepthwiseConvDims) = c.padding @inline dilation(c::DepthwiseConvDims) = c.dilation @inline flipkernel(c::DepthwiseConvDims) = c.flipkernel # This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M} # First, check that channel counts are all correct: @assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") @assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))") @assert w[M-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[M-1]) vs. $(channel_multiplier(cdims))") @assert w[M] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M]) vs. $(channels_in(cdims)))") # Next, check that the spatial dimensions match up @assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))") @assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))") @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))") # Finally, check that the batch size matches @assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))") end ================================================ FILE: src/dim_helpers/PoolDims.jl ================================================ """ PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int}; stride=k, padding=0, dilation=1) where {M, L} Dimensions for a "pooling" operation that can have an arbitrary input size, kernel size, stride, dilation, and channel count. Used to dispatch onto efficient implementations at compile-time. """ struct PoolDims{N, K, S, P, D} <: ConvDims{N} input_size::NTuple{N, Int} kernel_size::NTuple{K, Int} channels_in::Int stride::NTuple{S, Int} padding::NTuple{P, Int} dilation::NTuple{D, Int} end function PoolDims( x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int}; stride = k, padding = 0, dilation = 1, ) where {M, L} _check_kernel(k::Number, N::Int) = ntuple(_ -> Int(k), N) _check_kernel(k::NTuple, ::Int) = k kernel = _check_kernel(k, M - 2) length(x_size) == length(kernel) + 2 || error( "PoolDims expects ndim(x) == length(k)+2 or length(size(x)) == length(kernel)+2, dimension of x_size is $(length(x_size)), length of k need $(length(x_size) - 2), but now it's $(length(kernel))" ) spdf_kernel = NTuple{M, Int}([kernel..., 1, 1]) sstride, ppadding, ddilation = check_spdf( x_size, spdf_kernel, stride, padding, dilation) PoolDims( x_size[1:(end - 2)], kernel, x_size[end - 1], sstride, ppadding, ddilation) end PoolDims(x::AbstractArray, k; kwargs...) = PoolDims(size(x), k; kwargs...) # Useful for constructing a new PoolDims that has only a few elements different # from the original progenitor object that it inherits shapes from. PoolDims( c::C; I=input_size(c), K=kernel_size(c), C_in=channels_in(c), S=stride(c), P=padding(c), D=dilation(c), ) where C <: ConvDims = PoolDims(I, K, C_in, S, P, D) @inline channels_in(c::PoolDims) = c.channels_in @inline channels_out(c::PoolDims) = c.channels_in @inline input_size(c::PoolDims) = c.input_size @inline kernel_size(c::PoolDims) = c.kernel_size @inline stride(c::PoolDims) = c.stride @inline padding(c::PoolDims) = c.padding @inline dilation(c::PoolDims) = c.dilation @inline flipkernel(c::PoolDims) = false function check_dims(x::NTuple{M}, y::NTuple{M}, pdims::PoolDims) where {M} # First, check that channel counts are all correct: @assert x[end-1] == channels_in(pdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(pdims)))") @assert y[end-1] == channels_out(pdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(pdims)))") # Next, check that the spatial dimensions match up @assert x[1:end-2] == input_size(pdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(pdims)))") @assert y[1:end-2] == output_size(pdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(pdims)))") # Finally, check that the batch size matches @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))") end ================================================ FILE: src/dim_helpers.jl ================================================ # Various helper functions to calculate dimensions for operations include("dim_helpers/ConvDims.jl") include("dim_helpers/DenseConvDims.jl") include("dim_helpers/DepthwiseConvDims.jl") include("dim_helpers/PoolDims.jl") """ transpose_swapbatch(x::AbstractArray) Given an AbstractArray, swap its batch and channel axes, as we must during transposed convolution. We do this to the operands during convolution, and then again to the output once we're done. """ function transpose_swapbatch(x::AbstractArray) return permutedims(x, ((1:(ndims(x)-2))..., ndims(x), ndims(x)-1)) end function transpose_swapbatch(x::Tuple) return (x[1:end-2]..., x[end], x[end-1]) end """ transpose_pad(cdims::ConvDims) Transposed convolution can be calculated in terms of typical convolution with some extra padding. This method computes the padding of the convolution that would result in the transposed convolution of two operands, in essence taking care of that "extra padding". Note that this method should almost always be accompanied by a call that predilates one of the operands. """ function transpose_pad(cdims::ConvDims) I = input_size(cdims) K = kernel_size(cdims) D = dilation(cdims) P = padding(cdims) S = stride(cdims) return ntuple(length(P)) do i hi = ceil(Int, i/2) if mod(i, 2) == 1 return (K[hi] - 1)*D[hi] - P[i] else return (K[hi] - 1)*D[hi] - P[i] + mod(I[hi] + P[i-1] + P[i] - (K[hi] - 1)*D[hi] - 1, S[hi]) end end end """ insert_singleton_spatial_dimension(cdims::ConvDims) When converting a 1d convolution to a 2d, or a 2d to a 3d, we need to insert a singleton spatial dimension at the end of the spatial dimensions. This does so for a ConvDims. """ @inline function insert_singleton_spatial_dimension(cdims::C) where {C <: ConvDims} return basetype(C)(cdims; I=(input_size(cdims)..., 1), K=(kernel_size(cdims)..., 1), S=(stride(cdims)..., 1), # Padding is always the problem child.... P=(padding(cdims)..., 0, 0), D=(dilation(cdims)..., 1), ) end # We specialize common cases @inline function insert_singleton_spatial_dimension(x::AbstractArray{T,3}) where {T} return reshape(x, size(x,1), 1, size(x,2), size(x,3)) end @inline function insert_singleton_spatial_dimension(x::AbstractArray{T,4}) where {T} return reshape(x, size(x,1), size(x,2), 1, size(x,3), size(x,4)) end # Helper to do this as many times as needed @inline function insert_singleton_spatial_dimension(x, reps::Int) for r in 1:reps x = insert_singleton_spatial_dimension(x) end return x end """ predilated_size(x_size::Tuple, dilation::Tuple) Calculate the size of a predilated `x` given a particular dilation factor. This is used within `predilate()` and `transpose_cdims()`. """ function predilated_size(x_size::NTuple{N}, dilation::NTuple{M}) where {N, M} @assert (M == N - 2) DimensionMismatch("len(dilation) != number of spatial dims") return ntuple(N) do idx if idx <= N - 2 return (x_size[idx] - 1)*dilation[idx] + 1 else x_size[idx] end end end """ predilate(x, dilation::Tuple) Places elements of `x` within a lattice of zeros, used in expressing a transposed convolution in terms of normal convolution. Note that while we call this "predilation" for aesthetic reasons, you are typically passing a "stride" value into here. Yes, transposed convolution is confusing. """ function predilate(x::AbstractArray{T,N}, dilation::NTuple{M}) where {T, N, M} @assert (M == N - 2) DimensionMismatch("len(dilation) != number of spatial dims") # If there is no dilation to be done, then ignore it. if all(dilation .== 1) return x end # Validate dilation factors for idx in 1:length(dilation) @assert dilation[idx] >= 1 ArgumentError("dilation cannot be less than 1") end # Create new x that is bigger and holier x_dil = zeros(eltype(x), predilated_size(size(x), dilation)) # Fill in strategic locations within `x_dil`, such that there are `dilation[idx] - 1` # zeros between each element of `x` along each spatial dimension. x_dil[(1:dilation[idx]:size(x_dil,idx) for idx in 1:(N-2))..., :, :] .= x return x_dil end """ flipweight(w::AbstractArray) Reorders the weight tensor for supporting both convolution and cross-correlation operations. """ # For any array with ndims <= 3 it makes no sense to flip the weights so simply return the # original array @inline flipweight(w::AbstractArray) = w @inline flipweight(w::AbstractArray{T, 4}) where {T} = w[end:-1:1, end:-1:1, :, :] @inline flipweight(w::AbstractArray{T, 5}) where {T} = w[end:-1:1, end:-1:1, end:-1:1, :, :] ================================================ FILE: src/dropout.jl ================================================ """ dropout([rng], A, p; [dims]) Returns an array in which each element of `A` is either replaced with zero, with probability `p`, or else multiplied by `1/(1-p)`. By default every element is treated independently. With keyword `dims=1`, a choice is made for every value of the 1st index i.e. each row of a matrix is either zero or not. Optional first argument is the random number generator used. # Examples ```julia-repl julia> dropout(ones(2, 10), 0.2) 2×10 Matrix{Float64}: 1.25 1.25 0.0 1.25 1.25 1.25 1.25 1.25 1.25 1.25 1.25 1.25 1.25 0.0 1.25 1.25 0.0 1.25 1.25 1.25 julia> mean(dropout(ones(10^4, 5), 0.2), dims=1) 1×5 Matrix{Float64}: 0.998 1.00075 0.99125 0.99575 1.00075 julia> dropout(ones(5, 5), 0.7, dims=1) # whole row the same 5×5 Matrix{Float64}: 3.33333 3.33333 3.33333 3.33333 3.33333 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3.33333 3.33333 3.33333 3.33333 3.33333 0.0 0.0 0.0 0.0 0.0 julia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1) 1×5 Matrix{Float64}: 1.00571 1.00571 1.00571 1.00571 1.00571 ``` """ dropout(A::AbstractArray, p::Real; dims = :) = dropout(_rng_from_array(A), A, p; dims) function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :) _rng_compat_array(rng, A) T = float(eltype(A)) 0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1")) if p > 0 dst = similar(A, T, size(A)) pT = convert(real(T), p) _dropout!(rng, dst, A, pT, dims) else # Not so sure we want fast paths... this tries but doesn't guarantee type-stability, # and the rrule does not have such a fast paths. convert(AbstractArray{T}, A) end end """ dropout!(B, A, p; [dims]) This does exactly `B .= dropout(A, p; dims)`, or rather, it's the implementation of out-of-place [`dropout`](@ref). """ dropout!(B::AbstractArray, A::AbstractArray, p::Real; dims = :) = dropout!(_rng_from_array(B), B, A, p; dims) function dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real; dims=:) size(dst) == size(src) || throw(DimensionMismatch("dropout! expects output array the same size as input")) 0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1")) _rng_compat_array(rng, src) if p > 0 pT = convert(real(eltype(dst)), p) _dropout!(rng, dst, src, pT, dims) else # This fast path isn't free, but no concerns about types changing: copyto!(dst, src) end end # This is the easy case in that we can safely use the output array for random numbers. function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims::Colon) T = real(eltype(dst)) val = convert(T, 1/(1-p)) rand!(rng, dst) ## This is what we want, but it hits a SIMD bug, solved by _fast_broadcast! # dst .= (dst.>p) .* val .* src _fast_broadcast!(dst, src) do q, x ((real(q)>p) * val) * x end dst end # For other dims, we we do need to allocate something. function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims) T = real(eltype(dst)) tmp = similar(dst, T, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src))) rand!(rng, tmp) val = convert(T, 1/(1-p)) ## One-pass strategy -- faster on GPU dst .= ((tmp.>p) .* val) .* src ## Two-pass strategy -- slightly faster on some CPUs? # _fast_broadcast!(tmp) do q # (q>p) * val # end # dst .= tmp .* src end # The gradient needs to keep the random choices made, thus store at least a BitArray, # but the following way turns out to be faster & simpler: function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, A::AbstractArray, p::Real; dims = :) T = float(real(eltype(A))) val = convert(T, 1/(1-p)) keep = if dims isa Colon similar(A, T, size(A)) else similar(A, T, ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))) end rand!(rng, keep) Y = @. ((keep>p) * val) * A function dropout_back(Δ) dY = unthunk(Δ) dA = @. ((keep>p) * val) * dY (NoTangent(), NoTangent(), dA, NoTangent()) end return Y, dropout_back end # Possibly TODO: another approach to the gradient would be to copy the RNG # and then re-generate the same mask, instead of storing it. This saves memory # and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking. # https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402 """ _rng_from_array(x) Return the random number generator most appropriate for `x`: `CUDA.default_rng()` for `CuArray`, else `Random.default_rng()` """ _rng_from_array(::AbstractArray) = Random.default_rng() @non_differentiable _rng_from_array(::Any) # This exists because `rand!(default_rng(), CUDA.rand(3))` ignores the RNG, # and Flux would prefer an error. NNlibCUDAExt will overload it to produce that. _rng_compat_array(::AbstractRNG, ::AbstractArray) = nothing ================================================ FILE: src/fold.jl ================================================ """ unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true) Places sliding windows of x into a container tensor of size `(num_windows, window_size, batchsize)`. The window size is determined by the `prod(spatial dims of kernel)*input_channels`. The number of sliding windows will match those of convolution (`conv`) with the same kernel_size and arguments. Note that by default `conv` flips the spatial dimensions of its kernel (default `flipped=false`), whereas `unfold` does not (default `flipped=true`). Uses `NNlib.im2col!` as backend. See also [`fold`](@ref), the adjoint/transpose operator and a potential inverse of `unfold`. # Example The below example demonstrates that `unfold` uses the same sliding windows as `conv`. In general [`batched_mul`](@ref) + `unfold` should not be used to achieve convolution. ```jldoctest julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1 julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3 julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold julia> z = NNlib.unfold(x, size(w); kws...) 4×3×1 Array{Int64, 3}: [:, :, 1] = 0 100 2 2 3 40 40 5 6 6 700 0 julia> y1 = conv(x, w; kws...) 4×1×1 Array{Int64, 3}: [:, :, 1] = -2 -38 34 6 julia> y2 = z ⊠ w # ⊠ (\\boxtimes) is NNlib.batched_mul 4×1×1 Array{Int64, 3}: [:, :, 1] = -2 -38 34 6 ``` """ function unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N} stride = expand(Val(N - 2), stride) padding = expand(Val(N - 2), pad) dilation = expand(Val(N - 2), dilation) cdims = DenseConvDims(size(x), kernel_size; stride, padding, dilation, flipkernel=flipped) return unfold(x, cdims) end """ fold(y, output_size, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true) The adjoint/transpose operator of `unfold`. It accumulates sliding windows from the output of `unfold` into a container tensor of size `output_size`. An inverse to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues with a divisor (see example). Uses `NNlib.col2im!` as backend. See also [`unfold`](@ref). # Example ```jldoctest julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1 julia> y = NNlib.unfold(x, (3,1,1)) # sliding window of size 3 5×3×1 Array{Int64, 3}: [:, :, 1] = 100 2 3 2 3 40 3 40 5 40 5 6 5 6 700 julia> z = NNlib.fold(y, size(x), (3,1,1)) # sum of contributions in y. 100 appears once, 40 three times 7×1×1 Array{Int64, 3}: [:, :, 1] = 100 4 9 120 15 12 700 julia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3,1,1)) 7×1×1 Array{Float64, 3}: [:, :, 1] = 1.0 2.0 3.0 3.0 3.0 2.0 1.0 julia> z ./ divisor 7×1×1 Array{Float64, 3}: [:, :, 1] = 100.0 2.0 3.0 40.0 5.0 6.0 700.0 ``` In general, an inverse to `unfold` does not exist if `divisor` contains zeros. """ function fold(x::AbstractArray{T, 3}, output_size::NTuple{N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N} stride = expand(Val(N - 2), stride) padding = expand(Val(N - 2), pad) dilation = expand(Val(N - 2), dilation) cdims = DenseConvDims(output_size, kernel_size; stride, padding, dilation, flipkernel=flipped) return fold(x, output_size, cdims) end # im2col_dims returns (numblocks, blocksize, threadnum) where thread dim is used as thread-local # workspace for multithreaded conv. Ultimately, we want to threadnum with batchsize. unfold_dims(cdims::DenseConvDims) = im2col_dims(cdims)[1:2] # auto-allocating versions function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N} y = similar(x, unfold_dims(cdims)..., size(x, N)) # (numblocks, blocksize, batchsize) return unfold!(y, x, cdims) end function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T} x = similar(y, output_size) return fold!(x, y, cdims) end # N < 5 -dimension in-place versions function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N} unfold!( y, insert_singleton_spatial_dimension(x, 5-N), insert_singleton_spatial_dimension(cdims, 5-N), ) return y end function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N} fold!( insert_singleton_spatial_dimension(x, 5-N), y, insert_singleton_spatial_dimension(cdims, 5-N), ) return x end # 5-dimension in-place versions function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT} @threads for batch_idx in 1:size(x, 5) y_slice = view(y, :, :, batch_idx) im2col!(y_slice, view(x, :, :, :, :, batch_idx), cdims) end return y end function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {xT, yT} @threads for batch_idx in 1:size(x, 5) y_slice = view(y, :, :, batch_idx) col2im!(view(x, :, :, :, :, batch_idx), y_slice, cdims) end return x end @kernel function unfold_kernel!( col::AbstractArray{T}, x, col_size, input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx, ) where T index = @index(Global) @inbounds if index ≤ max_idx i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices w, h, d = CartesianIndices(output_size)[i].I # x indices # project w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation if !flipkernel kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1 end # check out of bounds if !all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d))) col[i, kw, kh, kd, c, b] = T(0) else xval::T = x[w, h, d, c, b] col[i, kw, kh, kd, c, b] = xval end end end @kernel function fold_kernel!( x::AbstractArray{T}, col, col_size, input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx, ) where T index = @index(Global) @inbounds if index ≤ max_idx i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices w, h, d = CartesianIndices(output_size)[i].I # x indices # project w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation # check out of bounds if all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d))) if !flipkernel kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1 end cval::T = col[i, kw, kh, kd, c, b] @atomic x[w, h, d, c, b] += cval end end end function unfold!( col::AnyGPUArray{cT,3}, x::AnyGPUArray{xT,5}, cdims::DenseConvDims, ) where {cT, xT} spatial_dims(cdims) != 3 && throw(DimensionMismatch( "unfold!() only accepts 3d convoluitional inputs")) C_in = channels_in(cdims) ker_size = kernel_size(cdims) pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo) out_size = output_size(cdims) col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :)) max_idx = prod(size(col)) unfold_kernel!(get_backend(x))( col_reshaped, x, size(col_reshaped), input_size(cdims), out_size, ker_size, flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx; ndrange=max_idx) return col end function fold!( x::AnyGPUArray{xT,5}, col::AnyGPUArray{cT,3}, cdims::DenseConvDims, ) where {xT, cT} spatial_dims(cdims) != 3 && throw(DimensionMismatch( "fold!() only accepts 3d convoluitional inputs")) # going to accumulate into x fill!(x, xT(0)) C_in = channels_in(cdims) ker_size = kernel_size(cdims) pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo) out_size = output_size(cdims) col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :)) max_idx = prod(size(col)) fold_kernel!(get_backend(x))( x, col_reshaped, size(col_reshaped), input_size(cdims), out_size, ker_size, flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx; ndrange=max_idx) return x end # reverse diff rules function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...) function unfold_pullback(Δ) return ( NoTangent(), fold(unthunk(Δ), size(x), cdims; kw...), NoTangent(), ) end return unfold(x, cdims; kw...), unfold_pullback end function rrule(::typeof(fold), x, output_size, cdims::DenseConvDims; kw...) function fold_pullback(Δ) return ( NoTangent(), unfold(unthunk(Δ), cdims; kw...), NoTangent(), NoTangent(), ) end return fold(x, output_size, cdims; kw...), fold_pullback end ================================================ FILE: src/functions.jl ================================================ """ glu(x, dim = 1) The gated linear unit from the ["Language Modeling with Gated Convolutional Networks"](https://arxiv.org/abs/1612.08083) paper. Calculates `a .* sigmoid(b)`, where `x` is split in half along given dimension `dim` to form `a` and `b`. """ function glu(x, dim = 1) maxdim = size(x, dim) @assert maxdim % 2 == 0 "Dimension must be even" half = maxdim ÷ 2 a, b = selectdim(x, dim, 1:half), selectdim(x, dim, half+1:maxdim) a .* sigmoid.(b) end ================================================ FILE: src/gather.jl ================================================ """ NNlib.gather(src, idx) -> dst Reverse operation of [`scatter`](@ref). Gathers data from source `src` and writes it in a destination `dst` according to the index array `idx`. For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to dst[:, ... , k] .= src[:, ... , idx[k]...] Notice that if `idx` is a vector containing integers and `src` is a matrix, previous expression simplifies to dst[:, k] .= src[:, idx[k]] and `k` will run over `1:length(idx)`. The elements of `idx` can be integers or integer tuples and may be repeated. A single `src` column can end up being copied into zero, one, or multiple `dst` columns. See [`gather!`](@ref) for an in-place version. # Examples ```jldoctest julia> NNlib.gather([1,20,300,4000], [2,4,2]) 3-element Vector{Int64}: 20 4000 20 julia> NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1]) 2×5 Matrix{Int64}: 1 3 1 3 1 4 6 4 6 4 ``` """ function gather( src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{Tidx, Nidx}, ) where {Tsrc, Nsrc, Nidx, Tidx} M = typelength(Tidx) dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) dst = similar(src, Tsrc, dstsize) return gather!(dst, src, idx) end """ gather(src, IJK...) Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and call `gather` on it: `gather(src, CartesianIndex.(IJK...))`. # Examples ```jldoctest julia> src = reshape([1:15;], 3, 5) 3×5 Matrix{Int64}: 1 4 7 10 13 2 5 8 11 14 3 6 9 12 15 julia> NNlib.gather(src, [1, 2], [2, 4]) 2-element Vector{Int64}: 4 11 ``` """ function gather( src::AbstractArray{Tsrc, Nsrc}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, Ks::AbstractVector{<:Integer}..., ) where {Nsrc, Tsrc} return gather(src, to_cartesian_index(I, J, Ks...)) end to_cartesian_index(IJK...) = CartesianIndex.(IJK...) @non_differentiable to_cartesian_index(::Any...) """ NNlib.gather!(dst, src, idx) Reverse operation of [`scatter!`](@ref). Gathers data from source `src` and writes it in destination `dst` according to the index array `idx`. For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to dst[:, ... , k] .= src[:, ... , idx[k]...] Notice that if `idx` is a vector containing integers, and both `dst` and `src` are matrices, previous expression simplifies to dst[:, k] .= src[:, idx[k]] and `k` will run over `1:length(idx)`. The elements of `idx` can be integers or integer tuples and may be repeated. A single `src` column can end up being copied into zero, one, or multiple `dst` columns. See [`gather`](@ref) for an allocating version. """ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) dims = scatter_dims(src, dst, idx) colons = ntuple(i -> Colon(), dims) for k in CartesianIndices(idx) _view(dst, colons, k) .= _view(src, colons, idx[k]) end return dst end function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) isempty(dst) && return dst n_dims = scatter_dims(src, dst, idx) dims = size(src)[1:n_dims] max_dims_idx = prod(dims) ndrange = max_dims_idx * length(idx) _gather!(KernelAbstractions.get_backend(src))( dst, src, idx, CartesianIndices(dims), max_dims_idx; ndrange) return dst end @kernel function _gather!( dst, @Const(src), @Const(idx), dim_ids::CartesianIndices, max_dims_idx::Int, ) i = @index(Global) j, k = divrem(i - 1, max_dims_idx) @inbounds dst[i] = src[dim_ids[k + 1], Tuple(idx[j + 1])...] end ∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx) function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) y = gather!(dst, src, idx) src_size = size(src) gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent()) return y, gather!_pullback end ================================================ FILE: src/gemm.jl ================================================ ## Low level gemm! call with pointers ## Borrowed from Knet.jl, adapted for compile-time constants using LinearAlgebra.BLAS: get_num_threads, set_num_threads if isdefined(LinearAlgebra.BLAS, :libblastrampoline) const libblas = LinearAlgebra.BLAS.libblastrampoline else const libblas = Base.libblas_name end """ gemm!() Low-level gemm!() call with pointers, borrowed from Knet.jl Calculates `C = alpha*op(A)*op(B) + beta*C`, where: - `transA` and `transB` set `op(X)` to be either `identity()` or `transpose()` - alpha and beta are scalars - op(A) is an (M, K) matrix - op(B) is a (K, N) matrix - C is an (M, N) matrix. """ gemm! # These are the datatypes we have fast GEMM for gemm_datatype_mappings = ( (:dgemm_, Float64), (:sgemm_, Float32), (:zgemm_, ComplexF64), (:cgemm_, ComplexF32), ) for (gemm, elt) in gemm_datatype_mappings @eval begin @inline function gemm!(transA::Val, transB::Val, M::Int, N::Int, K::Int, alpha::$(elt), A::Ptr{$elt}, B::Ptr{$elt}, beta::$(elt), C::Ptr{$elt}) # Convert our compile-time transpose marker to a char for BLAS convtrans(V::Val{false}) = 'N' convtrans(V::Val{true}) = 'C' if transA == Val(false) lda = M else lda = K end if transB == Val(false) ldb = K else ldb = N end ldc = M ccall((@blasfunc($(gemm)), libblas), Nothing, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), convtrans(transA), convtrans(transB), M, N, K, alpha, A, lda, B, ldb, beta, C, ldc) end end end for (gemm, elt) in gemm_datatype_mappings @eval begin @inline function batched_gemm!(transA::AbstractChar, transB::AbstractChar, alpha::($elt), A::AbstractArray{$elt, 3}, B::AbstractArray{$elt, 3}, beta::($elt), C::AbstractArray{$elt, 3}) @assert !Base.has_offset_axes(A, B, C) @assert size(A, 3) == 1 || size(A, 3) == size(C, 3) "batch size mismatch: A != C" @assert size(B, 3) == 1 || size(B, 3) == size(C, 3) "batch size mismatch: B != C" m = size(A, transA == 'N' ? 1 : 2) ka = size(A, transA == 'N' ? 2 : 1) kb = size(B, transB == 'N' ? 1 : 2) n = size(B, transB == 'N' ? 2 : 1) if ka != kb || m != size(C,1) || n != size(C,2) throw(DimensionMismatch("A1 has size ($m,$ka), B1 has size ($kb,$n), C1 has size $(size(C)[1:2])")) end LinearAlgebra.BLAS.chkstride1(A) LinearAlgebra.BLAS.chkstride1(B) LinearAlgebra.BLAS.chkstride1(C) ptrA = pointer(A) ptrB = pointer(B) ptrC = pointer(C) strA = size(A, 3) == 1 ? 0 : Base.stride(A, 3) strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3) strC = Base.stride(C, 3) n_threads = min( Threads.nthreads(:default), 1 + max(length(A), length(B)) ÷ 8000) # In some tests, size (20,20,20) is worth splitting between two threads, # as is size (32,32,8). if n_threads > 1 old_threads = get_num_threads() set_num_threads(1) parts = Iterators.partition(1:size(C, 3), cld(size(C, 3), n_threads)) function gemm!_part(ks) for k in ks ptrAk = ptrA + (k-1) * strA * sizeof($elt) ptrBk = ptrB + (k-1) * strB * sizeof($elt) ptrCk = ptrC + (k-1) * strC * sizeof($elt) ccall((@blasfunc($(gemm)), libblas), Nothing, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), transA, transB, m, n, ka, alpha, ptrAk, max(1,Base.stride(A,2)), ptrBk, max(1,Base.stride(B,2)), beta, ptrCk, max(1,Base.stride(C,2))) end end if should_use_spawn() && length(parts) > 1 Threads.@sync for ks in parts Threads.@spawn gemm!_part(ks) end else for ks in parts gemm!_part(ks) end end set_num_threads(old_threads) else # small problem, no threads for k in 1:size(C, 3) # Identical loop body ptrAk = ptrA + (k-1) * strA * sizeof($elt) ptrBk = ptrB + (k-1) * strB * sizeof($elt) ptrCk = ptrC + (k-1) * strC * sizeof($elt) ccall((@blasfunc($(gemm)), libblas), Nothing, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), transA, transB, m, n, ka, alpha, ptrAk, max(1,Base.stride(A,2)), ptrBk, max(1,Base.stride(B,2)), beta, ptrCk, max(1,Base.stride(C,2))) end end return C end end end ================================================ FILE: src/impl/conv_direct.jl ================================================ ## This file contains direct Julia implementations of 2d and 3d convolutions # Helper functions for restricting x/w overreach function clamp_lo(x, w) idx = 1 while idx <= length(x) && x[idx] <= 0 idx += 1 end return (x[idx:end], w[idx:end]) end function clamp_hi(x, w, L) idx = length(x) while idx >= 1 && x[idx] > L idx -= 1 end return (x[1:idx], w[1:idx]) end """ conv_direct!(y, x, w, cdims; alpha=1, beta=0) Direct convolution implementation; used for debugging, tests, and mixing/matching of strange datatypes within a single convolution. Uses naive nested for loop implementation and does not attempt to optimize performance. Rather, this implementation is intended to be maximally understandable and debuggable, to aid in testing other, more performant implementations. We also explicitly support mixing and matching of strange datatypes, so that if the user really wants to convolve an image of `UInt8`'s with a `Float16` kernel, storing the result in a `Float32` output, there is at least a function call for that madness. The keyword arguments `alpha` and `beta` control accumulation behavior; this function calculates `y = alpha * x * w + beta * y`, therefore by setting `beta` to a nonzero value, the user is able to accumulate values into a preallocated `y` buffer, or by setting `alpha` to a nonunitary value, an arbitrary gain factor can be applied. By defaulting `beta` to `false`, we make use of the Bradbury promotion trick to override `NaN`'s that may pre-exist within our output buffer, as `false*NaN == 0.0`, whereas `0.0*NaN == NaN`. Only set `beta` if you are certain that none of the elements within `y` are `NaN`. The basic implementation performs 3-dimensional convolution; 1-dimensional and 2- dimensional cases are supported by simply reshaping `y`, `x` and `w`, for which wrapper methods are available. """ conv_direct! function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5}, cdims::DenseConvDims; alpha::yT = yT(1), beta = false) where {yT, xT, wT} conv_direct!( y, x, w, cdims, Val(kernel_size(cdims)), Val(channels_out(cdims)), Val(padding(cdims)), Val(dilation(cdims)), Val(stride(cdims)), Val(flipkernel(cdims)); alpha, beta) return y end function conv_direct!( y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5}, cdims::DenseConvDims, # kernel size, output channels, padding, dilation, stride, flipped kernel ::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S}, fk::Val{F}; alpha::yT = yT(1), beta = false, ) where {yT, xT, wT, K, C, P, D, S, F} check_dims(size(x), size(w), size(y), cdims) width, height, depth = input_size(cdims) kernel_w, kernel_h, kernel_d = K pad_w_lo, _, pad_h_lo, _, pad_d_lo, _ = P dil_w, dil_h, dil_d = D stride_w, stride_h, stride_d = S # Create a method that determines how we're going to index into `w`. kproj(k, _, ::Val{true}) = k kproj(k, M, ::Val{false}) = M - k + 1 # A helper function to project from output (w, h) to input (input_w, input_h) project(idx, stride, pad) = (idx - 1)*stride - pad + 1 # Use `calc_padding_regions` to determine where we do or don't need to worry about padding padded_regions, central_region = calc_padding_regions(cdims) # Set outputs to zero to support custom datatypes (https://github.com/FluxML/NNlib.jl/issues/490) if iszero(beta) y = fill!(y, zero(yT)) end # Start with the central region w_region, h_region, d_region = central_region @inbounds for batch in 1:size(x, 5), c_out in 1:C, d_idx in d_region, h_idx in h_region, w_idx in w_region # Since we're in the central region, we don't need to worry about clamping dotprod = yT(0) for c_in in 1:channels_in(cdims), kd in 1:kernel_d, kh in 1:kernel_h, kw in 1:kernel_w # Hoist me, you coward. x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w x_val = x[x_w, x_h, x_d, c_in, batch] w_val = w[kproj(kw, kernel_w, fk), kproj(kh, kernel_h, fk), kproj(kd, kernel_d, fk), c_in, c_out] dotprod = muladd(x_val, w_val, dotprod) end y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end # Next, do potentially-padded regions: @inbounds for (w_region, h_region, d_region) in padded_regions, batch in 1:size(x, 5), c_out in 1:C, d_idx in d_region, h_idx in h_region, w_idx in w_region # Probe for out-of-bounds accesses on `x` and `continue` if we hit one dotprod = yT(0) for c_in in 1:channels_in(cdims), kd in 1:kernel_d x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d if x_d <= 0 || x_d > depth continue end for kh in 1:kernel_h x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h if x_h <= 0 || x_h > height continue end for kw in 1:kernel_w x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w if x_w <= 0 || x_w > width continue end x_val = x[x_w, x_h, x_d, c_in, batch] w_val = w[kproj(kw, kernel_w, fk), kproj(kh, kernel_h, fk), kproj(kd, kernel_d, fk), c_in, c_out] dotprod = muladd(x_val, w_val, dotprod) end end end y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end return y end ## Gradient definitions """ ∇conv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) Calculate the gradient imposed upon `x` in the convolution `y = x * w`. """ ∇conv_data_direct! function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, w::AbstractArray{wT,5}, cdims::DenseConvDims; alpha::xT=xT(1), beta=false) where {xT, yT, wT} w = conj(transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :])) dy = predilate(dy, stride(cdims)) ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims), dilation=dilation(cdims), flipkernel=flipkernel(cdims)) dx = conv_direct!(dx, dy, w, ctdims; alpha=alpha, beta=beta) return dx end """ ∇conv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) Calculate the gradient imposed upon `w` in the convolution `y = x * w`. """ ∇conv_filter_direct! function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, cdims::DenseConvDims; alpha::wT=wT(1), beta=false) where {xT, yT, wT} x = conj(transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :])) dy = transpose_swapbatch(predilate(dy, stride(cdims))) ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), stride=dilation(cdims)) dw_ = if flipkernel(cdims) view(dw, reverse(axes(dw, 1)), reverse(axes(dw, 2)), reverse(axes(dw, 3)), :, :) else dw end conv_direct!(dw_, dy, x, ctdims; alpha=alpha, beta=beta) return dw end ================================================ FILE: src/impl/conv_im2col.jl ================================================ ## This file contains im2col-backed implementations of convolution for 2d and 3d ## convolutions. Expect to see a lot of indexing. # Helper function for flipkernel-induced dyslexia function kernel_index(w, h, d, cdims::ConvDims) flipkernel(cdims) && return (w, h, d) kernel_w, kernel_h, kernel_d = kernel_size(cdims) return (kernel_w - w + 1, kernel_h - h + 1, kernel_d - d + 1) end """ conv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) Perform a convolution using im2col and GEMM, store the result in `y`. The kwargs `alpha` and `beta` control accumulation behavior; internally this operation is implemented as a matrix multiply that boils down to `y = alpha * x * w + beta * y`, thus by setting `beta` to a nonzero value, multiple results can be accumulated into `y`, or by setting `alpha` to a nonunitary value, various gain factors can be applied. Note for the particularly performance-minded, you can provide a pre-allocated `col`, which should eliminate any need for large allocations within this method. """ function conv_im2col!( y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3}=similar(x, im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0), ntasks::Int=nthreads()) where {T} check_dims(size(x), size(w), size(y), cdims) # COL * W -> Y # [M x K] * [K x N] -> [M x N] # # M: output spatial resolution # N: output channels # K: size of input "patch" (kernel size and input channels combined) # # In english, we're grabbing each input patch and laying them out along # the M dimension in `col`, so that the GEMM call below multiplies each # kernel (which is kernel_h * kernel_w * channels_in elments long) is # dotproducted with that input patch, effectively computing a convolution # in a somewhat memory-wasteful but easily-computed way (since we already # have an extremely highly-optimized GEMM call available in BLAS). M = prod(output_size(cdims)) N = channels_out(cdims) K = prod(kernel_size(cdims))*channels_in(cdims) parts = Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks)) function conv_part(task_n, part) col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace for batch_idx in part im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims) GC.@preserve col_slice w y begin col_ptr = pointer(col_slice) w_ptr = pointer(w) y_ptr = pointer(y, (batch_idx - 1)*M*N + 1) gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) end end end if should_use_spawn() && length(parts) > 1 @sync for (task_n, part) in enumerate(parts) Threads.@spawn conv_part(task_n, part) end else for (task_n, part) in enumerate(parts) conv_part(task_n, part) end end return y end """ ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims)); alpha=1, beta=0) Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`. See [`conv_im2col!`](@ref) for explanation of optional parameters. """ function ∇conv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(x), size(dw), size(dy), cdims) # COL' * dY -> dW # [M x K] * [K x N] -> [M x N] # # M: size of input "patch" (kernel size and input channels combined) # N: output channels # K: output spatial resolution # # In english, we're grabbing each input patch and laying them out along # the K dimension in `col`, then multiplying in `dY` to compute a dot # product between all pixels in the input that were multiplied by a single # position in the W kernel, and all output pixels of the same location, # across output channels. This slice of `col` therefore constitutes every # input pixel that touched a particular element of the kernel. # # This is identical to a convolution between x and a dimension-permuted dY, # where we M = prod(kernel_size(cdims))*channels_in(cdims) N = channels_out(cdims) K = prod(output_size(cdims)) for batch_idx in 1:size(x,5) col_slice = view(col, :, :, 1) im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims) GC.@preserve col_slice dw dy begin col_ptr = pointer(col_slice) dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1) dw_ptr = pointer(dw) gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) end # Because we accumulate over batches in this loop, we must set `beta` equal # to `1.0` from this point on. beta = T(1) end return dw end """ ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`. See [`conv_im2col!`](@ref) for explanation of optional parameters. """ function ∇conv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DenseConvDims; col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0), ntasks::Int=nthreads()) where {T} check_dims(size(dx), size(w), size(dy), cdims) # dY W' -> dX # [M x K] * [K x N] -> [M x N] # # M: output spatial resolution # N: size of input "patch" (kernel size and input channels combined) # K: output channels # # In english, we're taking the output image and laying it out by pixel, # with channels lying along the `K` dimension in `col`. We then multiply # in `W'` to compute a dot product between each pixel location and the # entire kernel. This dot product therefore constitutes every output pixel # that was a function of a particular input pixel. # # This is identical to a transposed convolution between dY and W M = prod(output_size(cdims)) N = prod(kernel_size(cdims))*channels_in(cdims) K = channels_out(cdims) parts = Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks)) function ∇conv_data_part(task_n, part) col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace for batch_idx in part GC.@preserve col_slice w dy begin dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1) w_ptr = pointer(w) col_ptr = pointer(col_slice) gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) end col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta) end end if should_use_spawn() && length(parts) > 1 @sync for (task_n, part) in enumerate(parts) Threads.@spawn ∇conv_data_part(task_n, part) end else for (task_n, part) in enumerate(parts) ∇conv_data_part(task_n, part) end end return dx end """ im2col!(col, x, cdims) Converts a 3d image `x` into a matrix `col` for usage with GEMM-calculated convolution. Patches of `x` of size (kernel_w, kernel_h, kernel_d, C_in) will be extracted and laid out along the rows of `col`, one for each output pixel. This routine is used by all im2col-based convolutions, just with extra singleton dimensions added in the case of `2d` or `1d` images. """ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, cdims::ConvDims) where {T} if spatial_dims(cdims) != 3 throw(DimensionMismatch("im2col!() only accepts 3d convoluitional inputs")) end # Extract those nice, compile-time constant type parameters from `cdims`. width, height, depth = input_size(cdims) kernel_w, kernel_h, kernel_d = kernel_size(cdims) C_in = channels_in(cdims) pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) dil_w, dil_h, dil_d = dilation(cdims) stride_w, stride_h, stride_d = stride(cdims) out_width, out_height, out_depth = output_size(cdims) # Reshape col for easy access. col_reshaped = reshape(col, ( # Output resolution out_width, out_height, out_depth, # By input patch size kernel_w, kernel_h, kernel_d, C_in, )) padded_regions, central_region = calc_padding_regions(cdims) # A helper function to project from output (w, h) to input (input_w, input_h) @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 # We begin by copying the central region of the image which requires no padding at all. # Eliminating the branches of the fully generalized version below gives us a nice # speedup on the majority of the data. @inbounds for c in 1:C_in # Unpack "central region" w_region, h_region, d_region = central_region for kd in 1:kernel_d, kh in 1:kernel_h, kw in 1:kernel_w, d in d_region, h in h_region, w in w_region input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w kidxs = kernel_index(kw, kh, kd, cdims) xval::T = x[input_kw, input_kh, input_kd, c] col_reshaped[w, h, d, kidxs..., c] = xval end end # For each "padded region", we run the fully general version @inbounds for (w_region, h_region, d_region) in padded_regions for c in 1:C_in, d in d_region, h in h_region, w in w_region, kd in 1:kernel_d, kh in 1:kernel_h, kw in 1:kernel_w input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w kidxs = kernel_index(kw, kh, kd, cdims) out_of_bounds = ( input_kd <= 0 || input_kd > depth || input_kh <= 0 || input_kh > height || input_kw <= 0 || input_kw > width ) if out_of_bounds col_reshaped[w, h, d, kidxs..., c] = T(0) continue end # Copy the data over xval::T = x[input_kw, input_kh, input_kd, c] col_reshaped[w, h, d, kidxs..., c] = xval end end end """ col2im!(x, col, cdims, beta=0) Does the inverse of `im2col!()`, converting `col` back into a 3d image, used for backward passes, transposed convolutions, etc... Note that this method has not been optimized in the same way as `im2col()` has, because it is slightly more complicated due to the more chaotic data access patterns, and I'm not desperate enough yet. """ col2im! function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims, beta::T=T(0)) where T if spatial_dims(cdims) != 3 throw(DimensionMismatch("col2im!() only accepts 3d convoluitional inputs")) end # Extract those nice, compile-time constant type parameters from `cdims`. width, height, depth = input_size(cdims) kernel_w, kernel_h, kernel_d = kernel_size(cdims) C_in = channels_in(cdims) pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) dil_w, dil_h, dil_d = dilation(cdims) stride_w, stride_h, stride_d = stride(cdims) out_width, out_height, out_depth = output_size(cdims) # TODO: Rewrite this method so we don't have this fill!() at the beginning! # Calculate each output pixel once rather than accumulating into it? if beta == T(0) fill!(x, T(0)) elseif beta == T(1) # nothing else x .*= beta end # Reshape col for easy access. col_reshaped = reshape(col, ( # Output resolution out_width, out_height, out_depth, # By input patch size kernel_w, kernel_h, kernel_d, C_in, )) # A helper function to project from output (w, h) to input (input_w, input_h) @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 @inbounds for c in 1:C_in for kd in 1:kernel_d, kh in 1:kernel_h, kw in 1:kernel_w for d in 1:out_depth input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d # If this d is off the edge, then deal with the entire plane # in one fell swoop, like a ravenous flock of crows. CAW CAW. if input_kd <= 0 || input_kd > depth continue end for h in 1:out_height input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h # Same for `h`, but in this case it's only a line, not a plane. # This results in slightly less caw'ing. if input_kh <= 0 || input_kh > height continue end for w in 1:out_width input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w # If this `w` is off the edge, only it gets cleared out. if input_kw <= 0 || input_kw > width continue end # Copy the data over kidxs = kernel_index(kw, kh, kd, cdims) cval::T = col_reshaped[w, h, d, kidxs..., c] x[input_kw, input_kh, input_kd, c] += cval end end end end end end ================================================ FILE: src/impl/depthwiseconv_direct.jl ================================================ ## This file contains direct Julia implementations of depwthwise convolutions """ depthwiseconv_direct!(y, x, w, cdims; alpha=1, beta=0) Direct depthwise convolution implementation; used for debugging, tests, and mixing/ matching of strange datatypes within a single convolution. Uses naive nested for loop implementation and does not attempt to optimize performance. Rather, this implementation is intended to be maximally understandable and debuggable, to aid in testing other, more performant implementations. We also explicitly support mixing and matching of strange datatypes, so that if the user really wants to convolve an image of `UInt8`'s with a `Float16` kernel, storing the result in a `Float32` output, there is at least a function call for that madness. One subtlety about depthwise convolutions; the shape of a depthwise convolutional kernel is `(spatial_dims..., C_mult, C_in)`, so the axis that must match with the number of channels in `x` is the last, not the second-to-last, as in a normal dense convolution. See the docstring for `conv_direct!()` for more on the optional parameters. """ function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; alpha::yT=yT(1), beta=false) where {yT, xT, wT} check_dims(size(x), size(w), size(y), cdims) width, height, depth = input_size(cdims) kernel_w, kernel_h, kernel_d = kernel_size(cdims) pad_w_lo, _, pad_h_lo, _, pad_d_lo, _ = padding(cdims) dil_w, dil_h, dil_d = dilation(cdims) stride_w, stride_h, stride_d = stride(cdims) # Create a method that determines how we're going to index into `w` kproj(k, M, cdims::DepthwiseConvDims) = flipkernel(cdims) ? k : (M - k + 1) # A helper function to project from output (w, h) to input (input_w, input_h) project(idx, stride, pad) = (idx - 1)*stride - pad + 1 # Use `calc_padding_regions` to determine where we do or don't need to worry about padding padded_regions, central_region = calc_padding_regions(cdims) # Start with the central region w_region, h_region, d_region = central_region @inbounds for batch in 1:size(x)[end], c_mult in 1:channel_multiplier(cdims), c_in in 1:channels_in(cdims), d_idx in d_region, h_idx in h_region, w_idx in w_region # Since we're in the central region, we don't need to worry about clamping dotprod = yT(0) c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult for kd in 1:kernel_d, kh in 1:kernel_h, kw in 1:kernel_w # Hoist me, you coward. x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w x_val = x[x_w, x_h, x_d, c_in, batch] w_val = w[kproj(kw, kernel_w, cdims), kproj(kh, kernel_h, cdims), kproj(kd, kernel_d, cdims), c_mult, c_in] dotprod = muladd(x_val, w_val, dotprod) end y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end # Next, do potentially-padded regions: @inbounds for (w_region, h_region, d_region) in padded_regions, batch in 1:size(x)[end], c_mult in 1:channel_multiplier(cdims), c_in in 1:channels_in(cdims), d_idx in d_region, h_idx in h_region, w_idx in w_region dotprod = yT(0) c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult for kd in 1:kernel_d # Probe for out-of-bounds accesses on `x` and `continue` if we hit one x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d if x_d <= 0 || x_d > depth continue end for kh in 1:kernel_h x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h if x_h <= 0 || x_h > height continue end for kw in 1:kernel_w x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w if x_w <= 0 || x_w > width continue end x_val = x[x_w, x_h, x_d, c_in, batch] w_val = w[kproj(kw, kernel_w, cdims), kproj(kh, kernel_h, cdims), kproj(kd, kernel_d, cdims), c_mult, c_in] dotprod = muladd(x_val, w_val, dotprod) end end end y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end return y end """ ∇depthwiseconv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) Calculate the gradient imposed upon `x` in the depthwise convolution `y = x * w`. We make use of the fact that a depthwise convolution is equivalent to `C_in` separate normal convolutions between that channel of `x` and the `C_mult` different kernels that get applied to it. The output of such a convolution is the gradient imposed upon that particular channel of `x`, and so we simply walk through `x`, calculating the gradient for each batch and channel independently. """ ∇depthwiseconv_data_direct! function ∇depthwiseconv_data_direct!( dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; alpha::xT=xT(1), beta=false) where {xT, yT, wT} # We do a separate convolution for each channel in x @inbounds for cidx in 1:channels_in(cdims) # For this batch and in-channel, we have a normal transposed convolution # between this slice of `x` and the corresponding slices of `w` and `dy`: dx_slice = view(dx, :, :, :, cidx:cidx, :) C_mult = channel_multiplier(cdims) dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :) w_slice = permutedims(view(w, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4)) # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out # channels appropriately for this one convolution. cdims_slice = DenseConvDims(cdims; C_in=1, C_out=channel_multiplier(cdims), ) ∇conv_data_direct!(dx_slice, dy_slice, w_slice, cdims_slice; alpha=alpha, beta=beta) end return dx end """ ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`. """ ∇depthwiseconv_filter_direct! function ∇depthwiseconv_filter_direct!( dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims; alpha::wT=wT(1),beta=false) where {xT, yT, wT} # We do a separate convolution for each channel in x @inbounds for cidx in 1:channels_in(cdims) # For this batch and in-channel, we have a normal transposed convolution # between this slice of `x` and the corresponding slices of `w` and `dy`: x_slice = view(x, :, :, :, cidx:cidx, :) C_mult = channel_multiplier(cdims) dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :) dw_slice = permutedims(view(dw, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4)) # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out # channels appropriately for this one convolution. cdims_slice = DenseConvDims(cdims; C_in=1, C_out=channel_multiplier(cdims), ) ∇conv_filter_direct!(dw_slice, x_slice, dy_slice, cdims_slice; alpha=alpha, beta=beta) dw[:, :, :, :, cidx:cidx] .= permutedims(dw_slice, (1, 2, 3, 5, 4)) end return dw end ================================================ FILE: src/impl/depthwiseconv_im2col.jl ================================================ ## This file contains adapter code for doing depthwise convolutions with im2col. """ depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) Perform a depthwise convolution using im2col and GEMM, store the result in `y`. See [`conv_im2col!`](@ref) for explanation of optional parameters. """ depthwiseconv_im2col! function depthwiseconv_im2col!( y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(x, im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0), ntasks::Int=nthreads()) where T check_dims(size(x), size(w), size(y), cdims) # This functions exactly the same as conv_im2col!(), except that we shard the # incoming data into slices of single channels. This means that we need to walk # each pointer forward individually, as done below, taking a single input channel # and combining it with each kernel individually, before walking forward and doing # the next input channel. M = prod(output_size(cdims)) N = channel_multiplier(cdims) K = prod(kernel_size(cdims)) parts = Iterators.partition(axes(y)[end], ceil(Int, size(y, 5) / ntasks)) dcdims = DenseConvDims(cdims) function depthwiseconv_part(task_n, part) col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace for batch_idx in part im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims) # We do a separate convolution for each channel in x, as we must for c_in in 1:channels_in(cdims) # Walk each pointer forward as we process each input channel GC.@preserve col_slice w y begin col_ptr = pointer(col_slice, (c_in-1)*M*K+1) w_ptr = pointer(w, (c_in-1)*K*N+1) y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1) gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) end end end end if should_use_spawn() && length(parts) > 1 @sync for (task_n, part) in enumerate(parts) Threads.@spawn depthwiseconv_part(task_n, part) end else for (task_n, part) in enumerate(parts) depthwiseconv_part(task_n, part) end end return y end """ ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims)); alpha=1, beta=0) Depthwise conv backward pass onto the weights using im2col and GEMM. See [`conv_im2col!`](@ref) for explanation of optional parameters. """ ∇depthwiseconv_filter_im2col! function ∇depthwiseconv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where T check_dims(size(x), size(dw), size(dy), cdims) M = prod(kernel_size(cdims)) N = channel_multiplier(cdims) K = prod(output_size(cdims)) for batch_idx in 1:size(x, 5) # Because we accumulate over batches in this loop, we must set `beta` equal # to `1.0` after the first sample. beta′ = batch_idx == 1 ? beta : T(1) # col_slice is a thread-local workspace col_slice = view(col, :, :, 1) im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims) # We do a separate convolution for each channel in x, as we must for c_in in 1:channels_in(cdims) # Walk each pointer forward as we process each input channel GC.@preserve col_slice dw dy begin col_ptr = pointer(col_slice, (c_in - 1)*M*K + 1) dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1) dw_ptr = pointer(dw, (c_in - 1)*M*N + 1) gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta′, dw_ptr) end end end return dw end """ ∇depthwiseconv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) Depwthwise conv2d backward pass onto the input using im2col and GEMM. See [`conv_im2col!`](@ref) for explanation of optional parameters. """ ∇depthwiseconv_data_im2col! function ∇depthwiseconv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DepthwiseConvDims; col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0), ntasks::Int=nthreads()) where T check_dims(size(dx), size(w), size(dy), cdims) M = prod(output_size(cdims)) N = prod(kernel_size(cdims)) K = channel_multiplier(cdims) parts = Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks)) function ∇depthwiseconv_data_part(task_n, part) col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace for batch_idx in part # We do a separate convolution for each channel in x, as we must for cidx in 1:channels_in(cdims) GC.@preserve col_slice w dy begin # Walk each pointer forward as we process each input channel dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1) w_ptr = pointer(w, (cidx - 1)*K*N + 1) col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1) gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) end end col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta) end end if should_use_spawn() && length(parts) > 1 @sync for (task_n, part) in enumerate(parts) Threads.@spawn ∇depthwiseconv_data_part(task_n, part) end else for (task_n, part) in enumerate(parts) ∇depthwiseconv_data_part(task_n, part) end end return dx end ================================================ FILE: src/impl/padding_edges.jl ================================================ """ calc_padding_regions(dims) Padding is a jerk. A HUGE jerk that tries to sneak a bunch of conditionals and edge cases (quite literally) into our beautiful stencil operations such as convolution, pooling, etc... The way we deal with this is to, first, deal with everything in 3d, and then define a single padding region helper function that returns the seven regions that all 3d operations must deal with, including the central "unpadded" region where we can run at full bore, not paying any attention to padding. """ function calc_padding_regions(dims) width, height, depth = input_size(dims) kernel_w, kernel_h, kernel_d = kernel_size(dims) C_in = channels_in(dims) pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(dims) dil_w, dil_h, dil_d = dilation(dims) stride_w, stride_h, stride_d = stride(dims) out_width, out_height, out_depth = output_size(dims) # Let us first calculate the number of rows/cols within which we must zero out some # portion of the image patches we're copying over. The "spillage" here is the number # of indices along a particular dimension for which a kernel will have some portion # of its input domain overlapping the padding. If padding is zero, these values are # all trivially zero. The low spillage is trivially the low padding divided by the # stride; literally the number of shifts that overlap some padding. The high # spillage is slightly more complicated; we first figure out how many elements of # high padding are wasted (e.g. through strides not fitting to the end perfectly) # subtract that from the high padding, then do the same: calc_lo_spill(O, S, P) = max(min(ceil(Int, P/S), O),0) @inline function calc_hi_spill(O, S, Pl, Ph, K, D, I) wasted_Ph = (I + Pl + Ph - (K - 1)*D - 1)%S return max(min(ceil(Int, (Ph - wasted_Ph)/S), O), 0) end spill_w_lo = calc_lo_spill(out_width, stride_w, pad_w_lo) spill_w_hi = calc_hi_spill(out_width, stride_w, pad_w_lo, pad_w_hi, kernel_w, dil_w, width) spill_h_lo = calc_lo_spill(out_height, stride_h, pad_h_lo) spill_h_hi = calc_hi_spill(out_height, stride_h, pad_h_lo, pad_h_hi, kernel_h, dil_h, height) spill_d_lo = calc_lo_spill(out_depth, stride_d, pad_d_lo) spill_d_hi = calc_hi_spill(out_depth, stride_d, pad_d_lo, pad_d_hi, kernel_d, dil_d, depth) spill_w_hi_abs = out_width - spill_w_hi + 1 spill_h_hi_abs = out_height - spill_h_hi + 1 spill_d_hi_abs = out_depth - spill_d_hi + 1 # These are the regions we're going to have to run with cognizance of padding. # There are six of them; one for each face of the cube image. We explicitly # design this so that we run over `width` most tightly, in the expectation that # this will generate better code for when `h` and `d` are singleton dimensions. # We visualize this as a cube, indexed by dimensions (w, h, d). padded_regions = ( # First region is the lower-d WH face: ( 1:out_width, 1:out_height, 1:spill_d_lo, ), # The next largest chunk we choose will be the lower-h WD faces; we always # want to maximize going across full `w`, as its contiguous in memory. ( 1:out_width, 1:spill_h_lo, (spill_d_lo+1):(spill_d_hi_abs-1), ), # Then the upper-h WD face ( 1:out_width, spill_h_hi_abs:out_height, (spill_d_lo+1):(spill_d_hi_abs-1), ), # Next, we fit the HD faces in, but without overlapping the `h` and `d` # regions we've done before: ( 1:spill_w_lo, (spill_h_lo+1):(spill_h_hi_abs-1), (spill_d_lo+1):(spill_d_hi_abs-1), ), ( spill_w_hi_abs:out_width, (spill_h_lo+1):(spill_h_hi_abs-1), (spill_d_lo+1):(spill_d_hi_abs-1) ), # Last region is the higher-d WH face: ( 1:out_width, 1:out_height, spill_d_hi_abs:out_depth, ), ) # The central region that has no padding. central_region = ( (spill_w_lo+1):(spill_w_hi_abs - 1), (spill_h_lo+1):(spill_h_hi_abs - 1), (spill_d_lo+1):(spill_d_hi_abs - 1), ) return padded_regions, central_region end ================================================ FILE: src/impl/pooling_direct.jl ================================================ # Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing # the inner loop operation and a few initialization parameters. for name in (:max, :mean, :lpnorm) @eval function $((Symbol("$(name)pool_direct!")))( y::AbstractArray{<:Any, 5}, x::AbstractArray{<:Any, 5}, pdims::PoolDims; alpha=1, beta=0, kwargs...) $((Symbol("$(name)pool_direct!")))( y, x, pdims, Val(kernel_size(pdims)), Val(channels_out(pdims)), Val(padding(pdims)), Val(dilation(pdims)), Val(stride(pdims)); alpha, beta, kwargs...) return y end @eval function $((Symbol("$(name)pool_direct!")))( y::AbstractArray{T,5}, x::AbstractArray{<:Any,5}, pdims::PoolDims, # kernel size, channels out, padding, dilation, stride ::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S}; alpha=1, beta=0, kwargs... ) where {T, K, C, P, D, S} @assert iszero(beta) "beta not supported yet" check_dims(size(x), size(y), pdims) width, height, depth = input_size(pdims) kernel_w, kernel_h, kernel_d = K pad_w_lo, _, pad_h_lo, _, pad_d_lo, _ = P dil_w, dil_h, dil_d = D stride_w, stride_h, stride_d = S # We use calc_padding_regions to split outselves up into separate regions that may or # may not need to worry about padding: padded_regions, central_region = calc_padding_regions(pdims) # A helper function to project from output (w, h) to input (input_w, input_h) @inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1 # If we're doing mean pooling, we represent division by kernel size by rolling it # into the `alpha` multiplier. # The type might change here, that's why we prepend the underscore # (does it make a difference, though?) _alpha = if $(name == :mean) T(alpha / prod(K)) else T(alpha) end # _beta = T(beta) # A quick note on the array element types `T` and `R`: # Ideally, `T == R`, but in some edge-cases, this might not be the case # (e.g. with `ReverseDiff.TrackedArray`, see issue #484). # If the types differ, we will initialize variables (like `_alpha` above) with the # target eltype `T`. p = if $(name != :lpnorm) 0 else !haskey(kwargs, :p) && error("lpnormpool needs keyword argument `p`") kwargs[:p] end # Each loop, we initialize `m` to something, set that here. m_init = if $(name == :max) T <: AbstractFloat ? nextfloat(typemin(T)) : typemin(T) elseif $(name == :mean) || $(name == :lpnorm) T(0) else error("Unimplemented codegen path") end # Start with the central region w_region, h_region, d_region = central_region @inbounds for batch_idx in 1:size(x, 5), c in 1:C for d in d_region pd = project(d, stride_d, pad_d_lo) for h in h_region ph = project(h, stride_h, pad_h_lo) for w in w_region pw = project(w, stride_w, pad_w_lo) m = m_init for kd in 1:kernel_d, kh in 1:kernel_h, kw in 1:kernel_w input_kd = pd + (kd - 1) * dil_d input_kh = ph + (kh - 1) * dil_h input_kw = pw + (kw - 1) * dil_w # This conditional will be optimized away at compile time if $(name == :max) xv = x[input_kw, input_kh, input_kd, c, batch_idx] if xv > m m = xv end elseif $(name == :mean) m += x[input_kw, input_kh, input_kd, c, batch_idx] elseif $(name == :lpnorm) # y = (∑ᵢ xᵢ^p)^(1 / p), here to calculate ∑ᵢ xᵢ^p m += x[input_kw, input_kh, input_kd, c, batch_idx]^p else error("Unimplemented codegen path") end end # for lpnormpool, y = (∑ᵢ xᵢ^p)^(1 / p) m = $(name == :lpnorm) ? m^(T(1) / p) : m y[w, h, d, c, batch_idx] = _alpha * m # + _beta * y[w, h, d, c, batch_idx] end end end end # Next, the padded regions @inbounds for (w_region, h_region, d_region) in padded_regions for batch_idx in 1:size(x, 5), c in 1:C for d in d_region pd = project(d, stride_d, pad_d_lo) for h in h_region ph = project(h, stride_h, pad_h_lo) for w in w_region pw = project(w, stride_w, pad_w_lo) m = m_init for kd in 1:kernel_d input_kd = pd + (kd - 1) * dil_d if input_kd <= 0 || input_kd > depth # add here condition for handling options for paded value handling continue end for kh in 1:kernel_h input_kh = ph + (kh - 1) * dil_h if input_kh <= 0 || input_kh > height # add here condition for handling options for paded value handling continue end for kw in 1:kernel_w input_kw = pw + (kw - 1) * dil_w if input_kw <= 0 || input_kw > width # add here condition for handling options for paded value handling continue end if $(name == :max) xv = x[input_kw, input_kh, input_kd, c, batch_idx] if xv > m m = xv end elseif $(name == :mean) m += x[input_kw, input_kh, input_kd, c, batch_idx] elseif $(name == :lpnorm) m += x[input_kw, input_kh, input_kd, c, batch_idx]^p else error("Unimplemented codegen path") end end end end $(name == :lpnorm) && (m = m^(T(1) / p)) y[w, h, d, c, batch_idx] = _alpha * m # + _beta * y[w, h, d, c, batch_idx] end end end end end return y end @eval function $((Symbol("∇$(name)pool_direct!")))( dx::AbstractArray{<:Any,5}, dy::AbstractArray{<:Any,5}, y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}, pdims::PoolDims; kwargs...) $((Symbol("∇$(name)pool_direct!")))( dx, dy, y, x, pdims, Val(kernel_size(pdims)); kwargs...) return dx end # Same story for gradients, and although this is very similar to the forward pass, # it's unfortunately different enough that I think we need a separate function. :( @eval function $((Symbol("∇$(name)pool_direct!")))( dx::AbstractArray{T,5}, dy::AbstractArray{<:Any,5}, y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}, pdims::PoolDims, ::Val{K}; # == kernel_size(pdims) alpha=1, beta=0, kwargs...) where {T, K} check_dims(size(x), size(dy), pdims) width, height, depth = input_size(pdims) kernel_w, kernel_h, kernel_d = K out_c = channels_out(pdims) pad_w_lo, _, pad_h_lo, _, pad_d_lo, _ = padding(pdims) dil_w, dil_h, dil_d = dilation(pdims) stride_w, stride_h, stride_d = stride(pdims) # Concerning array eltypes `DX, DY, X, Y`, we want handle them like above, i.e., # initialize everything with the left-hand-side type (target type). # Of course, ideally the types are all the same anyways. # We use calc_padding_regions to split outselves up into separate regions that # may or may not need to worry about padding: padded_regions, central_region = calc_padding_regions(pdims) # A helper function to project from output (w, h) to input (input_w, input_h) @inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1 # If we're doing mean pooling, we represent division by kernel size by rolling # it into the `_alpha` multiplier. _alpha = if $(name == :mean) T(alpha / prod(K)) else T(alpha) end p = if $(name != :lpnorm) 0 else !haskey(kwargs, :p) && error("lpnormpool must pass p") kwargs[:p] end # Start with the central region w_region, h_region, d_region = central_region @inbounds for batch_idx in 1:size(x, 5), c in 1:out_c for d in d_region pd = project(d, stride_d, pad_d_lo) for h in h_region ph = project(h, stride_h, pad_h_lo) for w in w_region pw = project(w, stride_w, pad_w_lo) # Grab the output at this index for future use y_idx = y[w, h, d, c, batch_idx] dy_idx = dy[w, h, d, c, batch_idx] maxpool_already_chose = false for kd in 1:kernel_d, kh in 1:kernel_h, kw in 1:kernel_w input_kd = pd + (kd - 1) * dil_d input_kh = ph + (kh - 1) * dil_h input_kw = pw + (kw - 1) * dil_w # This conditional will be optimized away at compile time, # or my name isn't shengdan jingyu # x_idxs = (input_kw, input_kh, input_kd, c, batch_idx) if $(name == :max) if maxpool_already_chose break end # If it's equal; this is the one we chose. We only choose one per # kernel window, all other elements of dx must be zero. # Uncomment line below if using with non-precise output (e.g. by NNPACK) # if abs(y_idx - x[x_idxs...]) < 1e-5 && !maxpool_already_chose if y_idx ≈ x[input_kw, input_kh, input_kd, c, batch_idx] dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...] maxpool_already_chose = true # Maxpooling does not support `beta` right now. :( # else # dx[x_idxs...] = T(0) + beta*dx[x_idxs...] end elseif $(name == :mean) # Either does meanpool :( dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha elseif $(name == :lpnorm) # y = (∑ᵢ xᵢ^p)^(1 / p), ∂y/∂xᵢ = xᵢ^(p-1) × y^(1-p) grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p) dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad else error("Unimplemented codegen path") end end end end end end # Next, the padded regions @inbounds for (w_region, h_region, d_region) in padded_regions for batch_idx in 1:size(x, 5), c in 1:out_c for d in d_region pd = project(d, stride_d, pad_d_lo) for h in h_region ph = project(h, stride_h, pad_h_lo) for w in w_region pw = project(w, stride_w, pad_w_lo) # Grab the incoming gradient at this index for future use y_idx = y[w, h, d, c, batch_idx] dy_idx = dy[w, h, d, c, batch_idx] maxpool_already_chose = false # In these loops, we have to check that we're not reaching off the edge, # we do so by putting in a bunch of conditionals. :/ for kd in 1:kernel_d input_kd = pd + (kd - 1) * dil_d if input_kd <= 0 || input_kd > depth continue end for kh in 1:kernel_h input_kh = ph + (kh - 1) * dil_h if input_kh <= 0 || input_kh > height continue end for kw in 1:kernel_w input_kw = pw + (kw - 1) * dil_w if input_kw <= 0 || input_kw > width continue end # Same as above # x_idxs = (input_kw, input_kh, input_kd, c, batch_idx) if $(name == :max) if maxpool_already_chose break end # Uncomment line below if using with non-precise output # if abs(y_idx - x[x_idxs...]) < 1e-5 && !maxpool_already_chose if y_idx ≈ x[input_kw, input_kh, input_kd, c, batch_idx] dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...] maxpool_already_chose = true # else # dx[x_idxs...] = T(0) + beta*dx[x_idxs...] end elseif $(name == :mean) dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...] elseif $(name == :lpnorm) grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p) dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad else error("Unimplemented codegen path") end end end end end end end end end return dx end end ================================================ FILE: src/normalization.jl ================================================ # TODO: add CPU implementation function batchnorm end function ∇batchnorm end function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) function batchnorm_pullback(Δ) grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent()) end y, batchnorm_pullback end ================================================ FILE: src/padding.jl ================================================ """ pad_zeros(x, pad::Tuple; [dims]) pad_zeros(x, pad::Int; [dims]) Pad the array `x` with zeros. Equivalent to [`pad_constant`](@ref) with the constant equal to 0. """ pad_zeros(x::AbstractArray, pad; dims = :) = pad_constant(x, pad, 0; dims = dims) """ pad_constant(x, pad::Tuple, val = 0; [dims = :]) pad_constant(x, pad::Int, val = 0; [dims = :]) Pad the array `x` with the constant value `val`. `pad` can be a tuple of integers. If it is of some length `2 * length(dims)` that specifies the left and right padding size for each of the dimensions in `dims` as `(l1, r1, ..., ln, rn)`. If supplied with a tuple of length `length(dims)` instead, it applies symmetric padding. If `dims` is not given, it defaults to all dimensions. For integer `pad` input, it is applied on both sides on every dimension in `dims`. See also [`pad_zeros`](@ref), [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_symmetric`](@ref), and [`pad_circular`](@ref). ```jldoctest julia> r = reshape(1:4, 2, 2) 2×2 reshape(::UnitRange{Int64}, 2, 2) with eltype Int64: 1 3 2 4 julia> pad_constant(r, (1, 2, 3, 4), 8) 5×9 Matrix{Int64}: 8 8 8 8 8 8 8 8 8 8 8 8 1 3 8 8 8 8 8 8 8 2 4 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 julia> pad_constant(r, 1, 8) 4×4 Matrix{Int64}: 8 8 8 8 8 1 3 8 8 2 4 8 8 8 8 8 julia> r = reshape(1:27, 3, 3, 3) 3×3×3 reshape(::UnitRange{Int64}, 3, 3, 3) with eltype Int64: [:, :, 1] = 1 4 7 2 5 8 3 6 9 [:, :, 2] = 10 13 16 11 14 17 12 15 18 [:, :, 3] = 19 22 25 20 23 26 21 24 27 julia> pad_constant(r, (2,1), dims = 1) # assymetric padding 6×3×3 Array{Int64, 3}: [:, :, 1] = 0 0 0 0 0 0 1 4 7 2 5 8 3 6 9 0 0 0 [:, :, 2] = 0 0 0 0 0 0 10 13 16 11 14 17 12 15 18 0 0 0 [:, :, 3] = 0 0 0 0 0 0 19 22 25 20 23 26 21 24 27 0 0 0 julia> pad_constant(r, (2,1, 3), dims = (1,2)) # padding must always be either the same length as dims, or double it ERROR: ArgumentError: Could not parse padding (2, 1, 3) and dims (1, 2) Stacktrace: [...] ``` """ pad_constant(x::AbstractArray{T,N}, pad::Int, val = 0; dims = :) where {T,N} = pad_constant(x, gen_pad(pad, dims isa Colon ? dims : (dims...,), N), val) pad_constant(x::AbstractArray{T,N}, pad::Tuple, val = 0; dims = :) where {T,N} = pad_constant(x, gen_pad(pad, dims isa Colon ? dims : (dims...,), N), val) function pad_idx(pad, dims, N) is = zip( (2 .* dims) .- 1, (2 .* dims)) end @inline tuplejoin(x) = x @inline tuplejoin(x, y) = (x..., y...) @inline tuplejoin(x, y, z...) = tuplejoin(tuplejoin(x, y), z...) gen_pad(pad::Int, dims, N) = gen_pad(ntuple(_ -> pad, length(dims)), dims, N) gen_pad(pad::Int, dims::Colon, N) = ntuple(_ -> (pad, pad), N) gen_pad(pad, dims::Colon, N) = gen_pad(pad, ntuple(identity, N), N) gen_pad(pad, dims::Int, N) = gen_pad(pad, (dims,), N) gen_pad(pad::Int, dims::Int, N) = gen_pad((pad,pad), (dims,), N) function gen_pad(pad::NTuple{L,Int}, dims::NTuple{D,Int}, N) where {L,D} ntuple(N) do d if d in dims if L == D ix = findfirst(==(d), dims) (pad[ix], pad[ix]) elseif L == 2D ix = findfirst(==(d), dims) (pad[2ix - 1], pad[2ix]) else throw(ArgumentError("Could not parse padding $pad and dims $dims")) end else (0,0) end end end # Expects length(pad) == 2M function pad_constant(x::AbstractArray{T,M}, pad::NTuple{N,Tuple{Int,Int}}, val = 0) where {T,M,N} sz, c = size_and_center(x, pad) res = fill!(similar(x, sz...), val) res[c...] = x res end function size_and_center(x, pad::NTuple{N,NTuple{2, Int}}) where N sz = ntuple(i -> pad[i][1] + pad[i][2], N) .+ size(x) center = broadcast((x,y) -> x .+ y, axes(x), ntuple(i -> pad[i][1], N)) sz, center end function rrule(::typeof(pad_constant), x::AbstractArray{T,N}, pad, val; dims = :) where {T,N} y = pad_constant(x, pad, val; dims = dims) function pad_constant_pullback(Δ) p = gen_pad(pad, dims, N) outsize, center = size_and_center(x, p) (NoTangent(), @thunk(unthunk(Δ)[center...]), NoTangent(), NoTangent(),) end return y, pad_constant_pullback end """ pad_repeat(x, pad::Tuple; [dims]) pad_repeat(x, pad::Int; [dims]) Pad the array `x` repeating the values on the border. `pad` can a tuple of integers `(l1, r1, ..., ln, rn)` of some length `2n` that specifies the left and right padding size for each of the dimensions in `dims`. If `dims` is not given, it defaults to the first `n` dimensions. If `pad` is an integer, it is applied on both sides on every dimension in `dims`. In this case, `dims` defaults to the first `ndims(x)-2` dimensions (i.e. excludes the channel and batch dimension). See also [`pad_reflect`](@ref), [`pad_symmetric`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref). ```jldoctest julia> r = reshape(1:9, 3, 3) 3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64: 1 4 7 2 5 8 3 6 9 julia> pad_repeat(r, (1,2,3,4)) 6×10 Matrix{Int64}: 1 1 1 1 4 7 7 7 7 7 1 1 1 1 4 7 7 7 7 7 2 2 2 2 5 8 8 8 8 8 3 3 3 3 6 9 9 9 9 9 3 3 3 3 6 9 9 9 9 9 3 3 3 3 6 9 9 9 9 9 ``` """ function pad_repeat(x::AbstractArray, pad::NTuple{M,Int}; dims = 1:M÷2) where M length(dims) == M ÷ 2 || throw(ArgumentError("The number of dims should be equal to the number of padding dimensions")) for (i, d) in enumerate(dims) x = pad_repeat(x, (pad[2i-1], pad[2i]); dims=d) end return x end function pad_repeat(x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1) where {F,N} lpad, rpad = pad xlborder = selectdim(x, dims, 1:1) nrepl = ntuple(i -> i == dims ? lpad : 1, N) xl = repeat(xlborder, outer = nrepl) n = size(x, dims) xrborder = selectdim(x, dims, n:n) nrepr = ntuple(i -> i == dims ? rpad : 1, N) xr = repeat(xrborder, outer = nrepr) return cat(xl, x, xr, dims = dims) end """ pad_reflect(x, pad::Tuple; [dims]) pad_reflect(x, pad::Int; [dims]) Pad the array `x` reflecting its values across the border. `pad` can a tuple of integers `(l1, r1, ..., ln, rn)` of some length `2n` that specifies the left and right padding size for each of the dimensions in `dims`. If `dims` is not given, it defaults to the first `n` dimensions. If `pad` is an integer, it is applied on both sides on every dimension in `dims`. In this case, `dims` defaults to the first `ndims(x)-2` dimensions (i.e. excludes the channel and batch dimension). See also [`pad_repeat`](@ref), [`pad_symmetric`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref). ```jldoctest julia> r = reshape(1:9, 3, 3) 3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64: 1 4 7 2 5 8 3 6 9 julia> pad_reflect(r, (1,2,1,2)) 6×6 Matrix{Int64}: 5 2 5 8 5 2 4 1 4 7 4 1 5 2 5 8 5 2 6 3 6 9 6 3 5 2 5 8 5 2 4 1 4 7 4 1 ``` """ function pad_reflect(x::AbstractArray, pad::NTuple{M,Int}; dims=1:M÷2) where M length(dims) == M ÷ 2 || throw(ArgumentError("The number of dims should be equal to the number of padding dimensions")) for (i, d) in enumerate(dims) x = pad_reflect(x, (pad[2i-1], pad[2i]); dims = d) end return x end function pad_reflect( x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1, ) where {F,N} lpad, rpad = pad n = size(x, dims) xl = lpad == 0 ? similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : reverse(selectdim(x, dims, 2:lpad+1); dims) xr = rpad == 0 ? similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : reverse(selectdim(x, dims, n-rpad:n-1); dims) return cat(xl, x, xr; dims) end """ pad_symmetric(x, pad::Tuple; [dims]) pad_symmetric(x, pad::Int; [dims]) Pad the array `x` reflecting its values symmetrically across the border, i.e. the border values of `x` are present in the padding values, in contrast to [`pad_reflect`](@ref). `pad` can a tuple of integers `(l1, r1, ..., ln, rn)` of some length `2n` that specifies the left and right padding size for each of the dimensions in `dims`. If `dims` is not given, it defaults to the first `n` dimensions. If `pad` is an integer, it is applied on both sides on every dimension in `dims`. In this case, `dims` defaults to the first `ndims(x)-2` dimensions (i.e. excludes the channel and batch dimension). See also [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref). ```jldoctest julia> r = reshape(1:9, 3, 3) 3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64: 1 4 7 2 5 8 3 6 9 julia> pad_symmetric(r, (1,2,1,2)) 6×6 Matrix{Int64}: 1 1 4 7 7 4 1 1 4 7 7 4 2 2 5 8 8 5 3 3 6 9 9 6 3 3 6 9 9 6 2 2 5 8 8 5 ``` """ function pad_symmetric(x::AbstractArray, pad::NTuple{M,Int}; dims=1:M÷2) where M length(dims) == M ÷ 2 || throw(ArgumentError("The number of dims should be equal to the number of padding dimensions")) for (i, d) in enumerate(dims) x = pad_symmetric(x, (pad[2i-1], pad[2i]); dims = d) end return x end function pad_symmetric( x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1, ) where {F,N} lpad, rpad = pad n = size(x, dims) xl = lpad == 0 ? similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : reverse(selectdim(x, dims, 1:lpad); dims) xr = rpad == 0 ? similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) : reverse(selectdim(x, dims, n-rpad+1:n); dims) return cat(xl, x, xr; dims) end """ pad_circular(x, pad::Tuple; [dims]) pad_circular(x, pad::Int; [dims]) Pad the array `x` "circularly" across the border by wrapping around values from the opposite side of `x`. `pad` can a tuple of integers `(l1, r1, ..., ln, rn)` of some length `2n` that specifies the left and right padding size for each of the dimensions in `dims`. If `dims` is not given, it defaults to the first `n` dimensions. If `pad` is an integer, it is applied on both sides on every dimension in `dims`. In this case, `dims` defaults to the first `ndims(x)-2` dimensions (i.e. excludes the channel and batch dimension). The pad length on either side in any dimension must not exceed the size of `x` in that dimension, i.e. `pad_circular` is not able to create abitrary sized tilings of `x`. See also [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_symmetric`](@ref), and [`pad_constant`](@ref). ```jldoctest julia> r = reshape(1:9, 3, 3) 3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64: 1 4 7 2 5 8 3 6 9 julia> pad_circular(r, (1,2,1,2)) 6×6 Matrix{Int64}: 9 3 6 9 3 6 7 1 4 7 1 4 8 2 5 8 2 5 9 3 6 9 3 6 7 1 4 7 1 4 8 2 5 8 2 5 ``` """ function pad_circular(x::AbstractArray, pad::NTuple{M,Int}; dims=1:M÷2) where M length(dims) == M ÷ 2 || throw(ArgumentError("The number of dims should be equal to the number of padding dimensions")) for (i, d) in enumerate(dims) x = pad_circular(x, (pad[2i-1], pad[2i]); dims = d) end return x end function pad_circular(x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1) where {F,N} lpad, rpad = pad n = size(x, dims) xl = selectdim(x, dims, n-lpad+1:n) xr = selectdim(x, dims, 1:rpad) return cat(xl, x, xr, dims = dims) end # convenience methods for symmetric and homogeneous padding pad_repeat(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} = pad_repeat(x, ntuple(_ -> pad, 2length(dims)); dims = dims) pad_reflect(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} = pad_reflect(x, ntuple(_ -> pad, 2length(dims)); dims = dims) pad_symmetric(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} = pad_symmetric(x, ntuple(_ -> pad, 2length(dims)); dims = dims) pad_circular(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} = pad_circular(x, ntuple(_ -> pad, 2length(dims)); dims = dims) ================================================ FILE: src/pooling.jl ================================================ ## Pooling API # # We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d, # 2d and 3d pooling, based on the rank of the input tensors, in both mutating and # non-mutating auto-allocating variants: # - Pooling: # - maxpool(x, pdims) # - maxpool!(y, x, pdims) # - meanpool(x, pdims) # - meanpool!(y, x, pdims) # - lpnormpool(x, pdims) # - lpnormpool!(y, x, pdims) # - Pooling input backprop # - ∇maxpool(dy, y, x, pdims) # - ∇maxpool!(dx, dy, y, x, pdims) # - ∇meanpool(dy, y, x, pdims) # - ∇meanpool!(dx, dy, y, x pdims) # - ∇lpnormpool(dy, y, x, pdims) # - ∇lpnormpool!(dx, dy, y, x pdims) # # All methods require a `PoolDims` object to define the dimensions and optional # elements of the convolution (stride, dilation, etc...), which is easily constructable # through something like `PoolDims(x, w)`. # First, we will define mappings from the generic API names to our accelerated backend # implementations. At the moment this is only the direct implementation, however this # exists here so that other packages (NNPACK, MAGMA, etc...) can override this easily. for (front_name, backend) in ( # This maps from public, front-facing name, to internal backend name :maxpool => :direct, :meanpool => :direct, :lpnormpool => :direct, ) # We only define 3d pooling primitives, we reshape lower down to get 1d and 2d pooling @eval begin function $(Symbol("$(front_name)!"))( y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}, pdims::PoolDims; kwargs...) $(Symbol("$(front_name)_$(backend)!"))(y, x, pdims; kwargs...) end end end # Do the same for backprops for (front_name, backend) in ( :∇maxpool => :direct, :∇meanpool => :direct, :∇lpnormpool => :direct, ) @eval begin function $(Symbol("$(front_name)!"))( dx::AbstractArray{<:Any,5}, dy::AbstractArray{<:Any,5}, y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}, pdims::PoolDims; kwargs...) $(Symbol("$(front_name)_$(backend)!"))(dx, dy, y, x, pdims; kwargs...) end end end # Our strategy for pooling is to reshape to an array with three spatial dimensions, which # makes things MUCH EASIER for us on the backend side, and is in general pretty fast, # since we can specialize on sizes. for front_name in (:maxpool, :meanpool, :lpnormpool) for backend in (Symbol(), :_direct) for N in (3, 4) @eval begin function $(Symbol("$(front_name)$(backend)!"))( y::AbstractArray{<:Any,$N}, x::AbstractArray{<:Any,$N}, pdims::PoolDims; kwargs...) $(Symbol("$(front_name)$(backend)!"))( insert_singleton_spatial_dimension(y, $(5 - N)), insert_singleton_spatial_dimension(x, $(5 - N)), insert_singleton_spatial_dimension(pdims, $(5 - N)); kwargs... ) # We explicitly return `y` here, because the backend call # itself may return a reshaped view, which we don't want. return y end # backprops too function $(Symbol("∇$(front_name)$(backend)!"))( dx::AbstractArray{<:Any,$N}, dy::AbstractArray{<:Any,$N}, y::AbstractArray{<:Any,$N}, x::AbstractArray{<:Any,$N}, pdims::PoolDims; kwargs...) $(Symbol("∇$(front_name)$(backend)!"))( insert_singleton_spatial_dimension(dx, $(5 - N)), insert_singleton_spatial_dimension(dy, $(5 - N)), insert_singleton_spatial_dimension(y, $(5 - N)), insert_singleton_spatial_dimension(x, $(5 - N)), insert_singleton_spatial_dimension(pdims, $(5 - N)); kwargs... ) # We explicitly return `dx` here, because the backend call # itself may return a reshaped view, which we don't want. return dx end end end end end # Finally, let's generate auto-allocating versions of all our functions, for all backends: for backend in (Symbol(), :_direct) # First make auto-allocating versions of the basic pooling calls: for name in (:maxpool, :meanpool, :lpnormpool) @eval begin function $(Symbol("$(name)$(backend)"))( x::AbstractArray{<:Any,N}, pdims::PoolDims; kwargs...) where {N} y = similar(x, output_size(pdims)..., channels_out(pdims), size(x, N)) fill!(y, 0) return $(Symbol("$(name)$(backend)!"))(y, x, pdims; kwargs...) end # Backprops too function $(Symbol("∇$(name)$(backend)"))( dy::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}, x::AbstractArray{<:Any,N}, pdims::PoolDims; kwargs...) where {N} dx = similar(x, input_size(pdims)..., channels_in(pdims), size(dy, N)) fill!(dx, 0) return $(Symbol("∇$(name)$(backend)!"))(dx, dy, y, x, pdims; kwargs...) end end end end expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) """ maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k) Perform max pool operation with window size `k` on input tensor `x`. Arguments: * `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2` * `pad`: See [`pad_zeros`](@ref) for details. * `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`. """ function maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N pad = expand(Val(N), pad) stride = expand(Val(N), stride) pdims = PoolDims(x, k; padding=pad, stride=stride) return maxpool(x, pdims) end """ meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) Perform mean pool operation with window size `k` on input tensor `x`. Arguments: * `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2` * `pad`: See [`pad_zeros`](@ref) for details. * `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`. """ function meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N pad = expand(Val(N), pad) stride = expand(Val(N), stride) pdims = PoolDims(x, k; padding=pad, stride=stride) return meanpool(x, pdims) end """ lpnormpool(x, p::Real, k::NTuple{N, Integer}; pad=0, stride=k) Perform Lp pool operation with value of the Lp norm `p` and window size `k` on input tensor `x`, also known as LPPool in pytorch. This pooling operator from [Learned-Norm Pooling for Deep Feedforward and Recurrent Neural Networks](https://arxiv.org/abs/1311.1780). Arguments: * `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2` * `p` is restricted to `0 < p < Inf`. * `pad`: See [`pad_zeros`](@ref) for details. * `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`. For all elements `x` in a size `k` window, lpnormpool computes `(∑ᵢ xᵢ^p)^(1 / p)` as an element of the output. Thus `lpnormpool(x, 1, k) ./ prod(k) ≈ meanpool(x, k)` and `lpnormpool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k)`. """ function lpnormpool(x, p::Real, k::NTuple{N, Integer}; pad=0, stride=k) where {N} pow = p isa Integer ? p : convert(float(eltype(x)), p) (isinf(pow) || pow < 0) && error("p value of Lp norm pool expects `0 < p < Inf`, but p is $(pow) now.") pdims = PoolDims(x, k; padding=expand(Val(N), pad), stride=expand(Val(N), stride)) return lpnormpool(x, pdims; p=pow) end for pool in [:maxpool, :meanpool, :lpnormpool] ∇pool = Symbol(:∇, pool) pullback = Symbol(pool, :_pullback) @eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...) Ω = $pool(x, pdims; kw...) $pullback(Δ) = (NoTangent(), $∇pool(unthunk(Δ), Ω, x, pdims; kw...), NoTangent()) return Ω, $pullback end end ================================================ FILE: src/rotation.jl ================================================ """ _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, round_or_floor) This rotates the coordinates and either applies round(nearest neighbour) or floor for :bilinear interpolation) """ @inline function _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, round_or_floor) y = i - rotation_center[1] x = j - rotation_center[2] yrot = cosθ * y - sinθ * x + rotation_center[1] xrot = sinθ * y + cosθ * x + rotation_center[2] yrot_f = round_or_floor(yrot) xrot_f = round_or_floor(xrot) yrot_int = round_or_floor(Int, yrot) xrot_int = round_or_floor(Int, xrot) return yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int end """ _bilinear_helper(yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int) Some helper variables """ @inline function _bilinear_helper(yrot, xrot, yrot_f, xrot_f) xdiff = (xrot - xrot_f) xdiff_1minus = 1 - xdiff ydiff = (yrot - yrot_f) ydiff_1minus = 1 - ydiff return ydiff, ydiff_1minus, xdiff, xdiff_1minus end """ _prepare_imrotate(arr, θ, rotation_center) Prepate `sin` and `cos`, creates the output array and converts type of `rotation_center` if required. """ function _prepare_imrotate(arr::AbstractArray{T}, θ, rotation_center) where T # needed for rotation matrix θ = mod(real(T)(θ), real(T)(2π)) rotation_center = real(T).(rotation_center) sinθ, cosθ = sincos(real(T)(θ)) out = similar(arr) fill!(out, 0) return sinθ, cosθ, rotation_center, out end """ _check_trivial_rotations!(out, arr, θ, rotation_center) When `θ = 0 || π /2 || π || 3/2 || π` and if `rotation_center` is in the middle of the array. For an even array of size 4, the rotation_center would need to be 2.5. For an odd array of size 5, the rotation_center would need to be 3. In those cases, rotations are trivial just by reversing or swapping some axes. """ function _check_trivial_rotations!(out, arr, θ, rotation_center; adjoint=false) if iszero(θ) out .= arr return true end # check for special cases where rotations are trivial if (iseven(size(arr, 1)) && iseven(size(arr, 2)) && rotation_center[1] ≈ size(arr, 1) ÷ 2 + 0.5 && rotation_center[2] ≈ size(arr, 2) ÷ 2 + 0.5) || (isodd(size(arr, 1)) && isodd(size(arr, 2)) && (rotation_center[1] == size(arr, 1) ÷ 2 + 1 && rotation_center[1] == size(arr, 2) ÷ 2 + 1)) if θ ≈ π / 2 if adjoint == false out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,)) else out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,)) end return true elseif θ ≈ π out .= reverse(arr, dims=(1,2)) return true elseif θ ≈ 3 / 2 * π if adjoint == false out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,)) else out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,)) end return true end end return false end """ imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear, rotation_center=size(arr) .÷ 2 .+ 1) Rotates an array in the first two dimensions around the center pixel `rotation_center`. The default value of `rotation_center` is defined such that there is a integer center pixel for even and odd sized arrays which it is rotated around. For an even sized array of size `(4,4)` this would be `(3,3)`, for an odd array of size `(3,3)` this would be `(2,2)` However, `rotation_center` can be also non-integer numbers if specified. The angle `θ` is interpreted in radians. The adjoint is defined with ChainRulesCore.jl. This method also runs with CUDA (and in principle all KernelAbstractions.jl supported backends). # Keywords * `method=:bilinear` for bilinear interpolation or `method=:nearest` for nearest neighbour * `rotation_center=size(arr) .÷ 2 .+ 1` means there is a real center pixel around it is rotated. # Examples ```julia-repl julia> arr = zeros((4,4,1,1)); arr[2,2,1,1] = 1; julia> arr 4×4×1×1 Array{Float64, 4}: [:, :, 1, 1] = 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 julia> NNlib.imrotate(arr, deg2rad(90)) # rotation around (3,3) 4×4×1×1 Array{Float64, 4}: [:, :, 1, 1] = 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 julia> NNlib.imrotate(arr, deg2rad(90), rotation_center=(2,2)) 4×4×1×1 Array{Float64, 4}: [:, :, 1, 1] = 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 julia> arr = zeros((3,3,1,1)); arr[1,2,1,1] = 1 1 julia> arr 3×3×1×1 Array{Float64, 4}: [:, :, 1, 1] = 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 julia> NNlib.imrotate(arr, deg2rad(45)) 3×3×1×1 Array{Float64, 4}: [:, :, 1, 1] = 0.0 0.207107 0.0 0.0 0.0 0.207107 0.0 0.0 0.0 julia> NNlib.imrotate(arr, deg2rad(45), method=:nearest) 3×3×1×1 Array{Float64, 4}: [:, :, 1, 1] = 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 ``` """ function imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear, rotation_center::Tuple=size(arr) .÷ 2 .+ 1) where T if (T <: Integer && method==:nearest || !(T <: Integer)) == false throw(ArgumentError("If the array has an Int eltype, only method=:nearest is supported")) end # prepare out, the sin and cos and type of rotation_center sinθ, cosθ, rotation_center, out = _prepare_imrotate(arr, θ, rotation_center) # such as 0°, 90°, 180°, 270° and only if the rotation_center is suitable _check_trivial_rotations!(out, arr, θ, rotation_center) && return out # KernelAbstractions specific backend = KernelAbstractions.get_backend(arr) if method == :bilinear kernel! = imrotate_kernel_bilinear!(backend) elseif method == :nearest kernel! = imrotate_kernel_nearest!(backend) else throw(ArgumentError("No interpolation method such as $method")) end kernel!(out, arr, sinθ, cosθ, rotation_center, size(arr, 1), size(arr, 2), ndrange=size(arr)) return out end """ ∇imrotate(dy, arr::AbstractArray{T, 4}, θ; method=:bilinear, rotation_center=size(arr) .÷ 2 .+ 1) Adjoint for `imrotate`. Gradient only with respect to `arr` and not `θ`. # Arguments * `dy`: input gradient * `arr`: Input from primal computation * `θ`: rotation angle in radians * `method=:bilinear` or `method=:nearest` * `rotation_center=size(arr) .÷ 2 .+ 1` rotates around a real center pixel for even and odd sized arrays """ function ∇imrotate(dy, arr::AbstractArray{T, 4}, θ; method=:bilinear, rotation_center::Tuple=size(arr) .÷ 2 .+ 1) where T sinθ, cosθ, rotation_center, out = _prepare_imrotate(arr, θ, rotation_center) # for the adjoint, the trivial rotations go in the other direction! # pass dy and not arr _check_trivial_rotations!(out, dy, θ, rotation_center, adjoint=true) && return out backend = KernelAbstractions.get_backend(arr) if method == :bilinear kernel! = ∇imrotate_kernel_bilinear!(backend) elseif method == :nearest kernel! = ∇imrotate_kernel_nearest!(backend) else throw(ArgumentError("No interpolation method such as $method")) end # don't pass arr but dy! kernel!(out, dy, sinθ, cosθ, rotation_center, size(arr, 1), size(arr, 2), ndrange=size(arr)) return out end @kernel function imrotate_kernel_nearest!(out, arr, sinθ, cosθ, rotation_center, imax, jmax) i, j, c, b = @index(Global, NTuple) r(x...) = round(x..., RoundNearestTiesAway) _, _, _, _, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, r) if 1 ≤ yrot_int ≤ imax && 1 ≤ xrot_int ≤ jmax @inbounds out[i, j, c, b] = arr[yrot_int, xrot_int, c, b] end end @kernel function imrotate_kernel_bilinear!(out, arr, sinθ, cosθ, rotation_center, imax, jmax) i, j, c, b = @index(Global, NTuple) yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, floor) if 1 ≤ yrot_int ≤ imax - 1 && 1 ≤ xrot_int ≤ jmax - 1 ydiff, ydiff_1minus, xdiff, xdiff_1minus = _bilinear_helper(yrot, xrot, yrot_f, xrot_f) @inbounds out[i, j, c, b] = ( xdiff_1minus * ydiff_1minus * arr[yrot_int , xrot_int , c, b] + xdiff_1minus * ydiff * arr[yrot_int + 1 , xrot_int , c, b] + xdiff * ydiff_1minus * arr[yrot_int , xrot_int + 1 , c, b] + xdiff * ydiff * arr[yrot_int + 1 , xrot_int + 1 , c, b]) end end @kernel function ∇imrotate_kernel_nearest!(out, arr, sinθ, cosθ, rotation_center, imax, jmax) i, j, c, b = @index(Global, NTuple) r(x...) = round(x..., RoundNearestTiesAway) _, _, _, _, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, r) if 1 ≤ yrot_int ≤ imax && 1 ≤ xrot_int ≤ jmax Atomix.@atomic out[yrot_int, xrot_int, c, b] += arr[i, j, c, b] end end @kernel function ∇imrotate_kernel_bilinear!(out, arr, sinθ, cosθ, rotation_center, imax, jmax) i, j, c, b = @index(Global, NTuple) yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, floor) if 1 ≤ yrot_int ≤ imax - 1 && 1 ≤ xrot_int ≤ jmax - 1 o = arr[i, j, c, b] ydiff, ydiff_1minus, xdiff, xdiff_1minus = _bilinear_helper(yrot, xrot, yrot_f, xrot_f) Atomix.@atomic out[yrot_int , xrot_int , c, b] += xdiff_1minus * ydiff_1minus * o Atomix.@atomic out[yrot_int + 1 , xrot_int , c, b] += xdiff_1minus * ydiff * o Atomix.@atomic out[yrot_int , xrot_int + 1, c, b] += xdiff * ydiff_1minus * o Atomix.@atomic out[yrot_int + 1 , xrot_int + 1, c, b] += xdiff * ydiff * o end end # is this rrule good? # no @thunk and @unthunk function ChainRulesCore.rrule(::typeof(imrotate), arr::AbstractArray{T}, θ; method=:bilinear, rotation_center=size(arr) .÷ 2 .+ 1) where T res = imrotate(arr, θ; method, rotation_center) function pb_rotate(dy) ad = ∇imrotate(unthunk(dy), arr, θ; method, rotation_center) return NoTangent(), ad, NoTangent() end return res, pb_rotate end ================================================ FILE: src/sampling.jl ================================================ @inline in_bounds(h, w, H, W) = 1 ≤ h ≤ H && 1 ≤ w ≤ W @inline in_bounds(h, w, d, H, W, D) = 1 ≤ h ≤ H && 1 ≤ w ≤ W && 1 ≤ d ≤ D # Borders are considered out-of-bounds for gradient. @inline clip_coordinate(coordinate, dim_size) = min(dim_size, max(1, coordinate)) @inline function ∇clip_coordinate(coordinate::C, dim_size) where {C} if coordinate ≤ 1 return C(1), C(0) elseif coordinate ≥ dim_size return C(dim_size), C(0) end coordinate, C(1) end @inline unnormalize(coordinate, dim_size) = ((coordinate + 1.0) * 0.5) * (dim_size - 1.0) + 1.0 @inline ∇unnormalize(coordinate, dim_size) = unnormalize(coordinate, dim_size), (dim_size - 1.0) * 0.5 @inline compute_source_index(coordinate, dim_size, ::Val{:zeros}) = unnormalize(coordinate, dim_size) @inline compute_source_index(coordinate, dim_size, ::Val{:border}) = clip_coordinate(unnormalize(coordinate, dim_size), dim_size) @inline ∇compute_source_index(coordinate, dim_size, ::Val{:zeros}) = ∇unnormalize(coordinate, dim_size) @inline function ∇compute_source_index(coordinate, dim_size, ::Val{:border}) source_coordinate, grad_in = ∇unnormalize(coordinate, dim_size) source_coordinate, grad_clip = ∇clip_coordinate(source_coordinate, dim_size) source_coordinate, grad_in * grad_clip end """ grid_sample(input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros) grid_sample(input::AbstractArray{T, 5}, grid::AbstractArray{T, 4}; padding_mode = :zeros) Given `input`, compute output by sampling `input` values at pixel locations from `grid`. Uses bilinear interpolation to calculate output values. This implementation assumes the extrema (`-1` and `1`) are considered as referring to the center points of the input’s corner pixels (i.e. align corners is `true`). # Arguments - `input`: Input array in `(W_in, H_in, [D_in,] C, N)` shape. - `grid`: Input grid in `(2, W_out, H_out, [D_out,] N)` shape. Where for each `(W_out, H_out, [D_out,] N)` grid contains `(x, y [,z])` coordinates that specify sampling locations normalized by the `input` shape. Therefore, `x`, `y` and [`z`] should have values in `[-1, 1]` range. For example, `(x = -1, y = -1, [z = -1])` is the left-top[-front] pixel of `input`, and `(x = 1, y = 1, [z = 1])` is the right-bottom-back pixel of `input`. Out-of-bound values are handled according to the `padding_mode`. - `padding_mode`: Out-of-bound padding. `:zeros` to use `0` for out-of-bound grid locations. `:border` to use border values for out-of-bound grid locations. Default is `:zeros`. # Returns `(W_out, H_out, [D_out,] C, N)` sampled grid from `input`. # Examples In the example below, grid contains two out-of-bound sampling locations, which are handled differently, depending on the `padding_mode`. ```jldoctest julia> x = reshape(collect(1.0:4.0), (2, 2, 1, 1)) 2×2×1×1 Array{Float64, 4}: [:, :, 1, 1] = 1.0 3.0 2.0 4.0 julia> grid = Array{Float64}(undef, 2, 3, 2, 1); julia> grid[:, 1, 1, 1] .= (-3, -1); julia> grid[:, 2, 1, 1] .= (0, -1); julia> grid[:, 3, 1, 1] .= (1, -1); julia> grid[:, 1, 2, 1] .= (-1, 1); julia> grid[:, 2, 2, 1] .= (0, 1); julia> grid[:, 3, 2, 1] .= (3, 1); julia> grid_sample(x, grid; padding_mode=:zeros) 3×2×1×1 Array{Float64, 4}: [:, :, 1, 1] = 0.0 3.0 1.5 3.5 2.0 0.0 julia> grid_sample(x, grid; padding_mode=:border) 3×2×1×1 Array{Float64, 4}: [:, :, 1, 1] = 1.0 3.0 1.5 3.5 2.0 4.0 ``` """ function grid_sample(input::AbstractArray{T,N}, grid; padding_mode = :zeros) where {T,N} if N ∉ (4,5) error("grid_sample is only supported for 4D and 5D arrays.") end iC, iN = size(input)[end-1:end] output_size = size(grid)[2:end-1] # W_out, H_out, [D_out] output = similar(input, T, (output_size..., iC, iN)) grid_sample!(output, input, grid, padding_mode) end function grid_sample!(output::AbstractArray{T,4}, input::AbstractArray{T,4}, grid, padding_mode=:zeros) where {T} pad = Val(padding_mode) iW, iH, iC, iN = size(input) _, gW, gH, _ = size(grid) # Loop over each output pixel. Threads.@threads for n in 1:iN for w in 1:gW, h in 1:gH _grid_sample_kernel!(output, input, grid, pad, w, h, n, iW, iH, iC) end end output end function grid_sample!(output::AbstractArray{T,5}, input::AbstractArray{T,5}, grid, padding_mode=:zeros) where {T} pad = Val(padding_mode) iW, iH, iD, iC, iN = size(input) _, gW, gH, gD, _ = size(grid) # Loop over each output pixel. Threads.@threads for n in 1:iN for w in 1:gW, h in 1:gH, d in 1:gD _grid_sample_kernel!(output, input, grid, pad, w, h, d, n, iW, iH, iD, iC) end end output end @inline function _grid_sample_kernel!( output::AbstractArray{T,4}, input::AbstractArray{T,4}, grid, padding_mode, w, h, n, iW, iH, iC, ) where {T} # Get the corresponding (x, y) coordinates from the grid. @inbounds x, y = grid[1, w, h, n], grid[2, w, h, n] ix = compute_source_index(x, iW, padding_mode) iy = compute_source_index(y, iH, padding_mode) # Get corner pixel values from (ix, iy) in north-east-south-west directions. ix_nw, iy_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy)) ix_ne, iy_ne = ix_nw + 1, iy_nw ix_sw, iy_sw = ix_nw, iy_nw + 1 ix_se, iy_se = ix_ne, iy_sw # Get surfaces to each neighbor (a.k.a. interpolation weights). nw = (ix_se - ix) * (iy_se - iy) ne = (ix - ix_sw) * (iy_sw - iy) sw = (ix_ne - ix) * (iy - iy_ne) se = (ix - ix_nw) * (iy - iy_nw) # ∀ channel: Calculate bilinear weighted pixel value. @inbounds for c in 1:iC r = zero(T) if in_bounds(iy_nw, ix_nw, iH, iW) r += input[ix_nw, iy_nw, c, n] * nw end if in_bounds(iy_ne, ix_ne, iH, iW) r += input[ix_ne, iy_ne, c, n] * ne end if in_bounds(iy_sw, ix_sw, iH, iW) r += input[ix_sw, iy_sw, c, n] * sw end if in_bounds(iy_se, ix_se, iH, iW) r += input[ix_se, iy_se, c, n] * se end output[w, h, c, n] = r end end @inline function _grid_sample_kernel!( output::AbstractArray{T,5}, input::AbstractArray{T,5}, grid, padding_mode, w, h, d, n, iW, iH, iD, iC, ) where {T} # Get the corresponding (x, y, z) coordinates from the grid. @inbounds x, y, z = grid[1, w, h, d, n], grid[2, w, h, d, n], grid[3, w, h, d, n] ix = compute_source_index(x, iW, padding_mode) iy = compute_source_index(y, iH, padding_mode) iz = compute_source_index(z, iD, padding_mode) # Get corner voxel values from (ix, iy, iz) in 8 directions (north-east-south-west-bottom-up). ix_nw, iy_nw, iz_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy)), unsafe_trunc(Int, floor(iz)) ix_ne, iy_ne, iz_ne = ix_nw + 1, iy_nw, iz_nw ix_sw, iy_sw, iz_sw = ix_nw, iy_nw + 1, iz_nw ix_se, iy_se, iz_se = ix_ne, iy_sw, iz_nw ix_nw_u, iy_nw_u, iz_nw_u = ix_nw, iy_nw, iz_nw + 1 ix_ne_u, iy_ne_u, iz_ne_u = ix_ne, iy_ne, iz_ne + 1 ix_sw_u, iy_sw_u, iz_sw_u = ix_sw, iy_sw, iz_sw + 1 ix_se_u, iy_se_u, iz_se_u = ix_se, iy_se, iz_se + 1 # Get volumes to each neighbor (a.k.a. interpolation weights). nw = (ix_se - ix) * (iy_se - iy) * (iz_se_u - iz) ne = (ix - ix_sw) * (iy_sw - iy) * (iz_sw_u - iz) sw = (ix_ne - ix) * (iy - iy_ne) * (iz_ne_u - iz) se = (ix - ix_nw) * (iy - iy_nw) * (iz_nw_u - iz) nw_u = (ix_se - ix) * (iy_se - iy) * (iz - iz_nw) ne_u = (ix - ix_sw) * (iy_sw - iy) * (iz - iz_sw) sw_u = (ix_ne - ix) * (iy - iy_ne) * (iz - iz_ne) se_u = (ix - ix_nw) * (iy - iy_nw) * (iz - iz_nw) # ∀ channel: Calculate trilinear weighted voxel value. @inbounds for c in 1:iC r = zero(T) if in_bounds(iy_nw, ix_nw, iz_nw, iH, iW, iD) r += input[ix_nw, iy_nw, iz_nw, c, n] * nw end if in_bounds(iy_ne, ix_ne, iz_ne, iH, iW, iD) r += input[ix_ne, iy_ne, iz_ne, c, n] * ne end if in_bounds(iy_sw, ix_sw, iz_sw, iH, iW, iD) r += input[ix_sw, iy_sw, iz_sw, c, n] * sw end if in_bounds(iy_se, ix_se, iz_se, iH, iW, iD) r += input[ix_se, iy_se, iz_se, c, n] * se end if in_bounds(iy_nw_u, ix_nw_u, iz_nw_u, iH, iW, iD) r += input[ix_nw_u, iy_nw_u, iz_nw_u, c, n] * nw_u end if in_bounds(iy_ne_u, ix_ne_u, iz_ne_u, iH, iW, iD) r += input[ix_ne_u, iy_ne_u, iz_ne_u, c, n] * ne_u end if in_bounds(iy_sw_u, ix_sw_u, iz_sw_u, iH, iW, iD) r += input[ix_sw_u, iy_sw_u, iz_sw_u, c, n] * sw_u end if in_bounds(iy_se_u, ix_se_u, iz_se_u, iH, iW, iD) r += input[ix_se_u, iy_se_u, iz_se_u, c, n] * se_u end output[w, h, d, c, n] = r end end """ ∇grid_sample(Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros) where T # Arguments - `Δ`: Input gradient in `(W_out, H_out, C, N)` shape (same as output of the primal computation). - `input`: Input from primal computation in `(W_in, H_in, C, N)` shape. - `grid`: Grid from primal computation in `(2, W_out, H_out, N)` shape. - `padding_mode`: Out-of-bound padding. `:zeros` to use `0` for out-of-bound grid locations. `:border` to use border values for out-of-bound grid locations. Should be the same as in primal computation. Default is `:zeros`. # Returns `dinput` (same shape as `input`) and `dgrid` (same shape as `grid`) gradients. """ function ∇grid_sample(Δ::AbstractArray{T,N}, input::AbstractArray{T,N}, grid; padding_mode=:zeros) where {T, N} if N ∉ (4,5) error("∇grid_sample is only supported for 4D and 5D arrays.") end dx = zeros(T, size(input)) dgrid = similar(grid) ∇grid_sample!(dx, dgrid, Δ, input, grid, padding_mode) end function ∇grid_sample!(dx::AbstractArray{T,4}, dgrid::AbstractArray{T,4}, Δ::AbstractArray{T,4}, input::AbstractArray{T,4}, grid::AbstractArray{T,4}, padding_mode) where {T} pad = Val(padding_mode) iW, iH, iC, iN = size(input) gW, gH = size(grid, 2), size(grid, 3) # Loop over each output pixel. Threads.@threads for n in 1:iN for w in 1:gW, h in 1:gH _∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, pad, w, h, n, iW, iH, iC) end end dx, dgrid end function ∇grid_sample!(dx::AbstractArray{T,5}, dgrid::AbstractArray{T,5}, Δ::AbstractArray{T,5}, input::AbstractArray{T,5}, grid::AbstractArray{T,5}, padding_mode) where {T} pad = Val(padding_mode) iW, iH, iD, iC, iN = size(input) gW, gH, gD = size(grid, 2), size(grid, 3), size(grid, 4) # Loop over each output voxel. Threads.@threads for n in 1:iN for w in 1:gW, h in 1:gH, d in 1:gD _∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, pad, w, h, d, n, iW, iH, iD, iC) end end dx, dgrid end @inline function _∇grid_sample_kernel!( dx::AbstractArray{T,4}, dgrid::AbstractArray{V,4}, Δ::AbstractArray{T,4}, input::AbstractArray{T,4}, grid::AbstractArray{V,4}, padding_mode, w, h, n, iW, iH, iC, ) where {T,V} # Get corresponding (x, y) from grid. @inbounds x, y = grid[1, w, h, n], grid[2, w, h, n] # Compute multipliers for gradients on ix, iy. ix, gix_mult = ∇compute_source_index(x, iW, padding_mode) iy, giy_mult = ∇compute_source_index(y, iH, padding_mode) # Get corner pixel values from (ix, iy) in north-east-south-west directions. ix_nw, iy_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy)) ix_ne, iy_ne = ix_nw + 1, iy_nw ix_sw, iy_sw = ix_nw, iy_nw + 1 ix_se, iy_se = ix_ne, iy_sw # Get surfaces to each neighbor (a.k.a. interpolation weights). nw = (ix_se - ix) * (iy_se - iy) ne = (ix - ix_sw) * (iy_sw - iy) sw = (ix_ne - ix) * (iy - iy_ne) se = (ix - ix_nw) * (iy - iy_nw) # ∀ channel: Calculate billinear weighted pixel value. gix, giy = zero(V), zero(V) @inbounds for c in 1:iC g_out = Δ[w, h, c, n] # Calculate dx and dgrid partials. if in_bounds(iy_nw, ix_nw, iH, iW) _safe_add!(dx, g_out * nw, ix_nw, iy_nw, c, n) nw_val = input[ix_nw, iy_nw, c, n] gix -= nw_val * (iy_se - iy) * g_out giy -= nw_val * (ix_se - ix) * g_out end if in_bounds(iy_ne, ix_ne, iH, iW) _safe_add!(dx, g_out * ne, ix_ne, iy_ne, c, n) ne_val = input[ix_ne, iy_ne, c, n] gix += ne_val * (iy_sw - iy) * g_out giy -= ne_val * (ix - ix_sw) * g_out end if in_bounds(iy_sw, ix_sw, iH, iW) _safe_add!(dx, g_out * sw, ix_sw, iy_sw, c, n) sw_val = input[ix_sw, iy_sw, c, n] gix -= sw_val * (iy - iy_ne) * g_out giy += sw_val * (ix_ne - ix) * g_out end if in_bounds(iy_se, ix_se, iH, iW) _safe_add!(dx, g_out * se, ix_se, iy_se, c, n) se_val = input[ix_se, iy_se, c, n] gix += se_val * (iy - iy_nw) * g_out giy += se_val * (ix - ix_nw) * g_out end end @inbounds dgrid[1, w, h, n] = gix_mult * gix @inbounds dgrid[2, w, h, n] = giy_mult * giy end @inline function _∇grid_sample_kernel!( dx::AbstractArray{T,5}, dgrid::AbstractArray{V,5}, Δ::AbstractArray{T,5}, input::AbstractArray{T,5}, grid::AbstractArray{V,5}, padding_mode, w, h, d, n, iW, iH, iD, iC, ) where {T,V} # Get corresponding (x, y, z) from grid. @inbounds x, y, z = grid[1, w, h, d, n], grid[2, w, h, d, n], grid[3, w, h, d, n] # Compute multipliers for gradients on ix, iy, iz. ix, gix_mult = ∇compute_source_index(x, iW, padding_mode) iy, giy_mult = ∇compute_source_index(y, iH, padding_mode) iz, giz_mult = ∇compute_source_index(z, iD, padding_mode) # Get corner pixel values from (ix, iy, iz) ix_0 = unsafe_trunc(Int, floor(ix)) iy_0 = unsafe_trunc(Int, floor(iy)) iz_0 = unsafe_trunc(Int, floor(iz)) ix_1 = ix_0 + 1 iy_1 = iy_0 + 1 iz_1 = iz_0 + 1 # Get difference of coordinate wx_0 = ix - ix_0 wy_0 = iy - iy_0 wz_0 = iz - iz_0 wx_1 = ix_1 - ix wy_1 = iy_1 - iy wz_1 = iz_1 - iz # Calculate weights (volume of diagnal vertex cube) # w_{abc} = wx_{¬a}*wy_{¬b}*wz_{¬c} weight_000 = wx_1 * wy_1 * wz_1 weight_001 = wx_1 * wy_1 * wz_0 weight_010 = wx_1 * wy_0 * wz_1 weight_011 = wx_1 * wy_0 * wz_0 weight_100 = wx_0 * wy_1 * wz_1 weight_101 = wx_0 * wy_1 * wz_0 weight_110 = wx_0 * wy_0 * wz_1 weight_111 = wx_0 * wy_0 * wz_0 # ∂w_{abc}/∂x=(-1)^{¬a} wy_{¬b}*wz_{¬c}, ∂w/∂y = (-1)^{¬b} wx_{¬a}*wz_{¬c}, ∂w/∂z=(-1)^{¬c} wx_{¬a}*wy_{¬b} # abc are the index of the vertex of the cube (001,010...) # Initialize gradient accumulators gix, giy, giz = zero(V), zero(V), zero(V) @inbounds for c in 1:iC g_out = Δ[w, h, d, c, n] # Calculate dx and dgrid partials for all 8 corners if in_bounds(iy_0, ix_0, iz_0, iH, iW, iD) _safe_add!(dx, g_out * weight_000, ix_0, iy_0, iz_0, c, n) val = input[ix_0, iy_0, iz_0, c, n] gix -= val * wy_1 * wz_1 * g_out giy -= val * wx_1 * wz_1 * g_out giz -= val * wx_1 * wy_1 * g_out end if in_bounds(iy_0, ix_0, iz_1, iH, iW, iD) _safe_add!(dx, g_out * weight_001, ix_0, iy_0, iz_1, c, n) val = input[ix_0, iy_0, iz_1, c, n] gix -= val * wy_1 * wz_0 * g_out giy -= val * wx_1 * wz_0 * g_out giz += val * wx_1 * wy_1 * g_out end if in_bounds(iy_1, ix_0, iz_0, iH, iW, iD) _safe_add!(dx, g_out * weight_010, ix_0, iy_1, iz_0, c, n) val = input[ix_0, iy_1, iz_0, c, n] gix -= val * wy_0 * wz_1 * g_out giy += val * wx_1 * wz_1 * g_out giz -= val * wx_1 * wy_0 * g_out end if in_bounds(iy_1, ix_0, iz_1, iH, iW, iD) _safe_add!(dx, g_out * weight_011, ix_0, iy_1, iz_1, c, n) val = input[ix_0, iy_1, iz_1, c, n] gix -= val * wy_0 * wz_0 * g_out giy += val * wx_1 * wz_0 * g_out giz += val * wx_1 * wy_0 * g_out end if in_bounds(iy_0, ix_1, iz_0, iH, iW, iD) _safe_add!(dx, g_out * weight_100, ix_1, iy_0, iz_0, c, n) val = input[ix_1, iy_0, iz_0, c, n] gix += val * wy_1 * wz_1 * g_out giy -= val * wx_0 * wz_1 * g_out giz -= val * wx_0 * wy_1 * g_out end if in_bounds(iy_0, ix_1, iz_1, iH, iW, iD) _safe_add!(dx, g_out * weight_101, ix_1, iy_0, iz_1, c, n) val = input[ix_1, iy_0, iz_1, c, n] gix += val * wy_1 * wz_0 * g_out giy -= val * wx_0 * wz_0 * g_out giz += val * wx_0 * wy_1 * g_out end if in_bounds(iy_1, ix_1, iz_0, iH, iW, iD) _safe_add!(dx, g_out * weight_110, ix_1, iy_1, iz_0, c, n) val = input[ix_1, iy_1, iz_0, c, n] gix += val * wy_0 * wz_1 * g_out giy += val * wx_0 * wz_1 * g_out giz -= val * wx_0 * wy_0 * g_out end if in_bounds(iy_1, ix_1, iz_1, iH, iW, iD) _safe_add!(dx, g_out * weight_111, ix_1, iy_1, iz_1, c, n) val = input[ix_1, iy_1, iz_1, c, n] gix += val * wy_0 * wz_0 * g_out giy += val * wx_0 * wz_0 * g_out giz += val * wx_0 * wy_0 * g_out end end @inbounds dgrid[1, w, h, d, n] = gix_mult * gix @inbounds dgrid[2, w, h, d, n] = giy_mult * giy @inbounds dgrid[3, w, h, d, n] = giz_mult * giz end @inline function _safe_add!(dx, value, ix, iy, c, n) @inbounds dx[ix, iy, c, n] += value end @inline function _safe_add!(dx, value, ix, iy, iz, c, n) @inbounds dx[ix, iy, iz, c, n] += value end function rrule(::typeof(grid_sample), x, grid; padding_mode) y = grid_sample(x, grid; padding_mode=padding_mode) function grid_sample_pullback(Δ) ∇x, ∇grid = ∇grid_sample(unthunk(Δ), x, grid; padding_mode=padding_mode) NoTangent(), ∇x, ∇grid end return y, grid_sample_pullback end ================================================ FILE: src/scatter.jl ================================================ ## Scatter API # - Scatter: # - scatter(op, src, idx) # - scatter!(op, dst, src, idx) # - Scatter destination backpropagation # - ∇scatter!_dst # - Scatter source backpropagation # - ∇scatter_src # - ∇scatter!_src # typelength(::Type{<:Number}) = 1 typelength(::Type{<:NTuple{M}}) where M = M typelength(::Type{CartesianIndex{M}}) where M = M """ Performs dimensional consistency checks and return the dimensionality of the scattered objects. """ function scatter_dims( X::AbstractArray{Tx,Nx}, Y::AbstractArray{Ty,Ny}, idx::AbstractArray{Tidx,Nidx}, ) where {Tx,Ty,Tidx,Nx,Ny,Nidx} dims = scatter_dims(Nx, Ny, typelength(Tidx), Nidx) size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) return dims end function scatter_dims(Nx, Ny, M, Nidx) @assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)." dims = Nx - M dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims.")) return dims end _view(X, colons, k) = view(X, colons..., k...) _view(X, colons, k::Union{Integer, CartesianIndex}) = view(X, colons..., k) """ NNlib.scatter!(op, dst, src, idx) Scatter operation, which writes data in `src` into `dst` at locations `idx`. A binary reduction operator `op` is applied during the scatter. For each index `k` in `idx`, accumulates values in `dst` according to dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...]) See also [`scatter`](@ref), [`gather`](@ref). # Arguments - `op`: Operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max`, `min` and `mean`. - `dst`: The destination for `src` to aggregate to. This argument will be mutated. - `src`: The source data for aggregating. - `idx`: The mapping for aggregation from source (index) to destination (value). The `idx` array can contain either integers or tuples. # Examples ```jldoctest julia> NNlib.scatter!(+, ones(3), [10,100], [1,3]) 3-element Vector{Float64}: 11.0 1.0 101.0 julia> NNlib.scatter!(*, fill(0.5, 2, 4), [1 10; 100 1000], [3,2]) 2×4 Matrix{Float64}: 0.5 5.0 0.5 0.5 0.5 500.0 50.0 0.5 ``` """ function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) where OP dims = scatter_dims(dst, src, idx) colons = Base.ntuple(_->Colon(), dims) for k in CartesianIndices(idx) dst_v = _view(dst, colons, idx[k]) src_v = _view(src, colons, k) dst_v .= (op).(dst_v, src_v) end dst end for AT in (AbstractArray, AnyGPUArray) @eval function scatter!(op::typeof(mean), dst::$AT, src::$AT, idx::$AT) Ns = scatter!(+, zero(dst), one.(src), idx) dst_ = scatter!(+, zero(dst), src, idx) dst .+= safe_div.(dst_, Ns) return dst end end function scatter!(op::OP, dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) where OP n_dims = scatter_dims(dst, src, idx) args = if n_dims == 0 ndrange = length(idx) () else dims = size(dst)[1:n_dims] max_dims_idx = prod(dims) ndrange = max_dims_idx * length(idx) (CartesianIndices(dims), max_dims_idx) end _scatter!(KernelAbstractions.get_backend(dst))( op, dst, src, idx, args...; ndrange) dst end @kernel function _scatter!(op::OP, dst, src, idxs) where OP i = @index(Global) @inbounds idx = Tuple(_convert_i64(idxs[i])) @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) # FIXME `@atomic` macro silently fails to perform atomic op below # @atomic dst[idx...] = op(dst[idx...], src[i]) end @kernel function _scatter!( op::OP, dst, src, idxs, dim_ids::CartesianIndices, max_dims_idx::Int, ) where OP i = @index(Global) j, k = divrem(i - 1, max_dims_idx) @inbounds idx = (Tuple(dim_ids[k + 1])..., Tuple(_convert_i64(idxs[j + 1]))...) @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) # FIXME `@atomic` macro silently fails to perform atomic op below # dim_i = Tuple(dim_ids[k + 1]) # idx = idxs[j + 1] # @atomic dst[dim_i..., idx...] = op(dst[dim_i..., idx...], src[i]) end # Allow non-Int64 indices by converting them to Int64 when index eltype <: Integer. # All other index types (tuples, cartesian indices) must be in Int64 already. @inline _convert_i64(x::Int) = x @inline _convert_i64(x::Integer) = Int(x) @inline _convert_i64(x) = x """ NNlib.scatter(op, src, idx; [init, dstsize]) Scatter operation allocating a destination array `dst` and calling `scatter!(op, dst, src, idx)` on it. * If keyword `init` is provided, it is used to initialize the content of `dst`. Otherwise, the init values is inferred from the reduction operator `op` for some common operators (e.g. `init = 0` for `op = +`). * If `dstsize` is provided, it will be used to define the size of destination array, otherwise it will be inferred by `src` and `idx`. See [`scatter!`](@ref) for full details on how `idx` works. # Examples ```jldoctest julia> NNlib.scatter(+, [10,100,1000], [3,1,2]) 3-element Vector{Int64}: 100 1000 10 julia> NNlib.scatter(+, [1 2 3 4; 5 6 7 8], [2,1,1,5]) 2×5 Matrix{Int64}: 5 1 0 0 4 13 5 0 0 8 julia> NNlib.scatter(*, [10,200,3000], [1,4,2]; init = 10, dstsize = 6) 6-element Vector{Int64}: 100 30000 10 2000 10 10 ``` """ function scatter( op::OP, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}; init = nothing, dstsize = nothing, ) where {Tsrc,Tidx,Nsrc,Nidx,OP} dims = Nsrc - Nidx dstsz = isnothing(dstsize) ? (size(src)[1:dims]..., maximum_dims(idx)...) : dstsize dst = similar(src, Tsrc, dstsz) xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init fill!(dst, xinit) scatter!(op, dst, src, idx) end scatter_empty(op, T) = Base.reduce_empty(op, T) scatter_empty(op::typeof(-), T) = zero(T) scatter_empty(op::typeof(/), T) = one(T) scatter_empty(op::typeof(min), T) = typemax(T) scatter_empty(op::typeof(max), T) = typemin(T) scatter_empty(op::typeof(mean), T) = zero(T) ## Gradients ∇scatter!_src(op, Δ, dst, src, idx) = ∇scatter_src(op, Δ, dst, src, idx) ∇scatter!_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src, idx) = gather(dst, idx) .* ∇scatter_src(op, Δ, dst, src, idx) ∇scatter!_dst(op, Δ, dst, y) = Δ ∇scatter!_dst(op::Union{typeof(max),typeof(min)}, Δ, dst_old, dst) = (dst_old .== op.(dst_old, dst)) .* Δ modify_src(::typeof(+), X) = X modify_src(::typeof(-), X) = -X modify_src(::typeof(*), X, Y) = X modify_src(::typeof(/), X, Y) = .-X ./ Y.^2 ∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, gather(Δ, idx)) ∇scatter_src(::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== gather(dst, idx)) .* gather(Δ, idx) function ∇scatter_src( op::Union{typeof(*),typeof(/)}, Δ, dst, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}, ) where {Tsrc,Tidx,Nsrc,Nidx} dims = Nsrc - Nidx Δsrc = modify_src(op, gather(Δ, idx), src) rev_idx = reverse_indices(idx) ax = CartesianIndices(axes(src)[1:dims]) for k in CartesianIndices(idx) inds = filter(x -> x != k, rev_idx[idx[k]]) for i in ax Δsrc[i, k] = op(Δsrc[i, k], prod(j -> src[i, j], inds)) end end Δsrc end function ∇scatter_src( op::Union{typeof(*), typeof(/)}, Δ, dst, src::AnyGPUArray{Tsrc, Nsrc}, idx::AnyGPUArray{Tidx, Nidx}, ) where {Tsrc, Nsrc, Tidx, Nidx} n_dims = Nsrc - Nidx Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) rev_idx = NNlib.reverse_indices(idx) args = if n_dims == 0 ndrange = length(idx) () else dims = size(dst)[1:n_dims] max_dims_idx = prod(dims) ndrange = max_dims_idx * length(idx) (CartesianIndices(dims), max_dims_idx) end _∇scatter_src(KernelAbstractions.get_backend(src))( op, Δsrc, src, idx, rev_idx, args...; ndrange) KernelAbstractions.unsafe_free!(rev_idx) return Δsrc end @kernel function _∇scatter_src(op, Δsrc, src::AbstractArray{T}, idx, rev_idx) where T i = @index(Global) cart_j = CartesianIndices(idx)[i] @inbounds begin inds = rev_idx[Tuple(idx[cart_j])...] x = one(T) for k in inds x *= src[k] end x /= src[cart_j] Δsrc[cart_j] = op(Δsrc[cart_j], x) end end @kernel function _∇scatter_src( op, Δsrc, src::AbstractArray{T}, idx, rev_idx, dim_ids::CartesianIndices, max_dims_idx::Int, ) where T i = @index(Global) j, k = fldmod1(i, max_dims_idx) @inbounds begin cart_j = CartesianIndices(idx)[j] cart_k = dim_ids[k] inds = rev_idx[Tuple(cart_j)...] x = one(T) for s in inds x *= src[Tuple(cart_k)..., Tuple(s)...] end x /= src[i] Δsrc[i] = op(Δsrc[i], x) end end function ∇scatter_src( ::typeof(mean), Δ, dst, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}, ) where {Tsrc,Tidx,Nsrc,Nidx} M = typelength(Tidx) num = gather(Δ, idx) counts = fill!(similar(Δ, Int, size(Δ)[end-M+1:end]), 0) scatter!(+, counts, fill!(similar(idx, Int), 1), idx) den = gather(counts, idx) # make num and den broadcast compatible for i in 1:ndims(num)-ndims(den) den = unsqueeze(den) end return safe_div.(num, den) end ∇scatter_src(op, Δ, dst, src, idx) = ∇scatter_src(op, unthunk(Δ), dst, src, idx) function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) dst_old = copy(dst) scatter!(op, dst, src, idx) scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, unthunk(Δ), dst_old, dst), ∇scatter!_src(op, unthunk(Δ), dst, src, idx), NoTangent()) dst, scatter!_pullback end function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray; kws...) y = scatter(op, src, idx; kws...) scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, unthunk(Δ), y, src, idx), NoTangent()) y, scatter_pullback end ================================================ FILE: src/softmax.jl ================================================ """ softmax(x; dims = 1) [Softmax](https://en.wikipedia.org/wiki/Softmax_function) turns input array `x` into probability distributions that sum to 1 along the dimensions specified by `dims`. It is semantically equivalent to the following: softmax(x; dims = 1) = exp.(x) ./ sum(exp.(x), dims = dims) with additional manipulations enhancing numerical stability. For a matrix input `x` it will by default (`dims = 1`) treat it as a batch of vectors, with each column independent. Keyword `dims = 2` will instead treat rows independently, and so on. See also [`logsoftmax`](@ref). # Examples ```jldoctest; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" julia> softmax([1, 2, 3]) 3-element Vector{Float64}: 0.09003057317038046 0.24472847105479764 0.6652409557748218 julia> softmax([1 2 3; 2 2 2]) # dims=1 2×3 Matrix{Float64}: 0.268941 0.5 0.731059 0.731059 0.5 0.268941 julia> softmax([1 2 3; 2 2 2]; dims=2) 2×3 Matrix{Float64}: 0.0900306 0.244728 0.665241 0.333333 0.333333 0.333333 ``` Note that, when used with Flux.jl, `softmax` must not be passed to layers like `Dense` which accept an activation function. The activation is broadcasted over the result, thus applies to individual numbers. But `softmax` always needs to see the whole column. ```julia-repl julia> using Flux julia> x = randn(Float32, 4, 4, 3, 13); julia> model = Chain(Conv((4, 4), 3 => 8, tanh), Flux.flatten, Dense(8 => 7), softmax); julia> model(x) |> size (7, 13) julia> Dense(4 => 7, softmax)(x) ERROR: `softmax(x)` called with a number, but it expects an array. ``` """ softmax(x::AbstractArray{T}; dims = 1) where {T} = softmax!(similar(x, float(T)), x; dims) softmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims) function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T} max_ = fast_maximum(x; dims) if all(isfinite, max_) @fastmath out .= exp.(x .- max_) else _zero, _one, _inf = T(0), T(1), T(Inf) @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_)) end tmp = dims isa Colon ? sum(out) : sum!(max_, out) out ./= tmp end function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S} dx = if within_gradient(y) tmp = dy .* y tmp .- y .* sum(tmp; dims) else # This path is faster, only safe for 1st derivatives though. # Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads, # but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30 out = similar(y, promote_type(T,S)) # sure to be mutable out .= dy .* y out .= out .- y .* sum(out; dims) end end function rrule(::typeof(softmax), x; dims = 1) y = softmax(x; dims) softmax_pullback(dy) = (NoTangent(), ∇softmax_data(unthunk(dy), y; dims)) return y, softmax_pullback end fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf)) """ logsoftmax(x; dims = 1) Computes the log of softmax in a more numerically stable way than directly taking `log.(softmax(xs))`. Commonly used in computing cross entropy loss. It is semantically equivalent to the following: logsoftmax(x; dims = 1) = x .- log.(sum(exp.(x), dims = dims)) See also [`softmax`](@ref). """ logsoftmax(x::AbstractArray{T}; dims = 1) where {T} = logsoftmax!(similar(x, float(T)), x; dims) logsoftmax!(x::AbstractArray; dims = 1) = logsoftmax!(x, x; dims) function logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T} max_ = fast_maximum(x; dims) if all(isfinite, max_) out .= x .- max_ else _zero, _minf, _inf = T(0), T(-Inf), T(Inf) @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _zero, _minf), x - max_) end @fastmath log_ = log.(sum(exp, out; dims)) out .-= log_ end function ∇logsoftmax_data(dy::AbstractArray, y::AbstractArray; dims = 1) # This was previously `∇logsoftmax!(dx, dy, x, y; dims)` to allow CUDA overloads, but that was slow. dx = dy .- sum(dy; dims) .* exp.(y) end function rrule(::typeof(logsoftmax), x; dims = 1) y = logsoftmax(x; dims) logsoftmax_pullback(dy) = (NoTangent(), ∇logsoftmax_data(unthunk(dy), y; dims)) return y, logsoftmax_pullback end """ logsumexp(x; dims = :) Computes `log.(sum(exp.(x); dims))` in a numerically stable way. Without `dims` keyword this returns a scalar. See also [`logsoftmax`](@ref). """ function logsumexp(x::AbstractArray; dims = :) max_ = fast_maximum(x; dims) @fastmath max_ .+ log.(sum(exp.(x .- max_); dims)) end function rrule(::typeof(logsumexp), x; dims = :) # The gradient is `softmax`, but both compute `tmp` so it's worth saving. max_ = fast_maximum(x; dims) @fastmath tmp = exp.(x .- max_) @fastmath y = max_ .+ log.(sum(tmp; dims)) logsumexp_pullback(dy) = (NoTangent(), unthunk(dy) .* tmp ./ sum(tmp; dims)) return y, logsumexp_pullback end # Informative error message if any of the softmax variants is called with a number for f in (:softmax, :logsoftmax, :softmax!, :logsoftmax!, :logsumexp) @eval $(f)(x::Number, args...) = error("`", $(string(f)), "(x)` called with a number, but it expects an array. Usually this is because a layer like `Dense(3,4,softmax)` is broadcasting it like an activation function; `softmax` needs to be outside the layer.") end ================================================ FILE: src/upsample.jl ================================================ """ pixel_shuffle(x, r::Integer) Pixel shuffling operation, upscaling by a factor `r`. For 4-arrays representing `N` images, the operation converts input `size(x) == (W, H, r^2*C, N)` to output of size `(r*W, r*H, C, N)`. For `D`-dimensional data, it expects `ndims(x) == D+2` with channel and batch dimensions, and divides the number of channels by `r^D`. Used in super-resolution networks to upsample towards high resolution features. Reference: Shi et. al., "Real-Time Single Image and Video Super-Resolution ...", CVPR 2016, https://arxiv.org/abs/1609.05158 # Examples ```jldoctest julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1] 2×3×4×1 Array{Float64, 4}: [:, :, 1, 1] = 11.1 12.1 13.1 21.1 22.1 23.1 [:, :, 2, 1] = 11.2 12.2 13.2 21.2 22.2 23.2 [:, :, 3, 1] = 11.3 12.3 13.3 21.3 22.3 23.3 [:, :, 4, 1] = 11.4 12.4 13.4 21.4 22.4 23.4 julia> pixel_shuffle(x, 2) # 4 channels used up as 2x upscaling of image dimensions 4×6×1×1 Array{Float64, 4}: [:, :, 1, 1] = 11.1 11.3 12.1 12.3 13.1 13.3 11.2 11.4 12.2 12.4 13.2 13.4 21.1 21.3 22.1 22.3 23.1 23.3 21.2 21.4 22.2 22.4 23.2 23.4 julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1] 3×6×1 Array{Float64, 3}: [:, :, 1] = 1.1 1.2 1.3 1.4 1.5 1.6 2.1 2.2 2.3 2.4 2.5 2.6 3.1 3.2 3.3 3.4 3.5 3.6 julia> pixel_shuffle(y, 2) # 1D image, with 6 channels reduced to 3 6×3×1 Array{Float64, 3}: [:, :, 1] = 1.1 1.3 1.5 1.2 1.4 1.6 2.1 2.3 2.5 2.2 2.4 2.6 3.1 3.3 3.5 3.2 3.4 3.6 ``` """ function pixel_shuffle(x::AbstractArray, r::Integer) ndims(x) > 2 || throw(ArgumentError("expected x with at least 3 dimensions")) d = ndims(x) - 2 sizein = size(x)[1:d] cin, n = size(x, d+1), size(x, d+2) cin % r^d == 0 || throw(ArgumentError("expected channel dimension to be divisible by r^d = $( r^d), where d=$d is the number of spatial dimensions. Given r=$r, input size(x) = $(size(x))")) cout = cin ÷ r^d x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n) perm = hcat(d+1:2d, 1:d) |> transpose |> vec # = [d+1, 1, d+2, 2, ..., 2d, d] x = permutedims(x, (perm..., 2d+1, 2d+2)) return reshape(x, map(s -> s*r, sizein)..., cout, n) end # # Upsampling # # GPU based bilinear upsampling including its gradient # # Based on the Caffe2 implementation at: # The code is a translation from the following files: # - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/operators/upsample_op.cu # - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/core/common_gpu.h # # Copyright (c) 2016-2021 Facebook Inc. # Copyright (c) 2015 Google Inc. # Copyright (c) 2015 Yangqing Jia # Copyright 2019-2020 Kakao Brain # # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, are # permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this list of # conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, this list of # conditions and the following disclaimer in the documentation and/or other materials # provided with the distribution. # # 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America and # IDIAP Research Institute nor the names of its contributors may be used to endorse or # promote products derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF # MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE # COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR # TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # Forward and backward pass have been tested to produce the same output # as pytorch with align_corners=True - it works modulo bit noise. # pytorch's default is align_corners=False, because otherwise the gradients depend on the # image size, which should be avoided -> this should be considered here as well """ upsample_nearest(x, scale::NTuple{S,Int}) upsample_nearest(x; size::NTuple{S,Int}) Upsamples the array `x` by integer multiples along the first `S` dimensions. Subsequent dimensions of `x` are not altered. Either the `scale` factors or the final output `size` can be specified. See also [`upsample_bilinear`](@ref), for two dimensions of an `N=4` array. # Example ```jldoctest julia> upsample_nearest([1 2 3; 4 5 6], (2, 3)) 4×9 Matrix{$Int}: 1 1 1 2 2 2 3 3 3 1 1 1 2 2 2 3 3 3 4 4 4 5 5 5 6 6 6 4 4 4 5 5 5 6 6 6 julia> ans == upsample_nearest([1 2 3; 4 5 6]; size=(4, 9)) # equivalent true julia> upsample_nearest([1 2 3; 4 5 6], (2,)) 4×3 Matrix{$Int}: 1 2 3 1 2 3 4 5 6 4 5 6 julia> ans == upsample_nearest([1 2 3; 4 5 6], size=(4,)) true ``` """ function upsample_nearest(x::AbstractArray; size::NTuple{S,Int}) where S xsize = Base.size(x)[1:S] all(size .% xsize .== 0) || throw(ArgumentError("expected output size divisible by input size")) return upsample_nearest(x, size .÷ xsize) end function upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S} S in 1:N || throw(ArgumentError("can't upsample ndims(x)=$N with scale=$scales")) outsize = ntuple(d -> d<=S ? scales[d] * size(x,d) : size(x,d), N) out = similar(x, T, outsize) writesize = ntuple(N+S) do d d > 2S && return size(x, d-S) isodd(d) ? scales[cld(d,2)] : size(x, cld(d,2)) end readsize = ntuple(N+S) do d d > 2S && return size(x, d-S) isodd(d) ? 1 : size(x, cld(d,2)) end reshape(out, writesize) .= reshape(x, readsize) out end """ ∇upsample_nearest(Δ::AbstractArray{T,3}, scales::NTuple{S, <:Integer}) where T # Arguments - `Δ`: Incoming gradient array, backpropagated from downstream layers - `scales`: scales by which the image was upsampled in the first place # Outputs - `dx`: Downsampled version of `Δ` """ function ∇upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S} outsize = ntuple(N) do d d > S && return size(x,d) rem(size(x,d), scales[d]) == 0 || throw(ArgumentError("expected input array evenly divisible by scale=$scales, got size(x)=$(size(x))")) div(size(x,d), scales[d]) end tempsize = ntuple(N+S) do d d > 2S && return size(x, d-S) s = scales[cld(d,2)] isodd(d) ? s : div(size(x, cld(d,2)),s) end mid = sum(reshape(x, tempsize), dims=ntuple(d -> 2d-1, S)) reshape(mid, outsize) end function rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple) Ω = upsample_nearest(x, s) upsample_nearest_pullback(Δ) = (NoTangent(), ∇upsample_nearest(unthunk(Δ), s), NoTangent()) return Ω, upsample_nearest_pullback end """ upsample_linear(x::AbstractArray{T,3}, scale::Real; align_corners::Bool = true) upsample_linear(x::AbstractArray{T,3}; size::Integer, align_corners::Bool = true) Upsamples the first dimension of the array `x` by the upsample provided `scale`, using linear interpolation. As an alternative to using `scale`, the resulting array `size` can be directly specified with a keyword argument. The size of the output is equal to `(scale*S1, S2, S3)`, where `S1, S2, S3 = size(x)`. """ # the user facing function function upsample_linear(x::AbstractArray{<:Any,N}, scale::NTuple{M,Real}; align_corners::Bool = true) where {N,M} M == N-2 || error("The scale argument should be an NTuple with length $(N-2), but it has length $M.") outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), N-2) return upsample_linear(x; size=outsize, align_corners) end # convenience for single-number scale upsample_linear(x::AbstractArray{<:Any,N}, scale::Real; align_corners::Bool = true) where N = upsample_linear(x, ntuple(_ -> scale, N-2); align_corners) # this actually calls the upsamling kernel function upsample_linear(x::AbstractArray{T,N}; size::Union{Integer, NTuple{<:Any,Integer}}, align_corners::Bool = true) where {T,N} length(size) == N-2 || error("The scale argument should be an NTuple with length $(N-2), but it has length $(length(size)).") if Base.size(x)[1:N-2] == size return x end y = similar(x, T, size..., Base.size(x)[end-1:end]...) return upsample_linear_kernel!(y, x; align_corners) end # Convenience definition for integers. The algo internally works with floats and then rounds. function upsample_linear(x::AbstractArray{T,<:Any}; size, align_corners::Bool = true) where T<:Integer y = float.(x) res = upsample_linear(y; size=size, align_corners) return round.(T, res) end """ ∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer, align_corners::Bool = true) where T # Arguments - `Δ`: Incoming gradient array, backpropagated from downstream layers - `size`: Size of the image upsampled in the first place # Outputs - `dx`: Downsampled version of `Δ` """ function ∇upsample_linear(Δ::AbstractArray{T,N}; size::NTuple{<:Any,Integer}, align_corners::Bool = true) where {T,N} if Base.size(Δ)[1:N-2] == size return Δ end dx = fill!(similar(Δ, T, size..., Base.size(Δ)[end-1:end]...), zero(T)) return ∇upsample_linear_kernel!(dx, Δ; align_corners) end function rrule(::typeof(upsample_linear), x::AbstractArray{<:Any,N}; size, align_corners::Bool = true) where N Ω = upsample_linear(x; size, align_corners) function upsample_linear_pullback(Δ) (NoTangent(), ∇upsample_linear(unthunk(Δ); size=Base.size(x)[1:N-2], align_corners)) end return Ω, upsample_linear_pullback end """ upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}; align_corners::Bool = true) upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}, align_corners::Bool = true) Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`, using bilinear interpolation. As an alternative to using `scale`, the resulting image `size` can be directly specified with a keyword argument. The size of the output is equal to `(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. # Examples ```jldoctest julia> x = reshape(Float32[1 2 3; 4 5 6], (2,3,1,1)) 2×3×1×1 Array{Float32, 4}: [:, :, 1, 1] = 1.0 2.0 3.0 4.0 5.0 6.0 julia> upsample_bilinear(x, (2, 3)) 4×9×1×1 Array{Float32, 4}: [:, :, 1, 1] = 1.0 1.25 1.5 1.75 2.0 2.25 2.5 2.75 3.0 2.0 2.25 2.5 2.75 3.0 3.25 3.5 3.75 4.0 3.0 3.25 3.5 3.75 4.0 4.25 4.5 4.75 5.0 4.0 4.25 4.5 4.75 5.0 5.25 5.5 5.75 6.0 julia> ans == upsample_bilinear(x; size=(4, 9)) # specify ouput size instead true julia> upsample_bilinear(x, (2.5, 3.5)) # non-integer scaling factors are allowed 5×10×1×1 Array{Float32, 4}: [:, :, 1, 1] = 1.0 1.22222 1.44444 1.66667 1.88889 … 2.33333 2.55556 2.77778 3.0 1.75 1.97222 2.19444 2.41667 2.63889 3.08333 3.30556 3.52778 3.75 2.5 2.72222 2.94444 3.16667 3.38889 3.83333 4.05556 4.27778 4.5 3.25 3.47222 3.69444 3.91667 4.13889 4.58333 4.80556 5.02778 5.25 4.0 4.22222 4.44444 4.66667 4.88889 5.33333 5.55556 5.77778 6.0 ``` """ upsample_bilinear(x, scale; align_corners::Bool = true) = upsample_linear(x, scale; align_corners) upsample_bilinear(x; size, align_corners::Bool = true) = upsample_linear(x; size, align_corners) """ ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}, align_corners::Bool = true) where T # Arguments - `Δ`: Incoming gradient array, backpropagated from downstream layers - `size`: Lateral (W,H) size of the image upsampled in the first place # Outputs - `dx`: Downsampled version of `Δ` """ ∇upsample_bilinear(Δ; size, align_corners::Bool = true) = ∇upsample_linear(Δ; size, align_corners) """ upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real}; align_corners::Bool = true) upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}, align_corners::Bool = true) Upsamples the first 3 dimensions of the array `x` by the upsample factors stored in `scale`, using trilinear interpolation. As an alternative to using `scale`, the resulting image `size` can be directly specified with a keyword argument. The size of the output is equal to `(scale[1]*S1, scale[2]*S2, scale[3]*S3, S4, S5)`, where `S1, S2, S3, S4, S5 = size(x)`. # Examples ```julia upsample_trilinear(x, (2, 3, 4)) upsample_trilinear(x; size=(4, 9, 11)) # specify ouput size instead upsample_trilinear(x, (2.5, 3.5, pi)) # non-integer scaling factors are allowed ``` """ upsample_trilinear(x, scale; align_corners::Bool = true) = upsample_linear(x, scale; align_corners) upsample_trilinear(x; size, align_corners::Bool = true) = upsample_linear(x; size, align_corners) """ ∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}, align_corners::Bool = true) where T # Arguments - `Δ`: Incoming gradient array, backpropagated from downstream layers - `size`: Lateral size & depth (W,H,D) of the image upsampled in the first place # Outputs - `dx`: Downsampled version of `Δ` """ ∇upsample_trilinear(Δ; size, align_corners::Bool = true) = ∇upsample_linear(Δ; size, align_corners) function upsample_linear_kernel!( y::AbstractArray{T, N}, x::AbstractArray{T, N}; align_corners::Bool = true, ) where {T, N} backend = KernelAbstractions.get_backend(x) ndrange = backend isa CPU ? size(y)[N - 1:end] : # Parallelization along channel x batch. size(y)[1:N - 2] # Parallelization along WHD. ratios = align_corners ? ntuple(i -> real(T)((size(x, i) - 1) / (size(y, i) - 1)), N - 2) : ntuple(i -> real(T)(size(x, i) / size(y, i)), N - 2) _upsample_linear_kernel!(backend)(backend, y, x, ratios..., Val(align_corners); ndrange) return y end function ∇upsample_linear_kernel!( dx::AbstractArray{T, N}, Δ::AbstractArray{T, N}; align_corners::Bool = true, ) where {T, N} backend = KernelAbstractions.get_backend(dx) ndrange = backend isa CPU ? size(Δ)[N - 1:end] : # Parallelization along channel x batch. size(Δ)[1:N - 2] # Parallelization along WHD. ratios = align_corners ? ntuple(i -> real(T)((size(dx, i) - 1) / (size(Δ, i) - 1)), N - 2) : ntuple(i -> real(T)(size(dx, i) / size(Δ, i)), N - 2) _∇upsample_linear_kernel!(backend)(backend, dx, Δ, ratios..., Val(align_corners); ndrange) return dx end # Linear (CPU): parallelization along channel x batch dimensions. @kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, align::Val{A}) where { T <: AbstractArray{<:Any, 3}, A, } @uniform in_width, channels, batch = size(x) @uniform out_width = size(y, 1) c, n = @index(Global, NTuple) yv, xv = @view(y[:, c, n]), @view(x[:, c, n]) @inbounds for i in 1:out_width iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) yv[i] = w0λ * xv[iw0] + w1λ * xv[iw1] end end @kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, align::Val{A}) where { T1 <: AbstractArray{<:Any, 3}, T2 <: AbstractArray{<:Any, 3}, A, } @uniform in_width, channels, batch = size(Δ) @uniform out_width = size(dx, 1) c, n = @index(Global, NTuple) Δv, dxv = @view(Δ[:, c, n]), @view(dx[:, c, n]) @inbounds for i in 1:in_width ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) val = Δv[i] dxv[ow0] += w0λ * val dxv[ow1] += w1λ * val end end # Linear (GPU): parallelization along width dimension. @kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 3}, A, } @uniform in_width, channels, batch = size(x) i = @index(Global) iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) @inbounds for n in 1:batch, c in 1:channels y[i, c, n] = w0λ * x[iw0, c, n] + w1λ * x[iw1, c, n] end end @kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 3}, A, } @uniform in_width, channels, batch = size(Δ) @uniform out_width = size(dx, 1) i = @index(Global) ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) @inbounds for n in 1:batch, c in 1:channels val = Δ[i, c, n] @atomic dx[ow0, c, n] += w0λ * val @atomic dx[ow1, c, n] += w1λ * val end end # Bilinear (CPU): parallelization along channel x batch dimensions. @kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, align::Val{A}) where { T <: AbstractArray{<:Any, 4}, A, } @uniform in_width, in_height, channels, batch = size(x) @uniform out_width, out_height = size(y)[1:2] c, n = @index(Global, NTuple) yv, xv = @view(y[:, :, c, n]), @view(x[:, :, c, n]) for j in 1:out_height ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height) for i in 1:out_width iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) @inbounds yv[i, j] = h0λ * (w0λ * xv[iw0, ih0] + w1λ * xv[iw1, ih0]) + h1λ * (w0λ * xv[iw0, ih1] + w1λ * xv[iw1, ih1]) end end end @kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where { T1 <: AbstractArray{<:Any, 4}, T2 <: AbstractArray{<:Any, 4}, A, } @uniform in_width, in_height, channels, batch = size(Δ) @uniform out_width, out_height = size(dx)[1:2] c, n = @index(Global, NTuple) Δv, dxv = @view(Δ[:, :, c, n]), @view(dx[:, :, c, n]) for j in 1:in_height oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height) @inbounds for i in 1:in_width ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) val = Δv[i, j] dxv[ow0, oh0] += w0λ * h0λ * val dxv[ow1, oh0] += w1λ * h0λ * val dxv[ow0, oh1] += w0λ * h1λ * val dxv[ow1, oh1] += w1λ * h1λ * val end end end # Bilinear (GPU): parallelization along width, height dimensions. @kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 4}, A, } @uniform in_width, in_height, channels, batch = size(x) i, j = @index(Global, NTuple) iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height) @inbounds for n in 1:batch, c in 1:channels y[i, j, c, n] = h0λ * (w0λ * x[iw0, ih0, c, n] + w1λ * x[iw1, ih0, c, n]) + h1λ * (w0λ * x[iw0, ih1, c, n] + w1λ * x[iw1, ih1, c, n]) end end @kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 4}, A, } @uniform in_width, in_height, channels, batch = size(Δ) @uniform out_width, out_height = size(dx)[1:2] i, j = @index(Global, NTuple) ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height) @inbounds for n in 1:batch, c in 1:channels val = Δ[i, j, c, n] @atomic dx[ow0, oh0, c, n] += w0λ * h0λ * val @atomic dx[ow1, oh0, c, n] += w1λ * h0λ * val @atomic dx[ow0, oh1, c, n] += w0λ * h1λ * val @atomic dx[ow1, oh1, c, n] += w1λ * h1λ * val end end # Trilinear (CPU): parallelization along channel x batch dimensions. @kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where { T <: AbstractArray{<:Any, 5}, A, } @uniform in_width, in_height, in_depth = size(x)[1:3] @uniform channels, batch = size(x, 4), size(x, 5) @uniform out_width, out_height, out_depth = size(y)[1:3] c, n = @index(Global, NTuple) yv, xv = @view(y[:, :, :, c, n]), @view(x[:, :, :, c, n]) for k in 1:out_depth id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, in_depth) for j in 1:out_height ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height) for i in 1:out_width iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) @inbounds yv[i, j, k] = d0λ * ( h0λ * (w0λ * xv[iw0, ih0, id0] + w1λ * xv[iw1, ih0, id0]) + h1λ * (w0λ * xv[iw0, ih1, id0] + w1λ * xv[iw1, ih1, id0])) + d1λ * ( h0λ * (w0λ * xv[iw0, ih0, id1] + w1λ * xv[iw1, ih0, id1]) + h1λ * (w0λ * xv[iw0, ih1, id1] + w1λ * xv[iw1, ih1, id1])) end end end end @kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, rdepth, align::Val{A}) where { T1 <: AbstractArray{<:Any, 5}, T2 <: AbstractArray{<:Any, 5}, A, } @uniform in_width, in_height, in_depth = size(Δ)[1:3] @uniform channels, batch = size(Δ, 4), size(Δ, 5) @uniform out_width, out_height, out_depth = size(dx)[1:3] c, n = @index(Global, NTuple) Δv, dxv = @view(Δ[:, :, :, c, n]), @view(dx[:, :, :, c, n]) for k in 1:in_depth od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, out_depth) for j in 1:in_height oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height) @inbounds for i in 1:in_width ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) val = Δv[i, j, k] dxv[ow0, oh0, od0] += w0λ * h0λ * d0λ * val dxv[ow1, oh0, od0] += w1λ * h0λ * d0λ * val dxv[ow0, oh1, od0] += w0λ * h1λ * d0λ * val dxv[ow1, oh1, od0] += w1λ * h1λ * d0λ * val dxv[ow0, oh0, od1] += w0λ * h0λ * d1λ * val dxv[ow1, oh0, od1] += w1λ * h0λ * d1λ * val dxv[ow0, oh1, od1] += w0λ * h1λ * d1λ * val dxv[ow1, oh1, od1] += w1λ * h1λ * d1λ * val end end end end # Trilinear (GPU): parallelization along width x height x depth dimensions. @kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 5}, A, } @uniform in_width, in_height, in_depth = size(x)[1:3] @uniform channels, batch = size(x, 4), size(x, 5) i, j, k = @index(Global, NTuple) iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width) ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height) id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, in_depth) @inbounds for n in 1:batch, c in 1:channels y[i, j, k, c, n] = d0λ * ( h0λ * (w0λ * x[iw0, ih0, id0, c, n] + w1λ * x[iw1, ih0, id0, c, n]) + h1λ * (w0λ * x[iw0, ih1, id0, c, n] + w1λ * x[iw1, ih1, id0, c, n])) + d1λ * ( h0λ * (w0λ * x[iw0, ih0, id1, c, n] + w1λ * x[iw1, ih0, id1, c, n]) + h1λ * (w0λ * x[iw0, ih1, id1, c, n] + w1λ * x[iw1, ih1, id1, c, n])) end end @kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, rdepth, align::Val{A}) where { B <: GPU, T <: AbstractArray{<:Any, 5}, A, } @uniform in_width, in_height, in_depth = size(Δ)[1:3] @uniform channels, batch = size(Δ, 4), size(Δ, 5) @uniform out_width, out_height, out_depth = size(dx)[1:3] i, j, k = @index(Global, NTuple) ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width) oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height) od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, out_depth) @inbounds for n in 1:batch, c in 1:channels val = Δ[i, j, k, c, n] @atomic dx[ow0, oh0, od0, c, n] += w0λ * h0λ * d0λ * val @atomic dx[ow1, oh0, od0, c, n] += w1λ * h0λ * d0λ * val @atomic dx[ow0, oh1, od0, c, n] += w0λ * h1λ * d0λ * val @atomic dx[ow1, oh1, od0, c, n] += w1λ * h1λ * d0λ * val @atomic dx[ow0, oh0, od1, c, n] += w0λ * h0λ * d1λ * val @atomic dx[ow1, oh0, od1, c, n] += w1λ * h0λ * d1λ * val @atomic dx[ow0, oh1, od1, c, n] += w0λ * h1λ * d1λ * val @atomic dx[ow1, oh1, od1, c, n] += w1λ * h1λ * d1λ * val end end @inline function source_idx_and_λ( ratio::T, out_idx::Int, ::Val{align}, in_width::Int, ) where {T, align} real_index = align ? ratio * out_idx : max(zero(T), ratio * (out_idx + T(0.5)) - T(0.5)) iw0 = if T <: Rational floor(Int, real_index) # Not GPU-friendly, but allows for Rational support. else unsafe_trunc(Int, floor(real_index)) end offset = ifelse(iw0 < in_width - 1, 1, 0) iw1 = iw0 + offset + 1 w1lambda = real_index - iw0 w0lambda = one(T) - w1lambda return iw0 + 1, iw1, w0lambda, w1lambda end ================================================ FILE: src/utils.jl ================================================ """ within_gradient(x) --> Bool Returns `false` except when used inside a `gradient` call, when it returns `true`. Useful for Flux regularisation layers which behave differently during training and inference. This should work with any ChainRules-based differentiation package, in which case `x` is ignored. But Tracker.jl overloads `with_gradient(x::TrackedArray)`, thus for widest use you should pass it an array whose gradient is of interest. There is also an overload for ForwardDiff.jl's `Dual` types (and arrays of them). # Examples ```julia-repl julia> using ForwardDiff, Zygote, NNlib julia> f_good(x) = if NNlib.within_gradient(x) @show 10x else x end; julia> Zygote.withgradient(f_good, 1.0) 10x = 10.0 (val = 10.0, grad = (10.0,)) julia> ForwardDiff.derivative(f_good, 1.0) 10x = Dual{ForwardDiff.Tag{typeof(f_good), Float64}}(10.0,10.0) 10.0 julia> f_bad(x, y) = if any(NNlib.within_gradient, (x, y)) @show x * y else x / y end; julia> Zygote.withgradient(f_bad, 2.0, 3.0) (val = 0.6666666666666666, grad = (0.3333333333333333, -0.2222222222222222)) julia> ForwardDiff.derivative(x -> f_bad(x, 3.0), 2.0) x * y = Dual{ForwardDiff.Tag{var"#9#10", Float64}}(6.0,3.0) 3.0 ``` What goes wrong in `f_bad` is that Zygote knows `any` to be non-differentiable, and thus completely ignores its contents. This is not a perfect mechanism, and the only style recommended is precisely that of `f_good` above. """ within_gradient(x) = false ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), NoTangent()) """ safe_div(x, y) Returns `x/y` unless `y==0`, in which case it just returns `x`. (Used internally by `scatter`.) """ safe_div(x, y) = ifelse(iszero(y), x, x/y) """ maximum_dims(dims) Given an array of `CartesianIndex{N}` or `NTuple{N,Int}`, returns a tuple containing the maximum of all the 1st entries, all the 2nd entries, and so on up to `N`. Given an array of integers, returns `(maximum(dims),)`. (These arguments are what [`scatter`](@ref NNlib.scatter) understands.) """ maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), ) maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N) maximum_dims(dims::AbstractArray{CartesianIndex{N}}) where {N} = ntuple(i -> maximum(x->x[i], dims), N) function reverse_indices!(rev::AbstractArray, idx::AbstractArray{<:Tuple}) for (ind, val) in pairs(Array(idx)) push!(rev[val...], ind) end # if CUDA supports `unique`, a more efficient version: # cidx in CartesianIndices(idx) # for i = unique(idx) # rev[i] = cidx[idx .== i] # end rev end function reverse_indices!(rev::AbstractArray, idx::AbstractArray) for (ind, val) in pairs(Array(idx)) push!(rev[val], ind) end rev end """ reverse_indices(idx) Return the reverse indices of `idx`. The indices of `idx` will be values, and values of `idx` will be index. # Arguments - `idx`: The indices to be reversed. Accepts array or cuarray of integer, tuple or `CartesianIndex`. """ function reverse_indices(idx::AbstractArray{<:Any,N}) where N max_dims = maximum_dims(idx) T = CartesianIndex{N} rev = Array{Vector{T}}(undef, max_dims...) for i in eachindex(rev) rev[i] = T[] end return reverse_indices!(rev, idx) end unsqueeze(x) = reshape(x, 1, size(x)...) """ _fast_broadcast!(f, x, y, z...) This does `x .= f.(x, y, z...)`, but works around an issue with broadcasting that prevents SIMD in such cases. Can perhaps be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed. Has an `rrule` to avoid mutation within derivatives. !!! warning Not intended for general use. Uses `@inbounds` but does not check sizes! Assumes that `f` has no derivative! """ function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function} bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...)) @simd ivdep for I in eachindex(bc) @inbounds x[I] = bc[I] end return x end function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function} # CUDA does not suffer from this bug broadcast!(f, x, x, yz...) end function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f::F, x::AbstractArray, ys...) where {F<:Function} rrule_via_ad(cfg, broadcast, f, x, ys...) end ================================================ FILE: test/Project.toml ================================================ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ================================================ FILE: test/activations.jl ================================================ ACTIVATION_FUNCTIONS = [@eval($a) for a in NNlib.ACTIVATIONS] BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATION_FUNCTIONS) @test sigmoid(0.0) == 0.5 @test hardsigmoid(0.0) == 0.5 @test hardtanh(0.0) == 0.0 @test relu(0.0) == 0.0 @test leakyrelu(0.0) == 0.0 @test relu6(0.0) == 0.0 @test rrelu(0.0) == 0.0 @test elu(0.0) == 0.0 @test gelu(0.0) == 0.0 @test gelu_tanh(0.0) == 0.0 @test gelu_sigmoid(0.0) == 0.0 @test gelu_erf(0.0) == 0.0 @test swish(0.0) == 0.0 @test hardswish(0.0) == 0.0 @test lisht(0.0) == 0.0 @test softplus(0.0) ≈ log(2.0) @test softplus(1e8) ≈ 1e8 @test softplus(-1e8) ≈ 0.0 @test softsign(0.0) == 0.0 @test selu(0.0) == 0.0 @test celu(0.0) == 0.0 @test trelu(0.0) == 0.0 @test logcosh(0.0) == log(cosh(0.0)) @test mish(0.0) == 0.0 @test tanhshrink(0.0) == 0.0 @test softshrink(0.0) == 0.0 @test sigmoid(1.0) == 1.0 / (1.0 + exp(-1.0)) @test hardsigmoid(1.0) == max(0,min(1, (1 + 3)/6)) @test hardtanh(1.0) == 1.0 @test relu(1.0) == 1.0 @test leakyrelu(1.0) == 1.0 @test relu6(1.0) == 1.0 @test rrelu(1.0) == 1.0 @test elu(1.0) == 1.0 @test gelu(1.0) ≈ 0.8411919906082768 @test gelu_tanh(1.0) ≈ 0.8411919906082768 @test gelu_sigmoid(1.0) ≈ 0.8411919906082768 @test gelu_erf(1.0) == 0.8413447460685429 @test swish(1.0) == sigmoid(1.0) @test hardswish(1.0) == hardsigmoid(1.0) @test lisht(1.0) ≈ 1.0 * tanh(1.0) @test softplus(1.0) ≈ log(exp(1.0) + 1.0) @test softsign(1.0) == 0.5 @test selu(1.0) == 1.0507009873554804934193349852946 @test celu(1.0) == 1.0 @test trelu(1.0) == 0.0 @test logcosh(1.0) ≈ log(cosh(1.0)) @test mish(1.0) ≈ tanh(log(1.0 + exp(1.0))) @test tanhshrink(1.0) ≈ 0.23840584404423515 @test softshrink(1.0) == 0.5 @test sigmoid(-1.0) == exp(-1.0) / (1.0 + exp(-1.0)) @test hardsigmoid(-1.0) == max(0,min(1,(-1+3)/6 )) @test hardtanh(-1.0) == -1.0 @test relu(-1.0) == 0.0 @test leakyrelu(-1.0) == -0.01 @test relu6(-1.0) == 0.0 @test -1/3.0 <= rrelu(-1.0) <= -1/8.0 @test elu(-1.0) == exp(-1.0) - 1.0 @test gelu(-1.0) ≈ -0.15880800939172324 @test gelu_tanh(-1.0) ≈ -0.15880800939172324 @test gelu_sigmoid(-1.0) ≈ -0.15880800939172324 @test gelu_erf(-1.0) == -0.15865525393145707 @test swish(-1.0) == -sigmoid(-1.0) @test hardswish(-1.0) == -hardsigmoid(-1.0) @test lisht(-1.0) ≈ -1.0 * tanh(-1.0) @test softplus(-1.0) ≈ log(exp(-1.0) + 1.0) @test softsign(-1.0) == -0.5 @test selu(-1.0) ≈ 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0) @test celu(-1.0) == exp(-1.0) - 1 @test trelu(-1.0) == 0.0 @test log(cosh(-1.0)) ≈ log(cosh(-1.0)) @test mish(-1.0) ≈ -tanh(log(1.0 + exp(-1.0))) @test tanhshrink(-1.0) ≈ -0.23840584404423515 @test softshrink(-1.0) == -0.5 @testset "Float inference" begin @testset "$(a): " for a in ACTIVATION_FUNCTIONS for T in [Float16, Float32, Float64] for val in [-10, -1, 0, 1, 10] out = @inferred a(T(val)) @test typeof(out) == T end end end @testset "binary $a: " for a in BINARY_ACTIVATIONS for T in [Float16, Float32, Float64] for val in [-10, -1, 0, 1, 10], beta in Any[0.1, 0.5f0, 1] out = @inferred a(T(val), beta) @test typeof(out) == T end end end end @testset "Array input -> error" begin x = rand(5) for a in ACTIVATION_FUNCTIONS @test size(a(x)) == size(x) grad = Zygote.gradient(p -> sum(a(p)), x) @test size(grad[1]) == size(x) end for a in BINARY_ACTIVATIONS @test size(a(x, 0.1)) == size(x) grad = Zygote.gradient(p -> sum(a(p, 0.1)), x) @test size(grad[1]) == size(x) end end @testset "NaN propagation" begin @testset "$a" for a in ACTIVATION_FUNCTIONS # With NaN input, all should produce NaN output: @test isnan(a(NaN32)) # Ideally +-Inf would not lead to NaN, but perhaps # these aren't worth the complication of fixing: a == softsign && continue @test !isnan(a(Inf32)) a in [gelu, gelu_tanh, gelu_sigmoid, gelu_erf, swish, hardswish, logcosh, mish] && continue @test !isnan(a(-Inf32)) end end @testset "Integer inputs" begin # These should work without error, for e.g. readme examples, # but no serious use will involve integers, no need for performance. @testset "$a" for a in ACTIVATION_FUNCTIONS @test typeof(a(Int64(1))) <: Real @test typeof(a(Int32(1))) <: Real end # The following ones can pass integers through. But it's not very important. @testset "relu: Int -> Int" begin @test typeof(relu(Int64(1))) == Int64 @test typeof(relu(Int32(1))) == Int32 end @testset "relu6: Int -> Int" begin @test typeof(relu6(Int64(1))) == Int64 @test typeof(relu6(Int32(1))) == Int32 end @testset "hardtanh: Int -> Int" begin @test typeof(hardtanh(Int64(1))) == Int64 @test typeof(hardtanh(Int32(1))) == Int32 end @testset "trelu: Int -> Int" begin @test typeof(trelu(Int64(1))) == Int64 @test typeof(trelu(Int32(1))) == Int32 end end @testset "elu" begin @test elu(42) == 42 @test elu(42.) == 42. @test elu(-4) ≈ (exp(-4) - 1) end @testset "mish" begin @test mish(-5) ≈ -0.033576237730161704 @test mish(9) == 9*tanh(log(1 + exp(9))) xs = Float32[1 2 3; 1000 2000 3000] @test typeof(mish.(xs)) == typeof(xs) end @test leakyrelu( 0.4,0.3) ≈ 0.4 @test leakyrelu(-0.4,0.3) ≈ -0.12 @test relu6(10.0) == 6.0 @test -0.2 <= rrelu(-0.4,0.25,0.5) <= -0.1 @testset "celu" begin @test celu(42) == 42 @test celu(42.) == 42. @test celu(-4, 0.5) ≈ 0.5*(exp(-4.0/0.5) - 1) end @testset "softshrink" begin @test softshrink(15., 5.) == 10. @test softshrink(4., 5.) == 0. @test softshrink(-15., 5.) == -10. end @testset "logsigmoid" begin xs = randn(10,10) @test logsigmoid.(xs) ≈ log.(sigmoid.(xs)) for T in [:Float32, :Float64] @eval @test logsigmoid.($T[-100_000, 100_000.]) ≈ $T[-100_000, 0.] end end @test logcosh(1_000.0) + log(2) == 1_000.0 @testset "hardsigmoid" begin @test hardsigmoid(0.3) == max(0,min(1,(0.3+3)/6)) @test hardsigmoid(-0.3) == max(0,min(1,(-0.3+3)/6)) for T in [:Float32, :Float64] @eval @test hardsigmoid.($T[-100_000, 100_000.]) ≈ $T[0., 1.] end end @test hardtanh(10.0) == 1.0 @test lisht(2.5) == 2.5*tanh(2.5) @testset "trelu" begin @test trelu(0.5) == 0.0 @test trelu(1.0) == 0.0 @test trelu(1.1) == 1.1 @test trelu(0.9,0.5) == 0.9 end ## Faster variants using NNlib: tanh_fast, sigmoid_fast function countepsfrom(x::T, xtrue) where {T<:AbstractFloat} target = T(xtrue) for n in Iterators.flatten(zip(0:100, -1:-1:-100)) nextfloat(x, n) === target && return n end return round(Int, (target - x) / eps(x)) end mean_eps(f, g, xs) = mean(x -> abs(countepsfrom(f(x), g(big(x)))), xs) worst_eps(f, g, xs) = maximum(x -> abs(countepsfrom(f(x), g(big(x)))), xs) function find_worst(f, g, xs) c, i = findmax(x -> abs(countepsfrom(f(x), g(big(x)))), xs) c, xs[i] end @testset "tanh_fast & sigmoid_fast: Float64" begin x64 = 1e-6:1e-4:5 xbig = vcat(6:3:200.0, 1000, 10^6, typemax(Float64)) @testset "tanh" begin mean_eps(tanh, tanh, x64) # 0.06582 worst_eps(tanh, tanh, x64) # 2 @test mean_eps(tanh_fast, tanh, x64) < 0.2 # 0.13164 @test worst_eps(tanh_fast, tanh, x64) <= 5 # 5 @test mean_eps(tanh_fast, tanh, -x64) < 0.6 # 0.5248 @test worst_eps(tanh_fast, tanh, -x64) <= 5 # 5 @test tanh_fast.(xbig) ≈ tanh.(xbig) @test tanh_fast.(-xbig) ≈ tanh.(-xbig) end @testset "sigmoid" begin mean_eps(sigmoid, sigmoid, x64) # 0.39246 worst_eps(sigmoid, sigmoid, x64) # 1 @test mean_eps(sigmoid_fast, sigmoid, x64) < 0.5 # 0.40432 @test worst_eps(sigmoid_fast, sigmoid, x64) <= 5 # 2 mean_eps(sigmoid, sigmoid, -x64) # 0.37672 worst_eps(sigmoid, sigmoid, -x64) # 2 @test mean_eps(sigmoid_fast, sigmoid, -x64) < 0.6 # 0.56478 @test worst_eps(sigmoid_fast, sigmoid, -x64) <= 5 # 4 @test sigmoid_fast.(xbig) ≈ sigmoid.(xbig) @test sigmoid_fast.(-xbig) ≈ sigmoid.(-xbig) end end @testset "tanh_fast & sigmoid_fast: Float32" begin x32 = 1f-6:1f-4:5 xbig32 = vcat(6:3:200f0, 1000, typemax(Float32)) @testset "tanh" begin mean_eps(tanh, tanh, x32) # 0.065 worst_eps(tanh, tanh, x32) # 1 @test mean_eps(tanh_fast, tanh, x32) < 0.8 # 0.65414 @test worst_eps(tanh_fast, tanh, x32) <= 5 # 5 @test mean_eps(tanh_fast, tanh, -x32) < 0.8 # 0.65414 @test worst_eps(tanh_fast, tanh, -x32) <= 5 # 5 @test tanh_fast.(xbig32) ≈ tanh.(xbig32) @test tanh_fast.(-xbig32) ≈ tanh.(-xbig32) end @testset "sigmoid" begin mean_eps(sigmoid, sigmoid, x32) # 0.38896 worst_eps(sigmoid, sigmoid, x32) # 1 @test mean_eps(sigmoid_fast, sigmoid, x32) < 0.5 # 0.38896 @test worst_eps(sigmoid_fast, sigmoid, x32) <= 2 # 2 mean_eps(sigmoid, sigmoid, -x32) # 0.38088 worst_eps(sigmoid, sigmoid, -x32) # 2 @test mean_eps(sigmoid_fast, sigmoid, -x32) < 0.5 # 0.38088 @test worst_eps(sigmoid_fast, sigmoid, -x32) <= 2 # 2 @test sigmoid_fast.(xbig32) ≈ sigmoid.(xbig32) @test sigmoid_fast.(-xbig32) ≈ sigmoid.(-xbig32) end end ## Autodiff tests WITH_UNARY_RULE = [@eval($a) for (a, _) in NNlib.UNARY_ACTS] WITH_BINARY_RULE = [@eval($a) for (a, _, _) in NNlib.BINARY_ACTS] has_rule(a) = rrule(a, 1f0) === nothing ? "(no rule)" : "" @testset "Gradient inference" begin @testset "$(a): $(has_rule(a))" for a in ACTIVATION_FUNCTIONS @testset "$T" for T in [Float16, Float32, Float64] for val in [-10, -1, 0, 1, 10] grad = @inferred gradient(a, T(val)) @test typeof(grad[1]) == T end end end end using Base.Broadcast: broadcasted @testset "lazy broadcasting" begin # ChainRules returns a Broadcasted, check these rules accept it @test rrule(broadcasted, relu, rrule(broadcasted, +, [1,2], 3)[1]) != nothing @test rrule(broadcasted, leakyrelu, rrule(broadcasted, +, [1,2], 3)[1], 0.2) != nothing end @testset "Gradient correctness" begin local rng = StableRNG(17) @testset "$(f): $(has_rule(f))" for f in ACTIVATION_FUNCTIONS f == rrelu && continue # stocastich output ## Avoid singular points of some activations ## problematic for finite diff methods gradtest(f, +2 + rand(rng)) gradtest(f, -2 - rand(rng)) gradtest(f, +2 .+ rand(rng, 2, 2), check_broadcast=true) gradtest(f, -2 .- rand(rng, 2, 2), check_broadcast=true) if f in BINARY_ACTIVATIONS gradtest(x -> f(x, 0.2), 1 + rand(rng)) gradtest(x -> f(x, 0.7), 1 + rand(rng)) gradtest(x -> f(x, 0.2), -2 + rand(rng)) gradtest(x -> f(x, 0.7), -2 + rand(rng)) end ## Check that rules, including broadcast rules, are defined: if f in WITH_UNARY_RULE @test rrule(f, rand()) !== nothing @test rrule(broadcasted, f, rand(2)) !== nothing end if f in WITH_BINARY_RULE @test rrule(f, rand(), rand()) !== nothing @test rrule(broadcasted, f, rand(2), rand()) !== nothing end end @testset "Flux-like usage" begin ## This checks some broadcast rules for correctness: gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2) gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2) gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2) gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2, atol=1e-4) gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2) gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2, atol=1e-4) gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2) gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2) ## Binary functions have their own broadcast rules: gradtest((x, W, b) -> leakyrelu.(W*x .+ b, 0.2), 5, (2,5), 2) gradtest((x, W, b) -> leakyrelu.(W*x .+ b, 0.7), (5,3), (2,5), 2) end @testset "Zygote issue 758" begin ## Tests for https://github.com/FluxML/Zygote.jl/issues/758 @test gradient(xs -> sum(selu.(xs)), [1_000, 10_000])[1] ≈ [1.0507009873554805, 1.0507009873554805] rtol=1e-8 @test gradient(x -> selu(x), 1_000) == (1.0507009873554805,) @test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) @test gradient(x -> elu(x, 2), 1_000) == (1.,) @test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) gradtest(x-> selu.(x),[100., 1_000.]) gradtest(x -> elu.(x, 3.5),[100., 1_000.]) gradtest(x -> elu.(x, 3.5),[1_000., 10_000.]) gradtest(x -> selu.(x), [1_000., 10_000.]) gradtest(x -> selu.(x), 10, atol=1e-4) end end @testset "Second derivatives" begin ## Not extensive, but a start! ## More careful tests could look for `nothing` gradients of piecewise functions @testset "$(f): $(has_rule(f))" for f in ACTIVATION_FUNCTIONS f == rrelu && continue ## Scalar h = Zygote.hessian_dual(x -> sin(f(x)), 0.23) @test h ≈ Zygote.hessian_reverse(x -> sin(f(x)), 0.23) ## Broadcasting x = [-0.9, -0.2, 0.1, 0.3, 1.2] H = Zygote.hessian_dual(x -> sum(abs2, f.(x .+ 0.1)), x) @test H ≈ Zygote.hessian_reverse(x -> sum(abs2, f.(x .+ 0.1)), x) end @testset "$(f): $(has_rule(f))" for f in BINARY_ACTIVATIONS f == rrelu && continue ## Scalar h = Zygote.hessian_dual(x -> sin(f(x, 0.3)), 0.45) @test h ≈ Zygote.hessian_reverse(x -> sin(f(x, 0.3)), 0.45) ## Broadcasting x = [-0.9, -0.2, 0.1, 0.3, 1.2] H = Zygote.hessian_dual(x -> sum(abs2, f.(x .+ 0.1, 0.3)), x) @test H ≈ Zygote.hessian_reverse(x -> sum(abs2, f.(x .+ 0.1, 0.3)), x) end end ================================================ FILE: test/attention.jl ================================================ @testset "different batchsizes" begin n = 15 lenq = 3 lenkv = 4 for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5] q = rand(Float32, n, lenq, batch_size...) k = rand(Float32, n, lenkv, batch_size...) v = rand(Float32, n, lenkv, batch_size...) y, α = dot_product_attention(q, k, v; nheads) @test y isa Array{Float32} @test size(y) == (n, lenq, batch_size...) @test size(α) == (lenkv, lenq, nheads, batch_size...) @test sum(α, dims=1) ≈ ones(1, lenq, nheads, batch_size...) end end @testset "dot_product_attention_scores" begin q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24 α = dot_product_attention_scores(q, k) q2, k2 = reshape.((q, k), 8, 3, 1) y, α2 = dot_product_attention(q2, k2, k2; nheads=2) @test α ≈ α2 end @testset "specific results" begin q = k = v = reshape([1:12;], 4, 3, 1) ./ 12 y, α = dot_product_attention(q, k, v; nheads=2) ytrue = [0.429754, 0.513087, 0.613791, 0.697125, 0.46431, 0.547644, 0.647876, 0.73121, 0.49773, 0.581064, 0.680455, 0.763788] ytrue = reshape(ytrue, 4, 3, 1) αtrue = [0.313896, 0.332948, 0.353157, 0.264431, 0.328206, 0.407362, 0.219215, 0.31838, 0.462405, 0.288691, 0.331243, 0.380066, 0.241239, 0.323893, 0.434868, 0.198438, 0.311761, 0.489801] αtrue = reshape(αtrue, 3, 3, 2, 1) @test y ≈ ytrue atol=1e-5 @test α ≈ αtrue atol=1e-5 end @testset "mask" begin q = rand(4, 2, 3, 1) k = rand(4, 2, 5, 1) mask = rand(Bool, (5, 3)) α = dot_product_attention_scores(q, k; mask) @test all((α[:,:,1,1].> 0) .== mask) @test all((α[:,:,2,1].> 0) .== mask) @testset "causal" begin x = rand(4, 2, 3, 1) mask = make_causal_mask(x, dims=3) α = dot_product_attention_scores(x, x; mask) @test all((α[:,:,1,1].> 0) .== mask) @test all((α[:,:,2,1].> 0) .== mask) end end @testset "dropout" begin q = k = v = rand(10, 10, 10) fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p) y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5)) @test 0.6 > mean(>(0), α) > 0.4 end @testset "bias" begin q = rand(4, 5, 1) k = v = rand(4, 3, 1) bias = randn(3, 5) y, α = dot_product_attention(q, k, v, bias; nheads=2) @test size(α) == (3, 5, 2, 1) @test size(y) == (4, 5, 1) end @testset "gradient" begin q = rand(4, 5, 1) k = v = rand(4, 3, 1) bias = randn(3, 5) y, α = dot_product_attention(q, k, v, bias; nheads=2) gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias) end ================================================ FILE: test/batchedmul.jl ================================================ using NNlib, Test, LinearAlgebra, Logging using NNlib: storage_type, storage_typejoin, is_strided, batched_mul_generic!, _unbatch, _copy_if_faster, BatchedAdjoint, BatchedTranspose function bmm_test(a,b; transA = false, transB = false) bs = size(a,3) transA && (a = permutedims(a, [2,1,3])) transB && (b = permutedims(b, [2,1,3])) c = [] for i = 1:bs push!(c, a[:,:,i]*b[:,:,i]) end cat(c...; dims = 3) end function bmm_adjtest(a,b; adjA = false, adjB = false) bs = size(a,3) c = [] for i = 1:bs ai = adjA ? adjoint(a[:,:,i]) : a[:,:,i] bi = adjB ? adjoint(b[:,:,i]) : b[:,:,i] push!(c, ai*bi) end cat(c...; dims = 3) end function half_batched_mul(x,y) @assert size(y,3) == 1 d = size(x,2) x_mat = reshape(permutedims(x, (1,3,2)),:,d) y_mat = reshape(y,d,:) z_mat = x_mat * y_mat permutedims(reshape(z_mat, size(x,1), size(x,3), :), (1,3,2)) end @testset "batched_mul: Float64 * $TB" for TB in [Float64, Float32] # Real A = randn(7,5,3) B = randn(TB, 5,7,3) C = randn(7,6,3) @test batched_mul(A, B) ≈ bmm_test(A, B) @test batched_mul(batched_transpose(A), batched_transpose(B)) ≈ bmm_test(A, B; transA = true, transB = true) @test batched_mul(batched_transpose(A), C) ≈ bmm_test(A, C; transA = true) @test batched_mul(A, batched_transpose(A)) ≈ bmm_test(A, A; transB = true) # Complex cA = randn(Complex{Float64}, 7,5,3) cB = randn(Complex{TB}, 5,7,3) cC = randn(Complex{Float64}, 7,6,3) @test batched_mul(cA, cB) ≈ bmm_adjtest(cA, cB) @test batched_mul(batched_adjoint(cA), batched_adjoint(cB)) ≈ bmm_adjtest(cA, cB; adjA = true, adjB = true) @test batched_mul(batched_adjoint(cA), cC) ≈ bmm_adjtest(cA, cC; adjA = true) @test batched_mul(cA, batched_adjoint(cA)) ≈ bmm_adjtest(cA, cA; adjB = true) # Wrappers which cancel @test batched_transpose(batched_transpose(A)) === A @test batched_transpose(PermutedDimsArray(A, (2,1,3))) === A @test batched_adjoint(batched_adjoint(cA)) === cA @test batched_transpose(batched_adjoint(cA)) isa NNlib.BatchedAdjoint # Integers TBi = TB==Float64 ? Int64 : Int32 iA = rand(1:99, 7,5,3) iB = TB.(rand(1:99, 5,7,3)) iC = zeros(Int, 7,6,3) @test batched_mul(iA, iB) == bmm_adjtest(iA, iB) @test batched_mul(cA, iB) ≈ bmm_adjtest(cA, iB) # Errors @test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 2,2,10)) @test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 10,2,2)) @test_throws Exception batched_mul!(zeros(2,2,10), rand(2,2,2), rand(TB, 2,2,2)) # PermutedDimsArrays for perm in [(1,3,2), (2,1,3), (3,2,1)], fun in [identity, batched_adjoint], ty in [identity, complex] A = randn(ty(Float64), 4,4,4) B = randn(ty(TB), 4,4,4) @test batched_mul(fun(A), PermutedDimsArray(B, perm)) ≈ batched_mul(fun(A), permutedims(B, perm)) @test batched_mul(fun(PermutedDimsArray(A, perm)), B) ≈ batched_mul(fun(permutedims(A, perm)), B) # when TB=Float64, only the case perm=(2,1,3); fun=batched_adjoint; ty=complex; goes to fallback # but all the perm=(3,2,1); cases copy their inputs. end # PermutedDimsArray output A′ = randn(4,3,2) B′ = batched_adjoint(randn(TB, 5,3,2)) C1 = batched_mul(A′, B′) # size 4,5,2 C2 = PermutedDimsArray(zeros(5,2,4), (3,1,2)) # size 4,5,2 @test C1 ≈ batched_mul!(C2, A′, B′) # Float64: "Debug: transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" @test C1 ≈ C2 # 5-arg mul! @test 10 .* C1 ≈ batched_mul!(C2, A′, B′, 10) rtol=1e-7 C2 .= 10 @test C1 .+ 100 ≈ batched_mul!(C2, A′, B′, 1, 10) # Trivial batches for B D′ = randn(TB, 3,5,1) @test size(batched_mul(A′,D′)) == (4,5,2) @test batched_mul(A′,D′) ≈ half_batched_mul(A′, D′) # Large output, multi-threaded path if TB == Float64 N = 50 A = rand(N,N,N) B = rand(N,N,N) C = reshape(reduce(hcat, [vec(A[:,:,k] * B[:,:,k]) for k in 1:N]), N,N,N) @test C ≈ A ⊠ B D = rand(N,N,1) E = reshape(reduce(hcat, [vec(A[:,:,k] * D[:,:,1]) for k in 1:N]), N,N,N) @test E ≈ A ⊠ D end end perm_12(A) = PermutedDimsArray(A, (2,1,3)) perm_23(A) = PermutedDimsArray(A, (1,3,2)) @testset "batched_mul: trivial dimensions & unit strides, $T" for T in [Float64, ComplexF64] @testset "$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))" for tA in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], sA in [(1,1), (1,3), (3,1), (3,3)], tB in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], sB in [(1,1), (1,3), (3,1), (3,3)] A = tA(rand(T, sA..., 3)) B = tB(rand(T, sB..., 3)) size(A,2) == size(B,1) && size(A,3) == size(B,3) == 3 || continue C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2], A[:,:,3] * B[:,:,3]; dims=3) @test A ⊠ B ≈ C @test_logs min_level=Logging.Debug A ⊠ B # In-place batched_mul! α, β = rand(T), rand(T) D = rand(T, size(C)) @test batched_mul!(copy(D), A, B, α, β) ≈ α .* C .+ β .* D @test batched_mul_generic!(copy(D), A, B, α, β) ≈ α .* C .+ β .* D # ... and with weird LHS -- all to batched_mul_generic! right now C2 = batched_transpose(permutedims(C, (2,1,3))) C3 = batched_adjoint(permutedims(conj(C), (2,1,3))) @test C2 == C3 == C C2 .= D C3 .= D @test batched_mul!(C2, A, B, α, β) ≈ α .* C .+ β .* D @test C2 ≈ α .* C .+ β .* D @test batched_mul!(C3, A, B, α, β) ≈ α .* C .+ β .* D @test C3 ≈ α .* C .+ β .* D end end @testset "BatchedAdjOrTrans interface * $TB" for TB in [Float64, Float32] A = randn(7,5,3) B = randn(TB, 5,7,3) C = randn(7,6,3) function interface_tests(X, _X) @test length(_X) == length(X) @test size(_X) == (size(X, 2), size(X, 1), size(X, 3)) @test axes(_X) == (axes(X, 2), axes(X, 1), axes(X, 3)) # @test getindex(_X, 2, 3, 3) == getindex(X, 3, 2, 3) @test getindex(_X, 5, 4, 1) == getindex(X, 4, 5, 1) # setindex!(_X, 2.0, 2, 4, 1) @test getindex(_X, 2, 4, 1) == 2.0 setindex!(_X, 3.0, 1, 2, 2) @test getindex(_X, 1, 2, 2) == 3.0 _sim = similar(_X, TB, (2, 3)) @test size(_sim) == (2, 3) @test typeof(_sim) == Array{TB, 2} _sim = similar(_X, TB) @test length(_sim) == length(_X) @test typeof(_sim) == Array{TB, 3} _sim = similar(_X, (2, 3)) @test size(_sim) == (2, 3) @test typeof(_sim) == Array{Float64, 2} _sim = similar(_X) @test length(_sim) == length(_X) @test typeof(_sim) == Array{Float64, 3} @test parent(_X) == _X.parent end for (X, _X) in zip([A, B, C], map(batched_adjoint, [A, B, C])) interface_tests(X, _X) @test -_X == NNlib.BatchedAdjoint(-_X.parent) _copyX = copy(_X) @test _X == _copyX setindex!(_copyX, 2.0, 1, 2, 1) @test _X != _copyX end for (X, _X) in zip([A, B, C], map(batched_transpose, [A, B, C])) interface_tests(X, _X) @test -_X == NNlib.BatchedTranspose(-_X.parent) _copyX = copy(_X) @test _X == _copyX setindex!(_copyX, 2.0, 1, 2, 1) @test _X != _copyX end end @testset "batched_mul(ndims < 3), $TM" for TM in [ComplexF64, Int8] A = randn(ComplexF64, 3,3,3) M = rand(TM, 3,3) .+ im V = rand(TM, 3) # These are all reshaped and sent to batched_mul(3-array, 3-array) @test batched_mul(A, M) ≈ cat([A[:,:,k] * M for k in 1:3]...; dims=3) @test batched_mul(A, M') ≈ cat([A[:,:,k] * M' for k in 1:3]...; dims=3) @test A ⊠ transpose(M) ≈ cat([A[:,:,k] * transpose(M) for k in 1:3]...; dims=3) @test batched_mul(M, A) ≈ cat([M * A[:,:,k] for k in 1:3]...; dims=3) @test batched_mul(M', A) ≈ cat([M' * A[:,:,k] for k in 1:3]...; dims=3) @test transpose(M) ⊠ A ≈ cat([transpose(M) * A[:,:,k] for k in 1:3]...; dims=3) # batched_vec @test batched_vec(A, M) ≈ hcat([A[:,:,k] * M[:,k] for k in 1:3]...) @test batched_vec(A, M') ≈ hcat([A[:,:,k] * (M')[:,k] for k in 1:3]...) @test batched_vec(A, V) ≈ hcat([A[:,:,k] * V for k in 1:3]...) end @testset "storage_type" begin @test storage_type(transpose(reshape(view(rand(10), 2:9),4,:))) == Vector{Float64} @test storage_type(transpose(reshape(view(1:10, 2:9),4,:))) == UnitRange{Int} @test storage_typejoin(rand(2), rand(Float32, 2)) == Vector{<:Any} @test storage_typejoin(rand(2), rand(2,3)', rand(2,3,4)) == Array{Float64} @test storage_typejoin([1,2,3], 4:5) == AbstractVector{Int} end @testset "is_strided" begin M = ones(10,10) @test is_strided(M) @test is_strided(view(M, 1:2:5,:)) @test is_strided(PermutedDimsArray(M, (2,1))) @test !is_strided(reshape(view(M, 1:2:10,:), 10,:)) @test !is_strided((M.+im)') @test !is_strided(Diagonal(ones(3))) A = ones(2,2,2) @test is_strided(batched_adjoint(A)) @test is_strided(batched_transpose(A)) @test !is_strided(batched_adjoint(A .+ im)) @test is_strided(batched_transpose(A .+ im)) end FiniteDifferences.to_vec(x::BatchedAdjoint) = FiniteDifferences.to_vec(collect(x)) FiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect(x)) @testset "AutoDiff" begin M, P, Q = 13, 7, 11 B = 3 # Two 3-arrays gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B)) gradtest(batched_mul, batched_adjoint(randn(rng, P, M, B)), randn(rng, P, Q, B)) gradtest(batched_mul, randn(rng, M, P, B), batched_transpose(randn(rng, Q, P, B))) # One a matrix... gradtest(batched_mul, randn(rng, M, P), randn(rng, P, Q, B)) gradtest(batched_mul, adjoint(randn(rng, P, M)), randn(rng, P, Q, B)) gradtest(batched_mul, randn(rng, M, P), batched_adjoint(randn(rng, Q, P, B))) gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q)) gradtest(batched_mul, batched_transpose(randn(rng, P, M, B)), randn(rng, P, Q)) gradtest(batched_mul, randn(rng, M, P, B), transpose(randn(rng, Q, P))) # ... or equivalent to a matrix gradtest(batched_mul, randn(rng, M, P, 1), randn(rng, P, Q, B)) gradtest(batched_mul, batched_transpose(randn(rng, P, M, 1)), randn(rng, P, Q, B)) gradtest(batched_mul, randn(rng, M, P, 1), batched_transpose(randn(rng, Q, P, B))) gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, 1)) gradtest(batched_mul, batched_adjoint(randn(rng, P, M, B)), randn(rng, P, Q, 1)) gradtest(batched_mul, randn(rng, M, P, B), batched_adjoint(randn(rng, Q, P, 1))) # batched_vec gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P, B)) gradtest(batched_vec, randn(rng, M, P, B), transpose(randn(rng, B, P))) gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P)) end @testset "batched_vec: N-D batches" begin # Test 4D case: A is 4D, B is 3D A4d = randn(4, 5, 3, 2) # (matrix_rows, matrix_cols, batch_dim1, batch_dim2) B3d = randn(5, 3, 2) # (vector_length, batch_dim1, batch_dim2) C = batched_vec(A4d, B3d) @test size(C) == (4, 3, 2) # Manual verification for i in 1:3, j in 1:2 @test C[:, i, j] ≈ A4d[:, :, i, j] * B3d[:, i, j] end # Test 5D case: A is 5D, B is 4D A5d = randn(3, 4, 2, 3, 2) # (matrix_rows, matrix_cols, batch1, batch2, batch3) B4d = randn(4, 2, 3, 2) # (vector_length, batch1, batch2, batch3) C5 = batched_vec(A5d, B4d) @test size(C5) == (3, 2, 3, 2) # Manual verification for a few cases @test C5[:, 1, 1, 1] ≈ A5d[:, :, 1, 1, 1] * B4d[:, 1, 1, 1] @test C5[:, 2, 3, 2] ≈ A5d[:, :, 2, 3, 2] * B4d[:, 2, 3, 2] # Test dimension mismatch errors @test_throws DimensionMismatch batched_vec(randn(3, 4, 2), randn(4, 3)) # ndims mismatch @test_throws DimensionMismatch batched_vec(randn(3, 4, 2, 3), randn(4, 2, 2)) # batch size mismatch end ================================================ FILE: test/bias_act.jl ================================================ using NNlib, Zygote, ChainRulesCore, Test using Zygote: ForwardDiff ACTIVATION_FUNCTIONS = [@eval($a) for a in NNlib.ACTIVATIONS] @testset "bias_act!" begin x = randn(3,4) b = randn(3) @test @inferred(bias_act!(identity, x, false)) === x # pass-through @test @inferred(bias_act!(identity, copy(x), b)) ≈ (x .+ b) @test @inferred(bias_act!(relu, copy(x), b)) ≈ relu.(x .+ b) @test @inferred(bias_act!(tanh, copy(x), b)) ≈ tanh.(x .+ b) @test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x) # Check that it does overwrite: x32 = rand(Float32, 3, 4); x32copy = copy(x32) @test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b) @test x32 ≈ cbrt.(x32copy .+ b) x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias @test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy) @test x32 ≈ tanh.(x32copy) x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b) @test y ≈ x32 ≈ relu.(x32copy .+ b) x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false) @test y ≈ x32 ≈ relu.(x32copy) # Check that it doesn't try to overwrite non-float arrays: xint = rand(-3:3, 3, 4) bint = rand(-2:2, 3) @test bias_act!(identity, copy(xint), bint) ≈ xint .+ bint @test bias_act!(tanh, copy(xint), bint) ≈ tanh.(xint .+ bint) @test bias_act!(tanh, copy(xint), false) ≈ tanh.(xint) # Reject bias===true so that Bool means one thing: @test_throws Exception bias_act!(identity, rand(3), true) @test_throws Exception bias_act!(cbrt, rand(3), true) @test_throws Exception bias_act!(cbrt, rand(1:3, 3), true) @testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt], ACTIVATION_FUNCTIONS, [x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)]) # Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about. fun == rrelu && continue # this one is randomised! fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below @test bias_act!(fun, copy(x), b) ≈ fun.(x .+ b) @test bias_act!(fun, copy(x), false) ≈ fun.(x) gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x) gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps()) gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps()) if !(gx ≈ gxplus ≈ gxminus) @warn "skipping gradient tests due to discontinuity" fun x b continue end @test gx ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1] gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x) gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) if !(gx2 ≈ gx2plus ≈ gx2minus) @warn "skipping gradient tests due to discontinuity" fun x continue end @test gx2 ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1] gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b) @test gb ≈ Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1] @test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,) @test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,) end @testset "gradient for fast_broadcast!" begin # Gradient definition is just to disable mutation inside 2nd order AD gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x) @test gx ≈ Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)[1] # relu should take the fast path g2 = ForwardDiff.gradient(x) do x sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) end @test_skip gx ≈ Zygote.gradient(x) do x # Here global variable b causes an error sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) end # Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)). # [5] (::typeof(∂(accum_global)))(Δ::Nothing) @test g2 ≈ Zygote.gradient(x, b) do x, b sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1]) end[1] g3 = ForwardDiff.gradient(x) do x sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) end @test g3 ≈ Zygote.gradient(x, b) do x, b sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) end[1] # Anon function sure to take the generic path g4 = ForwardDiff.gradient(x) do x sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) end @test g4 ≈ Zygote.gradient(x, b) do x, b sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) end[1] end end ================================================ FILE: test/conv.jl ================================================ using NNlib, Test using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier, stride, padding, dilation, flipkernel, output_size, groupcount using Random: AbstractRNG, SamplerType @testset "ConvDims" begin for T in (DenseConvDims, DepthwiseConvDims) @testset "$(T)" begin x = randn(5,4,3,2) if T == DenseConvDims w = randn(1,2,3,4) elseif T == DepthwiseConvDims w = randn(1,2,4,3) end # First, getters: cdims = T(x, w) @test input_size(cdims) == size(x)[1:2] @test kernel_size(cdims) == size(w)[1:2] @test channels_in(cdims) == size(x, 3) @test stride(cdims) == (1,1) @test dilation(cdims) == (1,1) @test padding(cdims) == (0,0,0,0) @test flipkernel(cdims) == false @test output_size(cdims) == (5,3) # Special-case channel output tests if T == DenseConvDims @test channels_out(cdims) == size(w, 4) elseif T == DepthwiseConvDims @test channel_multiplier(cdims) == size(w, 3) @test channels_out(cdims) == size(w,3)*size(w,4) end # Next, scalar settings: cdims = T(x, w; stride=2, dilation=2, padding=3, flipkernel=true) @test stride(cdims) == (2,2) @test dilation(cdims) == (2,2) @test padding(cdims) == (3,3,3,3) @test flipkernel(cdims) == true @test output_size(cdims) == (6,4) # Next, tuple settings cdims = T(x, w; stride=(1, 2), dilation=(1, 2), padding=(0,1)) @test stride(cdims) == (1,2) @test dilation(cdims) == (1,2) @test padding(cdims) == (0,0,1,1) @test output_size(cdims) == (5,2) # Special case for 4-d padding spec: cdims = T(x, w; padding=(1,2,3,4)) @test padding(cdims) == (1,2,3,4) @test output_size(cdims) == (8,10) # Make sure we throw on invalid settings: # Invalid dimensionality of settings: @test_throws DimensionMismatch T(x, w; stride=(1,)) @test_throws DimensionMismatch T(x, w; stride=(1, 1, 1)) @test_throws DimensionMismatch T(x, w; padding=(1, 1, 1)) @test_throws DimensionMismatch T(x, w; padding=(1, 1, 1, 1, 1)) @test_throws DimensionMismatch T(x, w; dilation=(1,)) @test_throws DimensionMismatch T(x, w; dilation=(1, 1, 1)) # Dilation will cause us to reach beyond the end of input + padding here: @test_throws DimensionMismatch T(x, w; dilation=(1, 5)) # Channel mismatch: if T == DenseConvDims @test_throws DimensionMismatch T(x, w[:,:,1:1,:]) elseif T == DepthwiseConvDims @test_throws DimensionMismatch T(x, w[:,:,:,1:1]) end end end end conv_answer_dict = Dict( # Known-good answers for 1d convolution operations 1 => Dict( "y_pad" => [1, 4, 7, 10, 13, 10.], "y_dil" => [5, 8, 11.], "y_flip" => [5, 8, 11, 14.], "dx" => [ 8, 18, 27, 36, 13.], "dx_stride" => [ 8, 4, 20, 10, 0.], "dx_pad" => [ 9, 18, 27, 36, 33.], "dx_dil" => [10, 16, 27, 8, 11.], "dx_flip" => [ 5, 18, 27, 36, 28.], "dw" => [134, 100.], "dw_stride" => [ 48, 34.], "dw_pad" => [135, 150.], "dw_dil" => [102, 54.], "dw_flip" => [110, 148.], ), # Known-good answers for 2d convolution operations 2 => Dict( "y_pad" => [ 1 9 29 49 48; 4 29 79 129 115; 7 39 89 139 122; 10 49 99 149 129; 13 59 109 159 136; 10 40 70 100 80. ], "y_dil" => [ 48 98; 58 108; 68 118. ], "y_flip" => [ 51 101 151; 61 111 161; 71 121 171; 81 131 181. ], "dx" => [ 116 374 674 258; 243 700 1200 407; 313 800 1300 437; 383 900 1400 467; 177 386 586 159. ], "dx_stride" => [ 116 58 516 258; 87 29 387 129; 196 98 596 298; 147 49 447 149; 0 0 0 0. ], "dx_pad" => [ 152 470 850 911; 261 700 1200 1240; 340 800 1300 1319; 419 900 1400 1398; 370 746 1126 1087. ], "dx_dil" => [ 192 392 96 196; 232 432 116 216; 416 766 184 334; 174 324 58 108; 204 354 68 118. ], "dx_flip" => [ 51 254 454 453; 163 700 1200 1087; 193 800 1300 1157; 223 900 1400 1227; 162 586 886 724. ], "dw" => [ 17378 11738; 16250 10610. ], "dw_stride" => [ 5668 3888; 5312 3532. ], "dw_pad" => [ 18670 22550; 19850 23430. ], "dw_dil" => [ 8632 3652; 7636 2656. ], "dw_flip" => [ 12590 19550; 13982 20942. ], ), # Known-good answers for 3d convolution operations (these are getting rather large) 3 => Dict( "y_pad" => reshape([ 1, 4, 7, 10, 13, 10, 9, 29, 39, 49, 59, 40, 29, 79, 89, 99, 109, 70, 49, 129, 139, 149, 159, 100, 48, 115, 122, 129, 136, 80, 26, 80, 94, 108, 122, 80, 126, 322, 358, 394, 430, 260, 206, 502, 538, 574, 610, 360, 286, 682, 718, 754, 790, 460, 220, 502, 524, 546, 568, 320, 146, 360, 374, 388, 402, 240, 446, 1042, 1078, 1114, 1150, 660, 526, 1222, 1258, 1294, 1330, 760, 606, 1402, 1438, 1474, 1510, 860, 420, 942, 964, 986, 1008, 560, 205, 456, 467, 478, 489, 270, 517, 1133, 1159, 1185, 1211, 660, 577, 1263, 1289, 1315, 1341, 730, 637, 1393, 1419, 1445, 1471, 800, 392, 847, 862, 877, 892, 480. ], (6,5,4)), "y_dil" => reshape([608, 644, 680, 788, 824, 860.], (3,2,1)), "y_flip" => reshape([ 686, 722, 758, 794, 866, 902, 938, 974, 1046, 1082, 1118, 1154, 1406, 1442, 1478, 1514, 1586, 1622, 1658, 1694, 1766, 1802, 1838, 1874. ], (4,3,2)), "dx" => reshape([ 2576, 5118, 5658, 6198, 3010, 5948, 11576, 12512, 13448, 6420, 8468, 16256, 17192, 18128, 8580, 4092, 7718, 8114, 8510, 3950, 9624, 18316, 19108, 19900, 9340, 18680, 34992, 36288, 37584, 17320, 22280, 41472, 42768, 44064, 20200, 9776, 17756, 18260, 18764, 8340, 4168, 7438, 7690, 7942, 3450, 6972, 11896, 12256, 12616, 5140, 8052, 13696, 14056, 14416, 5860, 2804, 4278, 4386, 4494, 1510. ], (5,4,3)), "dx_stride" => reshape([ 2576, 2254, 3152, 2758, 0, 1932, 1610, 2364, 1970, 0, 5456, 4774, 6032, 5278, 0, 4092, 3410, 4524, 3770, 0, 1288, 966, 1576, 1182, 0, 644, 322, 788, 394, 0, 2728, 2046, 3016, 2262, 0, 1364, 682, 1508, 754, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0. ], (5,4,3)), "dx_pad" => reshape([ 4220, 6343, 7116, 7889, 6550, 8490, 12276, 13312, 14348, 11606, 12350, 17456, 18492, 19528, 15546, 11989, 16664, 17469, 18274, 14333, 16200, 22628, 23616, 24604, 19392, 25336, 34992, 36288, 37584, 29320, 30216, 41472, 42768, 44064, 34200, 26236, 35664, 36652, 37640, 28940, 22816, 30831, 31636, 32441, 24794, 32522, 43668, 44704, 45740, 34742, 36462, 48848, 49884, 50920, 38602, 29501, 39264, 40037, 40810, 30733. ], (5,4,3)), "dx_dil" => reshape([ 4864, 5152, 9696, 4508, 4760, 6304, 6592, 12396, 5768, 6020, 3648, 3864, 7120, 3220, 3400, 4728, 4944, 9100, 4120, 4300, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2432, 2576, 4544, 1932, 2040, 3152, 3296, 5804, 2472, 2580, 1216, 1288, 1968, 644, 680, 1576, 1648, 2508, 824, 860. ], (5,4,3)), "dx_flip" => reshape([ 686, 2094, 2202, 2310, 1588, 2924, 7544, 7904, 8264, 5124, 3644, 9344, 9704, 10064, 6204, 3138, 7430, 7682, 7934, 4616, 4836, 11980, 12484, 12988, 7792, 14936, 34992, 36288, 37584, 21640, 17816, 41472, 42768, 44064, 25240, 12620, 28412, 29204, 29996, 16728, 7030, 15646, 16042, 16438, 9084, 17772, 38968, 39904, 40840, 22276, 19932, 43648, 44584, 45520, 24796, 12362, 26742, 27282, 27822, 14992. ], (5,4,3)), "dw" => reshape([1.058184e6, 1.0362e6, 948264, 926280, 618504, 596520, 508584, 486600], (2,2,2)), "dw_stride" => reshape([ 74760, 72608, 64000, 61848, 31720, 29568, 20960, 18808.], (2,2,2)), "dw_pad" => reshape([1.26055e6, 1.30805e6, 1.40327e6, 1.44923e6, 1.73731e6, 1.77589e6, 1.83259e6, 1.86731e6], (2,2,2)), "dw_dil" => reshape([ 250320, 241512, 206280, 197472, 74160, 65352, 30120, 21312.], (2,2,2)), "dw_flip" => reshape([ 639480, 670200, 793080, 823800, 1.25388e6, 1.2846e6, 1.40748e6, 1.4382e6], (2,2,2)), ), ) # A "drop channels and batch dimension" helper ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) @testset "Dense Convolution" begin # Start with some easy-to-debug cases that we have worked through and _know_ work for rank in (1,2,3) @testset "conv$(rank)d" begin # Pull out known-good answers for y = conv(x, w) y_pad = conv_answer_dict[rank]["y_pad"] y_dil = conv_answer_dict[rank]["y_dil"] y_flip = conv_answer_dict[rank]["y_flip"] # We can always derive y_plain and y_stride from the other answers. y_plain = y_pad[((2:(size(y_pad,idx)-1)) for idx in 1:rank)...] y_stride = y_pad[((2:2:(size(y_pad,idx)-1)) for idx in 1:rank)...] # Same for dx and dw: dx = conv_answer_dict[rank]["dx"] dx_stride = conv_answer_dict[rank]["dx_stride"] dx_pad = conv_answer_dict[rank]["dx_pad"] dx_dil = conv_answer_dict[rank]["dx_dil"] dx_flip = conv_answer_dict[rank]["dx_flip"] dw = conv_answer_dict[rank]["dw"] dw_stride = conv_answer_dict[rank]["dw_stride"] dw_pad = conv_answer_dict[rank]["dw_pad"] dw_dil = conv_answer_dict[rank]["dw_dil"] dw_flip = conv_answer_dict[rank]["dw_flip"] # We generate x and w from the shapes we know they must be x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1) w = reshape(Float64[1:prod(size(dw));], size(dw)..., 1, 1) convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,] for conv in convs @testset "$(conv)" begin cdims = DenseConvDims(x, w) # First, your basic convolution with no parameters @test isapprox(ddims(conv(x, w, cdims)), y_plain, rtol = 1.0e-7) # Next, test convolution on views and alternate datatypes: @test isapprox(ddims(conv(view(x, repeat([:], ndims(x))...), w, cdims)), y_plain, rtol = 1.0e-7) @test isapprox(ddims(conv(Float32.(x), Float32.(w), cdims)), Float32.(y_plain), rtol = 1.0e-7) # Next, introduce stride: cdims = DenseConvDims(x, w; stride=2) @test isapprox(ddims(conv(x, w, cdims)), y_stride, rtol = 1.0e-7) # Next, introduce dilation: cdims = DenseConvDims(x, w; dilation=2) @test isapprox(ddims(conv(x, w, cdims)), y_dil, rtol = 1.0e-7) # Next, introduce padding: cdims = DenseConvDims(x, w; padding=1) @test isapprox(ddims(conv(x, w, cdims)), y_pad, rtol = 1.0e-7) # Next, test crosscor/conv with a flipped kernel cdims = DenseConvDims(x, w; flipkernel=true) @test isapprox(ddims(conv(x, w, cdims)), y_flip, rtol = 1.0e-7) end end # Test all in-place implementations/interfaces convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,] for conv! in convs α, β = 2e0, -1e0 @testset "$(conv!)" begin # First, your basic convolution with no parameters cdims = DenseConvDims(x, w) y0 = rand(rng, -9e0:9e0, size(y_plain)..., 1, 1) @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7) # Next, test convolution on views and alternate datatypes: @test isapprox(ddims(conv!(copy(y0), view(x, repeat([:], ndims(x))...), w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7) @test isapprox(ddims(conv!(Float32.(copy(y0)), Float32.(x), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), Float32.(α*y_plain + β*y0), rtol = 1.0e-7) # Next, introduce stride: cdims = DenseConvDims(x, w; stride=2) y0 = rand(rng, -9e0:9e0, size(y_stride)..., 1, 1) @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_stride + β*y0, rtol = 1.0e-7) # Next, introduce dilation: cdims = DenseConvDims(x, w; dilation=2) y0 = rand(rng, -9e0:9e0, size(y_dil)..., 1, 1) @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_dil + β*y0, rtol = 1.0e-7) # Next, introduce padding: cdims = DenseConvDims(x, w; padding=1) y0 = rand(rng, -9e0:9e0, size(y_pad)..., 1, 1) @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_pad + β*y0, rtol = 1.0e-7) # Next, test crosscor/conv with a flipped kernel cdims = DenseConvDims(x, w; flipkernel=true) y0 = rand(rng, -9e0:9e0, size(y_flip)..., 1, 1) @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_flip + β*y0, rtol = 1.0e-7) end end # Test all implementations/interfaces for (∇conv_filter, ∇conv_data) in ( (NNlib.∇conv_filter, NNlib.∇conv_data), (NNlib.∇conv_filter_im2col, NNlib.∇conv_data_im2col), (NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct), ) @testset "$(∇conv_filter)/$(∇conv_data)" begin # First, your basic convolution with no parameters cdims = DenseConvDims(x, w) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx, rtol = 1.0e-7) # Next, test convolution on views and alternate datatypes: @test isapprox(ddims(∇conv_filter(x, view(dy, repeat([:], ndims(dy))...), cdims)), dw, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(view(dy, repeat([:], ndims(dy))...), w, cdims)), dx, rtol = 1.0e-7) @test isapprox(ddims(∇conv_filter(Float32.(x), Float32.(dy), cdims)), dw, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(Float32.(dy), Float32.(w), cdims)), dx, rtol = 1.0e-7) # Next, introduce stride: cdims = DenseConvDims(x, w; stride=2) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_stride, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_stride, rtol = 1.0e-7) # Next, introduce dilation: cdims = DenseConvDims(x, w; dilation=2) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_dil, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_dil, rtol = 1.0e-7) # Next, introduce padding: cdims = DenseConvDims(x, w; padding=1) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_pad, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_pad, rtol = 1.0e-7) # Next, test crosscor/conv with a flipped kernel cdims = DenseConvDims(x, w; flipkernel=true) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_flip, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_flip, rtol = 1.0e-7) end end # Test im2col for beta in (-2.0, -1.0, 0.0, 0.5, 1.0, 2.0) cache_dx, cache_dy, cache_w = ([0.17;;; 0.19;;; 0.23], [0.11;;; 0.13;;; 0.15], [1.0;;;]) dx_old = copy(cache_dx) cdims = DenseConvDims(cache_dx, cache_w) NNlib.∇conv_data_im2col!(cache_dx, cache_dy, cache_w, cdims; alpha=1.0, beta) @test isapprox(cache_dx, dx_old * beta + cache_dy, rtol = 1.0e-7) end # Test all in-place implementations/interfaces for (∇conv_filter!, ∇conv_data!) in ( (NNlib.∇conv_filter!, NNlib.∇conv_data!), (NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!), (NNlib.∇conv_filter_direct!, NNlib.∇conv_data_direct!), ) #α, β = 2*rand(rng) - 1, 2*rand(rng) - 1 α, β = 2e0, -1e0 @testset "$(∇conv_filter!)/$(∇conv_data!)" begin # First, your basic convolution with no parameters cdims = DenseConvDims(x, w) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) # Next, test convolution on views and alternate datatypes: @test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) @test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) # Next, introduce stride: cdims = DenseConvDims(x, w; stride=2) dy = NNlib.conv(x, w, cdims) flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3) @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_ @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) # Next, introduce dilation: cdims = DenseConvDims(x, w; dilation=2) dy = NNlib.conv(x, w, cdims) flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3 @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag_ # Next, introduce padding: cdims = DenseConvDims(x, w; padding=1) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) # Next, test crosscor/conv with a flipped kernel cdims = DenseConvDims(x, w; flipkernel=true) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) end end end end end @testset "Complex Dense Convolution" begin # For now only 1 dimensional 1x1 convolution x = reshape(complex.(Float64[1:4;], Float64[1:4;] .+ 1), 1, 4, 1) w = reshape(complex.(Float64[1:4;] .+ 2, Float64[1:4;] .+ 3), 1, 4, 1) cdims = DenseConvDims(x, w) convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,] for conv in convs @testset "$(conv)" begin @test isapprox(ddims(conv(x, w, cdims)), [transpose(vec(w)) * vec(x)], rtol = 1.0e-7) end end dy = NNlib.conv(x, w, cdims) for (∇conv_filter, ∇conv_data) in ( (NNlib.∇conv_filter, NNlib.∇conv_data), (NNlib.∇conv_filter_im2col, NNlib.∇conv_data_im2col), (NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct), ) @testset "$(∇conv_filter)/$(∇conv_data)" begin @test isapprox(∇conv_filter(x, dy, cdims), conj(x) .* dy, rtol = 1.0e-7) @test isapprox(∇conv_data(dy, w, cdims), dy .* conj(w), rtol = 1.0e-7) end end end if get(ENV, "NNLIB_TEST_FUZZING", "false") == "true" # @info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them") @testset "fuzzing" begin @info("Starting Convolutional fuzzing tests; this can take a few minutes...") # Now that we're fairly certain things are working, let's fuzz things a little bit: for x_size in ( # 1d tests (1,), (3,), (7,), # 2d tests (1, 3), (3, 3), (12, 3), (20, 17), # 3d tests (1, 1, 3), (3, 5, 4), (20, 17, 14), ), C_in in (1, 3), batch in (1, 5) # Allocate x in this outer loop to save on allocations and speed things up x = rand(x_size..., C_in, batch) dx_direct = similar(x) dx_im2col = similar(x) for w_size in ( (1,), (3,), (7,), (1,1), (1,3), (3,4), (7, 4), (1,1,1), (1,1,3,), (3,4,3), (7,3,2)), C_out in (1, 4) # Give some output to the user that something is in fact happening. print(".") # Allocate w in this outer loop to save on allocations and speed things up w = rand(w_size..., C_in, C_out) dw_direct = similar(w) dw_im2col = similar(w) for S_size in (1, 2, 4, (1,2), (4,1), (2,1,4)), P_size in (0, 1, 2, (0,3,0,3), (4,1,4,2), (1,2,3,4,5,6)), D_size in (1, 2, 4, (1,2), (3,2), (4,2,3)) # Skip tests that are impossible due to mismatched sizes try DenseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size, ) catch e if isa(e, DimensionMismatch) || isa(e, MethodError) continue end rethrow(e) end # Do the actual convolution, comparing convolution implementations cdims = DenseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size) # We use mutating calls with explicitly different initial values, so as # to be sure to catch when we're leaving pieces of the output untouched. y_direct = ones(output_size(cdims)..., C_out, batch) .* 666.666 y_im2col = ones(output_size(cdims)..., C_out, batch) .* 777.777 # Do the convolutions NNlib.conv_direct!(y_direct, x, w, cdims) NNlib.conv_im2col!(y_im2col, x, w, cdims) # Compare! @test y_direct ≈ y_im2col dy = y_im2col # Now push backwards; first for the filter. Again, we initialize our # memory so that segments that never get touched are immediately noticable fill!(dw_direct, 666.666) fill!(dw_im2col, 777.777) NNlib.∇conv_filter_direct!(dw_direct, x, dy, cdims) NNlib.∇conv_filter_im2col!(dw_im2col, x, dy, cdims) @test dw_direct ≈ dw_im2col # And then for the input fill!(dx_direct, 666.666) fill!(dx_im2col, 777.777) NNlib.∇conv_data_direct!(dx_direct, dy, w, cdims) NNlib.∇conv_data_im2col!(dx_im2col, dy, w, cdims) @test dx_direct ≈ dx_im2col end end end println() end else @info "Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them" end @testset "Depthwise Convolution" begin # Start with some easy-to-debug cases that we have worked through and _know_ work. # NOTE: these examples are all single-channel... which doesn't really stress test # the important parts of depthwise convolution! for rank in (1,2,3) @testset "depthwiseconv$(rank)d" begin # Pull out known-good answers for y = depthwiseconv(x, w) y_pad = conv_answer_dict[rank]["y_pad"] y_dil = conv_answer_dict[rank]["y_dil"] y_flip = conv_answer_dict[rank]["y_flip"] # We can always derive y_plain and y_stride from the other answers. y_plain = y_pad[((2:(size(y_pad,idx)-1)) for idx in 1:rank)...] y_stride = y_pad[((2:2:(size(y_pad,idx)-1)) for idx in 1:rank)...] # Same for dx and dw: dx = conv_answer_dict[rank]["dx"] dx_stride = conv_answer_dict[rank]["dx_stride"] dx_pad = conv_answer_dict[rank]["dx_pad"] dx_dil = conv_answer_dict[rank]["dx_dil"] dx_flip = conv_answer_dict[rank]["dx_flip"] dw = conv_answer_dict[rank]["dw"] dw_stride = conv_answer_dict[rank]["dw_stride"] dw_pad = conv_answer_dict[rank]["dw_pad"] dw_dil = conv_answer_dict[rank]["dw_dil"] dw_flip = conv_answer_dict[rank]["dw_flip"] # We generate x and w from the shapes we know they must be x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1) w = reshape(Float64[1:prod(size(dw));], size(dw)..., 1, 1) for conv in (NNlib.depthwiseconv, NNlib.depthwiseconv_im2col, NNlib.depthwiseconv_direct) @testset "$(conv)" begin # First, your basic convolution with no parameters cdims = DepthwiseConvDims(x, w) @test ddims(conv(x, w, cdims)) == y_plain # Next, test convolution on views and alternate datatypes: @test isapprox(ddims(conv(view(x, repeat([:], ndims(x))...), w, cdims)), y_plain, rtol = 1.0e-7) @test isapprox(ddims(conv(Float32.(x), Float32.(w), cdims)), Float32.(y_plain), rtol = 1.0e-7) # Next, introduce stride: cdims = DepthwiseConvDims(x, w; stride=2) @test isapprox(ddims(conv(x, w, cdims)), y_stride, rtol = 1.0e-7) # Next, introduce dilation: cdims = DepthwiseConvDims(x, w; dilation=2) @test isapprox(ddims(conv(x, w, cdims)), y_dil, rtol = 1.0e-7) # Next, introduce padding: cdims = DepthwiseConvDims(x, w; padding=1) @test isapprox(ddims(conv(x, w, cdims)), y_pad, rtol = 1.0e-7) # Next, test crosscor/conv with a flipped kernel cdims = DepthwiseConvDims(x, w; flipkernel=true) @test isapprox(ddims(conv(x, w, cdims)), y_flip, rtol = 1.0e-7) end end # Test all implementations/interfaces for (∇conv_filter, ∇conv_data) in ( (NNlib.∇depthwiseconv_filter, NNlib.∇depthwiseconv_data), (NNlib.∇depthwiseconv_filter_im2col, NNlib.∇depthwiseconv_data_im2col), (NNlib.∇depthwiseconv_filter_direct, NNlib.∇depthwiseconv_data_direct), ) @testset "$(∇conv_filter)/$(∇conv_data)" begin # First, your basic convolution with no parameters cdims = DepthwiseConvDims(x, w) dy = NNlib.depthwiseconv(x, w, cdims) @test ddims(∇conv_filter(x, dy, cdims)) == dw @test ddims(∇conv_data(dy, w, cdims)) == dx # Next, test convolution on views and alternate datatypes: @test ddims(∇conv_filter(x, view(dy, repeat([:], ndims(dy))...), cdims)) == dw @test ddims(∇conv_data(view(dy, repeat([:], ndims(dy))...), w, cdims)) == dx @test ddims(∇conv_filter(Float32.(x), Float32.(dy), cdims)) == dw @test ddims(∇conv_data(Float32.(dy), Float32.(w), cdims)) == dx # Next, introduce stride: cdims = DepthwiseConvDims(x, w; stride=2) dy = NNlib.depthwiseconv(x, w, cdims) @test ddims(∇conv_filter(x, dy, cdims)) == dw_stride @test ddims(∇conv_data(dy, w, cdims)) == dx_stride # Next, introduce dilation: cdims = DepthwiseConvDims(x, w; dilation=2) dy = NNlib.depthwiseconv(x, w, cdims) @test ddims(∇conv_filter(x, dy, cdims)) == dw_dil @test ddims(∇conv_data(dy, w, cdims)) == dx_dil # Next, introduce padding: cdims = DepthwiseConvDims(x, w; padding=1) dy = NNlib.depthwiseconv(x, w, cdims) @test ddims(∇conv_filter(x, dy, cdims)) == dw_pad @test ddims(∇conv_data(dy, w, cdims)) == dx_pad # Next, test crosscor/conv with a flipped kernel cdims = DepthwiseConvDims(x, w; flipkernel=true) dy = NNlib.depthwiseconv(x, w, cdims) @test ddims(∇conv_filter(x, dy, cdims)) == dw_flip @test ddims(∇conv_data(dy, w, cdims)) == dx_flip end end end end # Do some real depthwise convolution tests x = Float64.(reshape(1:2, (1,2,1))) w = Float64.(reshape(1:6, (3,1,2))) cdims = DepthwiseConvDims(x, w; padding=1) for conv in (NNlib.depthwiseconv, NNlib.depthwiseconv_im2col, NNlib.depthwiseconv_direct) @test conv(x, w, cdims)[:] ≈ [2, 10] rtol=1e-7 end end if get(ENV,"NNLIB_TEST_FUZZING","false") == "true" @testset "fuzzing" begin @info("Starting Depthwise Convolutional fuzzing tests; this can take a few minutes...") # Now that we're fairly certain things are working, let's fuzz things a little bit: for x_size in ( # 1d tests (1,), (3,), (7,), # 2d tests (1, 3), (3, 3), (12, 3), (20, 17), # 3d tests (1, 1, 3), (3, 5, 4), (20, 17, 14), ), C_in in (1, 3), batch in (1, 5) # Allocate x in this outer loop to save on allocations and speed things up x = rand(x_size..., C_in, batch) dx_direct = similar(x) dx_im2col = similar(x) for w_size in ( (1,), (3,), (7,), (1,1), (1,3), (3,4), (7, 4), (1,1,1), (1,1,3,), (3,4,3), (7,3,2)), C_mult in (1, 4) # Give some output to the user that something is in fact happening. print(".") # Allocate w in this outer loop to save on allocations and speed things up w = rand(w_size..., C_mult, C_in) dw_direct = similar(w) dw_im2col = similar(w) for S_size in (1, 2, 4, (1,2), (4,1), (2,1,4)), P_size in (0, 1, 2, (0,3,0,3), (4,1,4,2), (1,2,3,4,5,6)), D_size in (1, 2, 4, (1,2), (3,2), (4,2,3)) # Skip tests that are impossible due to mismatched sizes try DepthwiseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size, ) catch e if isa(e, DimensionMismatch) || isa(e, MethodError) continue end rethrow(e) end # Do the actual convolution, comparing convolution implementations cdims = DepthwiseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size) # We use mutating calls with explicitly different initial values, so as # to be sure to catch when we're leaving pieces of the output untouched. y_direct = ones(output_size(cdims)..., channels_out(cdims), batch) .* 666.666 y_im2col = ones(output_size(cdims)..., channels_out(cdims), batch) .* 777.777 # Do the convolutions NNlib.depthwiseconv_direct!(y_direct, x, w, cdims) NNlib.depthwiseconv_im2col!(y_im2col, x, w, cdims) # Compare! @test y_direct ≈ y_im2col dy = y_im2col # Now push backwards; first for the filter. Again, we initialize our # memory so that segments that never get touched are immediately noticable fill!(dw_direct, 666.666) fill!(dw_im2col, 777.777) NNlib.∇depthwiseconv_filter_direct!(dw_direct, x, dy, cdims) NNlib.∇depthwiseconv_filter_im2col!(dw_im2col, x, dy, cdims) @test dw_direct ≈ dw_im2col # And then for the input fill!(dx_direct, 666.666) fill!(dx_im2col, 777.777) NNlib.∇depthwiseconv_data_direct!(dx_direct, dy, w, cdims) NNlib.∇depthwiseconv_data_im2col!(dx_im2col, dy, w, cdims) @test dx_direct ≈ dx_im2col end end end println() end else @info "Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them" end @testset "Grouped Convolutions" begin x′ = rand(Float32, 28, 28, 100, 2) w′ = rand(Float32, 3, 3, 20, 15) @test_throws DimensionMismatch DenseConvDims(x′, w′) cdims = DenseConvDims(x′, w′, groups = 5) @test groupcount(cdims) == 5 y = conv(x′, w′, cdims) _, back = Zygote.pullback((x, w) -> sum(conv(x, w, cdims)), x′, w′) gs_x, gs_w = back(1.f0) ips = Iterators.partition(1:100, 20) ops = Iterators.partition(1:15, 3) for (i,o) in zip(ips,ops) _, back_reg = Zygote.pullback((x, w) -> sum(conv(x, w)), x′[:,:,i,:], w′[:,:,:,o]) gs_x_reg, gs_w_reg = back_reg(1.f0) @test conv(x′[:,:,i,:], w′[:,:,:,o]) ≈ y[:,:,o,:] @test gs_x_reg ≈ gs_x[:,:,i,:] @test gs_w_reg ≈ gs_w[:,:,:,o] end # Currently hangs due to a FiniteDifferences issue @test_skip gradtest((x, w) -> sum(conv(x, w, cdims)), x′, w′) end @testset "conv_wrapper" begin x = rand(10, 10, 3, 10) w = rand(2, 2, 3, 16) w1 = rand(3, 4, 3, 16) @test size(conv(x, w)) == (9, 9, 16, 10) @test size(conv(x, w; stride = (2, 2), pad = (2, 2))) == (7, 7, 16, 10) @test size(conv(x, w1; stride = (1, 2), pad = (2, 3))) == (12, 7, 16, 10) @test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2))) == (12, 7, 16, 10) @test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (12, 7, 16, 10) end # https://github.com/FluxML/NNlib.jl/issues/369 @testset "conv_wrapper with groups - not equal types that trigger direct backend" begin x = rand(Float32, 10, 10, 32, 8) w = rand(Float64, 2, 2, 16, 4) g = 2 @test conv(x, w; groups=g) ≈ conv(x, Float32.(w); groups=g) @test conv(x, w; stride = (2, 2), pad = (2, 2), groups=g) ≈ conv(x, w; stride = (2, 2), pad = (2, 2), groups=g) @test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g) ≈ conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g) @test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g) ≈ conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g) end @testset "depthwiseconv_wrapper" begin x = rand(10, 10, 3, 10) w = rand(2, 2, 3, 3) w1 = rand(3, 4, 3, 3) @test size(depthwiseconv(x, w)) == (9, 9, 9, 10) @test size(depthwiseconv(x, w; stride = (2, 2), pad = (2, 2))) == (7, 7, 9, 10) @test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3))) == (12, 7, 9, 10) @test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3), dilation = (2, 2))) == (10, 5, 9, 10) @test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (10, 5, 9, 10) end # https://github.com/FluxML/NNlib.jl/pull/171 @testset "conv_direct! - Check Sizes" begin x_size = (6, 7, 8, 5, 3) y_size = (5, 6, 7, 4, 3) w_size = (2, 2, 2, 5, 4) x = randn(Float32, x_size); y = randn(Float32, y_size); w = randn(Float32, w_size); cdims = DenseConvDims(x_size, w_size) @test size(NNlib.conv_direct!(y, x, w, cdims)) == y_size @test size(NNlib.∇conv_data_direct!(x, y, w, cdims)) == x_size @test size(NNlib.∇conv_filter_direct!(w, x, y, cdims)) == w_size end # https://github.com/FluxML/NNlib.jl/issues/490 # https://github.com/FluxML/NNlib.jl/issues/405 @testset "conv_direct! - Unusual input types" begin # Create test type that can't be indexed when undefined. # This simulates the worst-case scenario for custom types. struct MyFloat <: Real set::Set{Float32} end # Test that direct indexing fails when undefined. v = Array{MyFloat}(undef, 3) @test_throws UndefRefError v[1] # Define minimal set of functions required for conv_direct! MyFloat(x::MyFloat) = x MyFloat(x::Real) = MyFloat(Set(Float32(x))) Base.:+(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) + only(y.set)) Base.:*(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) * only(y.set)) Base.promote_rule(::Type{MyFloat}, ::Type{Float32}) = MyFloat Base.rand(::AbstractRNG, ::SamplerType{MyFloat}) = MyFloat(rand(Float32)) Base.zero(::MyFloat) = MyFloat(zero(Float32)) Base.zero(::Type{MyFloat}) = MyFloat(zero(Float32)) # Test conv_direct! x_size = (6, 7, 8, 5, 3) y_size = (5, 6, 7, 4, 3) w_size = (2, 2, 2, 5, 4) x = rand(MyFloat, x_size); w = randn(Float32, w_size); y = Array{MyFloat}(undef, y_size...); cdims = DenseConvDims(x_size, w_size) y_out = NNlib.conv_direct!(y, x, w, cdims) @test eltype(y_out) == MyFloat @test size(y_out) == y_size end @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) gradtest((x, w) -> conv(x, w, cdims), x, w) gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 y = conv(x, w, cdims) gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y) gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y) dcdims = DepthwiseConvDims(x, w) gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w) # FIXME fails y = depthwiseconv(x, w, dcdims) gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end @static if Test_Enzyme @testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) curconv = conv curconv! = conv! dst = curconv(x, w, cdims) for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) Tret == EnzymeCore.Const && continue # ERROR EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const), atol=1e-6, rtol=1e-6) end end @testset "EnzymeRules: ∇conv_data! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) y = conv(x, w, cdims) dy = randn(rng, size(y)...) dx = ∇conv_data(dy, w, cdims) for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) Tret == EnzymeCore.Const && continue # ERROR EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Ty, Tw) || continue EnzymeTestUtils.test_reverse(∇conv_data!, Tret, (dx, Tdst), (dy, Ty), (w, Tw), (cdims, EnzymeCore.Const), atol=1e-6, rtol=1e-6) end end @testset "EnzymeRules: ∇conv_filter! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DenseConvDims(x, w) y = conv(x, w, cdims) dy = randn(rng, size(y)...) dw = ∇conv_filter(x, dy, cdims) for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) Tret == EnzymeCore.Const && continue # ERROR EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Ty) || continue EnzymeTestUtils.test_reverse(∇conv_filter!, Tret, (dw, Tdst), (x, Tx), (dy, Ty), (cdims, EnzymeCore.Const), atol=1e-6, rtol=1e-6) end end @testset "EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) cdims = DepthwiseConvDims(x, w) curconv = depthwiseconv curconv! = depthwiseconv! dst = curconv(x, w, cdims) for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) Tret == EnzymeCore.Const && continue # ERROR EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const), atol=1e-6, rtol=1e-6) end end end ================================================ FILE: test/conv_bias_act.jl ================================================ @testset "conv_bias_act" begin x = rand(4,4,3,3) w = rand(2,2,3,3) b = rand(1,1,1,3) cdims = DenseConvDims(x, w; stride=2) @test NNlib.conv_bias_act(x, w, cdims, b, relu) ≈ relu.(conv(x, w, cdims) .+ b) atol=1e-5 end ================================================ FILE: test/ctc.jl ================================================ using Test using NNlib: ctc_loss using Zygote: gradient using LinearAlgebra # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` function ctc_ngradient(x, y) f = ctc_loss grads = zero(x) for i in 1:length(x) δ = sqrt(eps()) tmp = x[i] x[i] = tmp - δ/2 y1 = f(x, y) x[i] = tmp + δ/2 y2 = f(x, y) x[i] = tmp grads[i] = (y2-y1)/δ end return grads end @testset "ctc_loss" begin x = rand(10, 50) y = rand(1:9, 30) g1 = gradient(ctc_loss, x, y)[1] g2 = ctc_ngradient(x, y) @test g1 ≈ g2 rtol=1e-5 atol=1e-5 # tests using hand-calculated values x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] y = [1, 2] @test ctc_loss(x, y) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x, y)[1] @test g ≈ ghat rtol=1e-5 atol=1e-5 x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] y = [1, 2] @test ctc_loss(x, y) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] ghat = gradient(ctc_loss, x, y)[1] @test g ≈ ghat rtol=1e-5 atol=1e-5 end ================================================ FILE: test/dropout.jl ================================================ using NNlib, Test, Statistics, Random, LinearAlgebra using Zygote, StableRNGs, ChainRulesCore, Enzyme @testset "dropout" begin # Basics x1 = randn(Float32, 3, 4) @test size(@inferred dropout(x1, 0.1)) == (3, 4) @test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4) @test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4) @test eltype(dropout(x1, 0.1)) == Float32 @test eltype(dropout(x1, 0.1; dims=1)) == Float32 @test eltype(dropout(x1, 0.1; dims=(1,2))) == Float32 rng = Random.default_rng() @test size(@inferred dropout(rng, x1, 0.1)) == (3, 4) @test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4) x2 = Diagonal(randn(Float32, 10)) # Just to check it runs on weird matrices. @test dropout(x2, 0.3) isa Matrix{Float32} # does not infer, but that's OK? # Values @test dropout(x1, 0) == x1 @test dropout(x1.+0im, 0) == x1 @test dropout(x1, 1) == zero.(x1) @test dropout(x1.+im, 1) == zero.(x1) d45 = dropout(trues(100, 100, 100), 0.45) @test mean(d45) ≈ 1 atol=1e-2 dpi2 = dropout(fill(pi, 1000), 0.2) @test sort(unique(dpi2)) ≈ [0, 5pi/4] d33 = dropout(fill(3, 10, 1000), 0.3, dims=2) @test sort(unique(vec(d33))) ≈ [0, 3/(1-0.3)] # Complex -- not worth too much optimisation, but should work! x2 = [1.0+0im,2.0+1im,3.0+3im] # from Flux's tests @test dropout(x2, 0.5) isa Vector{ComplexF64} @test dropout(x2, 0.5; dims=1) isa Vector{ComplexF64} # Gradient rule y, back = rrule(dropout, rng, hcat(trues(1000), falses(1000)), 0.45) dx = back(fill(3, 1000, 2))[3] @test !all(iszero, dx[:,2]) # this is why we save the random choices @test sort(unique(vec(dx))) ≈ [0, 3/(1-0.45)] y2, back2 = rrule(dropout, rng, x2, 0.5) @test y2 isa Vector{ComplexF64} @test back2(one.(y2))[3] isa Vector{ComplexF64} @testset "Zygote" begin @test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa Matrix{Float32} @test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa Matrix{Float32} @test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa Matrix{Float32} # p=0 & p=1 @test Zygote.gradient(x -> sum(dropout(x, 0)), x1)[1] == ones(3,4) @test Zygote.gradient(x -> sum(dropout(x, 1)), x1)[1] == zeros(3,4) # Second order f1(x) = sum(dropout(x, 0.5)) @test_broken Zygote.hessian(f1, [1.0,2.0,3.0]) == zeros(3, 3) # forward over reverse @test Zygote.hessian_reverse(f1, [1.0,2.0,3.0]) == zeros(3, 3) end # Bang y1 = fill!(similar(x1), NaN) @test dropout!(y1, x1, 0.0) == x1 @test y1 == x1 @test dropout!(rng, y1, x1, 1) == zero(x1) @test y1 == zero(x1) # Errors @test_throws ArgumentError dropout(x1, -1) @test_throws ArgumentError dropout(x1, 2) @test_throws ArgumentError dropout!(y1, x1, 3) end @static if Test_Enzyme @testset "EnzymeRules: dropout " begin rng = Random.default_rng() x1 = randn(Float32, 3000, 4000) dx1 = zeros(Float32, 3000, 4000) dout = randn(Float32, 3000, 4000) p = 0.2f0 forward, reverse = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, typeof(Const(dropout)), Duplicated, typeof(Const(rng)), typeof(Duplicated(x1, dx1)), typeof(Const(0.2f0))) tape, primal, shadow = forward(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p)) shadow .= dout reverse(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p), tape) @test dx1[.!tape[1]] ≈ zero(x1)[.!tape[1]] val = convert(Float32, 1/(1-p)) @test dx1[tape[1]] ≈ (val * dout)[tape[1]] end end ================================================ FILE: test/ext_amdgpu/activations.jl ================================================ @testset "Compare CPU & GPU" begin for (T, atol) in ((Float16, 1.0f-2), (Float32, 1.0f-5)) @testset "ndims: $(ndims(x))" for x in (randn(T, 16), randn(T, ntuple(_ -> 2, 5)...), randn(T, ntuple(_ -> 2, 6)...)) gputest(x -> NNlib.relu.(x), x; atol) gputest(x -> NNlib.relu6.(x), x; atol) gputest(x -> NNlib.softplus.(x), x; atol) gputest(x -> tanh.(x), x; atol) gputest(x -> identity.(x), x; atol) end end end ================================================ FILE: test/ext_amdgpu/attention.jl ================================================ @testset "Compare CPU & GPU" begin n = 15 lenq = 3 lenkv = 4 for batch_size in [(), 1, 2, (2, 1, 3)], nheads in [1, 3, 5] q = AMDGPU.rand(Float32, n, lenq, batch_size...) k = AMDGPU.rand(Float32, n, lenkv, batch_size...) v = AMDGPU.rand(Float32, n, lenkv, batch_size...) y, α = @inferred dot_product_attention(q, k, v; nheads) @test y isa ROCArray{Float32} @test size(y) == (n, lenq, batch_size...) @test size(α) == (lenkv, lenq, nheads, batch_size...) @test sum(Array(α), dims=1) ≈ ones(1, lenq, nheads, batch_size...) qh = rand(Float32, n, lenq, batch_size...) kh = rand(Float32, n, lenkv, batch_size...) vh = rand(Float32, n, lenkv, batch_size...) gputest( (x...) -> dot_product_attention(x...; nheads)[1], qh, kh, vh; atol=1f-5) end end @testset "Mask" begin x = AMDGPU.rand(Float32, 4, 2, 3, 1) mask = make_causal_mask(x, dims=3) @test mask isa ROCArray{Bool} α = dot_product_attention_scores(x, x; mask) α_host, mask_host = Array.((α, mask)) @test all((α_host[:, :, 1, 1] .> 0) .== mask_host) @test all((α_host[:, :, 2, 1] .> 0) .== mask_host) end @testset "Dropout" begin q = k = v = AMDGPU.rand(Float32, 10, 10, 10) fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p) y, α = dot_product_attention( q, k, v; nheads=2, fdrop=x -> dropout(x, 0.5)) @test 0.6 > mean(>(0), α) > 0.4 end ================================================ FILE: test/ext_amdgpu/batched_mul.jl ================================================ @testset "batched_mul" begin A = rand(Float32, 3, 3, 2) B = rand(Float32, 3, 3, 2) dA, dB = ROCArray.((A, B)) C = batched_mul(A, B) @test ROCArray(C) ≈ batched_mul(dA, dB) Ct = batched_mul(batched_transpose(A), B) @test ROCArray(Ct) ≈ batched_mul(batched_transpose(dA), dB) Ca = batched_mul(A, batched_adjoint(B)) @test ROCArray(Ca) ≈ batched_mul(dA, batched_adjoint(dB)) # 5-arg batched_mul! C .= pi batched_mul!(C, A, B, 2f0, 3f0) Cpi = ROCArray(similar(C)) .= pi @test ROCArray(C) ≈ batched_mul!(Cpi, dA, dB, 2f0, 3f0) # PermutedDimsArray @test ROCArray(Ct) ≈ batched_mul(PermutedDimsArray(dA, (2, 1, 3)), dB) # FIXME same but with (1, 3, 2) errors D = permutedims(B, (2, 1, 3)) Cp = batched_mul(batched_adjoint(A), B) @test ROCArray(Cp) ≈ batched_mul( batched_adjoint(dA), PermutedDimsArray(ROCArray(D), (2, 1, 3))) # Methods which reshape M = randn(Float32, 3, 3) Cm = batched_mul(A, M) @test ROCArray(Cm) ≈ batched_mul(dA, ROCArray(M)) end ================================================ FILE: test/ext_amdgpu/batched_repr.jl ================================================ function print_array_strs(x) str = sprint((io, x)->show(io, MIME"text/plain"(), x), x) return @view split(str, '\n')[2:end] end @testset "BatchedAdjOrTrans" begin x = rand(Float32, 3, 4, 2) y = ROCArray(x) bax = batched_adjoint(x) btx = batched_transpose(x) bay = batched_adjoint(y) bty = batched_transpose(y) @test sprint(show, bax) == sprint(show, bay) @test sprint(show, btx) == sprint(show, bty) @test print_array_strs(bax) == print_array_strs(bay) @test print_array_strs(btx) == print_array_strs(bty) @test Array(bax) == Array(bay) @test collect(bax) == collect(bay) @test Array(btx) == Array(bty) @test collect(btx) == collect(bty) for shape in (:, (12, 2)) rbax = reshape(bax, shape) rbtx = reshape(btx, shape) rbay = reshape(bay, shape) rbty = reshape(bty, shape) @test sprint(show, rbax) == sprint(show, rbay) @test sprint(show, rbtx) == sprint(show, rbty) @test print_array_strs(rbax) == print_array_strs(rbay) @test print_array_strs(rbtx) == print_array_strs(rbty) @test Array(rbax) == Array(rbay) @test collect(rbax) == collect(rbay) @test Array(rbtx) == Array(rbty) @test collect(rbtx) == collect(rbty) end end ================================================ FILE: test/ext_amdgpu/conv.jl ================================================ @testset "Compare CPU & GPU" begin channels, batch = 3, 2 for T in (Float16, Float32), nd in (1, 2, 3) x = rand(Float32, fill(4, nd)..., 3, 1) w = rand(Float32, fill(2, nd)..., channels, 4) cdims = DenseConvDims(x, w, flipkernel=true) gputest((x, w) -> NNlib.conv(x, w, cdims), x, w; atol=1e-4) # This one flips manually kernel for AMDGPU. cdims = DenseConvDims(x, w) gputest((x, w) -> NNlib.conv(x, w, cdims), x, w; atol=1e-4) end end ================================================ FILE: test/ext_amdgpu/dropout.jl ================================================ @testset "Test API" begin x = AMDGPU.randn(Float32, 3, 4) @test size(@inferred dropout(x, 0.1)) == (3, 4) @test size(@inferred dropout(x, 0.2; dims=2)) == (3, 4) @test size(@inferred dropout(x, 0.3; dims=(1, 2))) == (3, 4) rng = AMDGPU.rocrand_rng() @test size(@inferred dropout(rng, x, 0.1)) == (3, 4) @test size(@inferred dropout(rng, x, 0.1; dims=2)) == (3, 4) # Values d45 = dropout(AMDGPU.ones(100, 100, 100), 0.45) @test mean(d45) ≈ 1 atol=1e-2 dpi2 = dropout(AMDGPU.fill(1f0 * pi, 1000), 0.2) @test sort(unique(Array(dpi2))) ≈ [0, 5 * pi / 4] d33 = dropout(AMDGPU.fill(3f0, 10, 1000), 0.3, dims=2) @test sort(unique(vec(Array(d33)))) ≈ [0, 3 / (1 - 0.3)] @test Zygote.gradient(x -> sum(dropout(x, 0.1)), x)[1] isa ROCArray{Float32} end ================================================ FILE: test/ext_amdgpu/pool.jl ================================================ @testset "Compare CPU & GPU" begin channels, batch = 3, 2 for T in (Float16, Float32), nd in (1, 2, 3) x = rand(T, fill(8, nd)..., channels, batch) pdims = PoolDims(x, 2) # NOTE: Disable grad check for maxpool as *sometimes* # it does not *completely* agree with CPU :/ gputest(x -> NNlib.maxpool(x, pdims), x; checkgrad=false) gputest(x -> NNlib.meanpool(x, pdims), x) end end ================================================ FILE: test/ext_amdgpu/runtests.jl ================================================ using NNlib: batched_adjoint, batched_mul, batched_mul!, batched_transpose using NNlib: is_strided, storage_type using LinearAlgebra AMDGPU.allowscalar(false) function gputest(f, xs...; checkgrad=true, atol=1e-6, kws...) cpu_in = xs gpu_in = ROCArray.(xs) cpu_out = f(cpu_in...; kws...) gpu_out = f(gpu_in...; kws...) @test collect(cpu_out) ≈ collect(gpu_out) if checkgrad cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...) gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...) for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad) if cpu_g === nothing @test gpu_g === nothing else @test collect(cpu_g) ≈ collect(gpu_g) atol=atol end end end end @testset "Storage types" begin include("storage_type.jl") end @testset "Batched repr" begin include("batched_repr.jl") end @testset "Batched multiplication" begin include("batched_mul.jl") end @testset "Convolution" begin include("conv.jl") end @testset "Pooling" begin include("pool.jl") end @testset "Softmax" begin include("softmax.jl") end @testset "Activations" begin include("activations.jl") end @testset "Dropout" begin include("dropout.jl") end @testset "Attention" begin include("attention.jl") end ================================================ FILE: test/ext_amdgpu/softmax.jl ================================================ @testset "Compare CPU & GPU" begin for (T, atol) in ((Float16, 1f-2), (Float32, 1f-5)) for (sz, dims) in [ ((5,), :), ((5,), 1), ((5, 5), :), ((5, 5), 1), ((5, 5), 2), ((5, 5, 5, 5), (2, 3)), ((5, 5, 5, 5), (2, 4)), ] if T == Float16 x = ones(T, sz) # Really low precision. else x = randn(T, sz) end gputest(NNlib.softmax, x; atol) gputest(NNlib.logsoftmax, x; atol) end end end ================================================ FILE: test/ext_amdgpu/storage_type.jl ================================================ @testset "NNlib storage type" begin x = ROCArray(ones(Float32, 10, 10)) @test storage_type(x) <: ROCArray{Float32, 2} @test storage_type(reshape(view(x, 1:2:10,:), 10, :)) <: ROCArray{Float32, 2} @test is_strided(x) @test is_strided(view(x, 1:2:5,:)) @test is_strided(PermutedDimsArray(x, (2, 1))) @test !is_strided(reshape(view(x, 1:2:10, :), 10, :)) @test !is_strided((x .+ im)') @test !is_strided(Diagonal(ROCArray(ones(3)))) end ================================================ FILE: test/ext_cuda/activations.jl ================================================ @testset "activation broadcast" begin for f in NNlib.ACTIVATIONS if f ∉ [:rrelu] @eval gputest(x -> $f.(x), rand(Float64, 5)) end end end @testset "forward diff" begin for f in NNlib.ACTIVATIONS if f ∉ [:rrelu] @eval gputest(x -> $f.(x), Dual.(rand(5), 1)) end end end # Broadcasting over complex CuArray works without NNlibCUDAExt, this test checks that # NNlibCUDAExt does not cause such operations to take a fast path which does not support # complex numbers (e.g. cuDNN) @testset "complex" begin f(x) = tanh.(x) cs = rand(ComplexF64, 5) @test f(cs) ≈ collect(f(CuArray(cs))) end @testset "softplus" begin # softplus does not give `Inf` for large arguments x = CuArray([1000.]) @test all(softplus.(x) .== x) end @testset "input is preserved" begin x = CUDA.ones(1) @test Array(x) == [1f0] tanh.(x) @test Array(x) == [1f0] y = tanh.(x) @test Array(x) == [1f0] @test Array(y) == [tanh(1f0)] x .= tanh.(y) @test Array(y) == [tanh(1f0)] @test Array(x) == [tanh(tanh(1f0))] end @testset "fused act addition broadcast" begin x = CUDA.rand(Float32, 10, 10) b = CUDA.rand(Float32, 10) for act in getfield.((NNlib,), NNlib.ACTIVATIONS) fused_act_add = act ∘ + @test fused_act_add.(x, b) ≈ act.(x .+ b) end end ================================================ FILE: test/ext_cuda/batchedadjtrans.jl ================================================ function print_array_strs(x) str = sprint((io, x)->show(io, MIME"text/plain"(), x), x) return @view split(str, '\n')[2:end] end @testset "BatchedAdjOrTrans" begin x = randn(Float32, 3,4,2) y = cu(x) bax = batched_adjoint(x) btx = batched_transpose(x) bay = batched_adjoint(y) bty = batched_transpose(y) @test sprint(show, bax) == sprint(show, bay) @test sprint(show, btx) == sprint(show, bty) @test print_array_strs(bax) == print_array_strs(bay) @test print_array_strs(btx) == print_array_strs(bty) @test Array(bax) == Array(bay) @test collect(bax) == collect(bay) @test Array(btx) == Array(bty) @test collect(btx) == collect(bty) for shape in (:, (12, 2)) rbax = reshape(bax, shape) rbtx = reshape(btx, shape) rbay = reshape(bay, shape) rbty = reshape(bty, shape) @test sprint(show, rbax) == sprint(show, rbay) @test sprint(show, rbtx) == sprint(show, rbty) @test print_array_strs(rbax) == print_array_strs(rbay) @test print_array_strs(rbtx) == print_array_strs(rbty) @test Array(rbax) == Array(rbay) @test collect(rbax) == collect(rbay) @test Array(rbtx) == Array(rbty) @test collect(rbtx) == collect(rbty) end end ================================================ FILE: test/ext_cuda/batchedmul.jl ================================================ @testset "batched_mul" begin using NNlib: batched_mul, batched_mul!, batched_vec, batched_adjoint, batched_transpose A = randn(Float32, 3,3,2); B = randn(Float32, 3,3,2); C = batched_mul(A, B) @test CuArray(C) ≈ batched_mul(CuArray(A), CuArray(B)) Ct = batched_mul(batched_transpose(A), B) @test CuArray(Ct) ≈ batched_mul(batched_transpose(CuArray(A)), CuArray(B)) Ca = batched_mul(A, batched_adjoint(B)) @test CuArray(Ca) ≈ batched_mul(CuArray(A), batched_adjoint(CuArray(B))) # 5-arg batched_mul! C .= pi batched_mul!(C, A, B, 2f0, 3f0) cuCpi = CuArray(similar(C)) .= pi @test CuArray(C) ≈ batched_mul!(cuCpi, CuArray(A), CuArray(B), 2f0, 3f0) # PermutedDimsArray @test CuArray(Ct) ≈ batched_mul(PermutedDimsArray(CuArray(A), (2,1,3)), CuArray(B)) D = permutedims(B, (1,3,2)) Cp = batched_mul(batched_adjoint(A), B) @test CuArray(Cp) ≈ batched_mul(batched_adjoint(CuArray(A)), PermutedDimsArray(CuArray(D), (1,3,2))) # Methods which reshape M = randn(Float32, 3,3) Cm = batched_mul(A, M) @test CuArray(Cm) ≈ batched_mul(CuArray(A), CuArray(M)) Cv = batched_vec(permutedims(A,(3,1,2)), M) @test CuArray(Cv) ≈ batched_vec(PermutedDimsArray(CuArray(A),(3,1,2)), CuArray(M)) end @testset "NNlib storage_type etc." begin using LinearAlgebra using NNlib: is_strided, are_strided, storage_type M = cu(ones(10,10)) @test is_strided(M) @test is_strided(view(M, 1:2:5,:)) @test is_strided(PermutedDimsArray(M, (2,1))) @test !is_strided(reshape(view(M, 1:2:10,:), 10,:)) @test !is_strided((M .+ im)') @test !is_strided(Diagonal(cu(ones(3)))) @test storage_type(M) <: CuArray{Float32,2} @test storage_type(reshape(view(M, 1:2:10,:), 10,:)) <: CuArray{Float32,2} end ================================================ FILE: test/ext_cuda/batchnorm.jl ================================================ using Statistics @testset "Batchnorm" begin v = CUDA.rand(Float32, 2) m = CUDA.rand(Float32, 2, 5) @testset for training in (true, false), track_stats in (true, false) kws = (training=training, track_stats=track_stats) # Normal batchnorm(v, v, m, v, v, 1.0; kws...) ∇batchnorm(v, v, m, m, v, v, 1.0; kws...) # No affine batchnorm(nothing, nothing, m, v, v, 1.0; kws...) ∇batchnorm(nothing, nothing, m, m, v, v, 1.0; kws...) # No tracking batchnorm(v, v, m, nothing, nothing, 1.0; kws...) ∇batchnorm(v, v, m, m, nothing, nothing, 1.0; kws...) # Both or neither tracked or affine params must be set for (α, β) in ((v, nothing), (nothing, v)) @test_throws MethodError batchnorm(α, β, m, v, v, 1.0; kws...) @test_throws MethodError ∇batchnorm(α, β, m, m, v, v, 1.0; kws...) @test_throws ArgumentError batchnorm(v, v, m, α, β, 1.0; kws...) end end @testset "test mode" begin y_no_track_stats = batchnorm(v, v, m, nothing, nothing, 1.0; training=false, track_stats=false) running_mean = mean(m, dims=[2]) running_var = var(m, mean=running_mean, dims=[2], corrected=false) y_track_stats = batchnorm(v, v, m, running_mean, running_var, 1.0; training=false, track_stats=true) # batchnorm without tracked stats should equal bathnorm with tracked stats where the # stats are calculated only on the input. @test y_no_track_stats ≈ y_track_stats end end ================================================ FILE: test/ext_cuda/conv.jl ================================================ using NNlib: DenseConvDims @testset "convolution" begin @testset "$T" for T in (Float64, ComplexF64) a, b, c = rand(T, 10, 10, 3, 1), rand(T, 2, 2, 3, 4), rand(T, 9, 9, 4, 1) da, db, dc = CuArray(a), CuArray(b), CuArray(c) cdims = DenseConvDims(a, b) @test NNlib.conv(a, b, cdims) ≈ collect(NNlib.conv(da, db, cdims)) @test ∇conv_data(c, b, cdims) ≈ collect(∇conv_data(dc, db, cdims)) @test ∇conv_filter(a, c, cdims) ≈ collect(∇conv_filter(da, dc, cdims)) if T <: Complex @testset "mixed real and complex" begin @test NNlib.conv(real(a), b, cdims) ≈ collect(NNlib.conv(real(da), db, cdims)) @test NNlib.conv(a, real(b), cdims) ≈ collect(NNlib.conv(da, real(db), cdims)) @test ∇conv_data(c, real(b), cdims) ≈ collect(∇conv_data(dc, real(db), cdims)) @test ∇conv_filter(real(a), c, cdims) ≈ collect(∇conv_filter(real(da), dc, cdims)) end end # Test Conv Bias Activation bias = rand(T, 1, 1, 4, 1) dbias = CuArray(bias) act = T <: Complex ? abs2 : NNlib.relu @test conv_bias_act(a, b, cdims, bias, act) ≈ collect(conv_bias_act(da, db, cdims, dbias, act)) @test conv_bias_act(a, b, cdims, bias, identity) ≈ collect(conv_bias_act(da, db, cdims, dbias, identity)) # Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs options = Dict{Any, Any}.(( (), (:dilation => 2), (:flipkernel => true), (:stride => 2), (:padding => 1), (:padding => (1,0)), (:padding => (0,1)), (:padding => (2,3)), )) C_in_ = 3 C_out = 4 batch_size = 1 # we use this activation for the gpu tests # as we can't take gradients of complex quantities act = T <: Complex ? x-> abs2(x) : identity @testset "groups=$groups, num_spatial_dims=$num_spatial_dims" for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3) # Make `C_in = C_out` when using grouped convolution. C_in = groups == 1 ? C_in_ : C_out # Initialize data we'll run our tests over x = rand(T, fill(8, num_spatial_dims)..., C_in, batch_size) w = rand(T, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out) @testset "opts #$i" for (i,opts) in enumerate(options) opts[:groups] = groups if :padding in keys(opts) padding = opts[:padding] if 1 < length(padding) && length(padding) != 2num_spatial_dims opts[:padding] = ntuple(i -> padding[mod1(i,2)] .+ 2div(i-1,2), 2num_spatial_dims) end end cdims = DenseConvDims(x, w; opts...) y = NNlib.conv(x, w, cdims) # Test that basic convolution is equivalent across GPU/CPU @testset "cpu==gpu" begin @testset "conv" begin gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, w) if T <: Complex gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), real(x), w) gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, real(w)) end end @testset "∇conv_data" begin gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, w) if T <: Complex gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, real(w)) end end @testset "∇conv_filter" begin gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), x, y) if T <: Complex gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), real(x), y) end end end # Scaling factors @testset "scale-alpha" begin gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, w, checkgrad=false) # TODO gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, w, checkgrad=false) # TODO gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), x, y, checkgrad=false) # TODO if T <: Complex gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), real(x), w, checkgrad=false) gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, real(w), checkgrad=false) # TODO gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, real(w), checkgrad=false) # TODO gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), real(x), y, checkgrad=false) # TODO end end @testset "scale-beta" begin gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, w, checkgrad=false, broken=false) gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, x, y, checkgrad=false, broken=false) gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, w, checkgrad=false, broken=false) if T <: Complex gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, real(x), w, checkgrad=false) gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, real(w), checkgrad=false) gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, real(w), checkgrad=false) gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, real(x), y, checkgrad=false) end end end end end end ================================================ FILE: test/ext_cuda/ctc.jl ================================================ # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` function ctc_ngradient(x, y) f = ctc_loss grads = zero(x) for i in 1:length(x) δ = sqrt(eps()) tmp = x[i] x[i] = tmp - δ/2 y1 = f(x, y) x[i] = tmp + δ/2 y2 = f(x, y) x[i] = tmp grads[i] = (y2-y1)/δ end return grads end @testset "ctc-gpu" begin x = rand(10, 50) y = rand(1:9, 30) x_cu = CuArray(x) g1 = gradient(ctc_loss, x_cu, y)[1] g1 = g1 |> collect g2 = ctc_ngradient(x, y) @test g1 ≈ g2 rtol=1e-5 atol=1e-5 # test that GPU loss matches CPU implementation l1 = ctc_loss(x_cu, y) l2 = ctc_loss(x, y) @test l1 ≈ l2 # tests using hand-calculated values x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray y = [1, 2] @test ctc_loss(x_cu, y) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x_cu, y)[1] |> collect @test g ≈ ghat rtol=1e-5 atol=1e-5 x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray y = [1, 2] |> CuArray @test ctc_loss(x_cu, y) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] ghat = gradient(ctc_loss, x_cu, y)[1] |> collect @test g ≈ ghat rtol=1e-5 atol=1e-5 end ================================================ FILE: test/ext_cuda/dropout.jl ================================================ @testset "dropout + CUDA" begin # Basics x1 = CUDA.randn(3, 4) @test size(@inferred dropout(x1, 0.1)) == (3, 4) @test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4) @test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4) rng = CUDA.default_rng() @test size(@inferred dropout(rng, x1, 0.1)) == (3, 4) @test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4) # Values d45 = dropout(CUDA.ones(100, 100, 100), 0.45) @test mean(d45) ≈ 1 atol=1e-2 dpi2 = dropout(CUDA.fill(1f0 * pi, 1000), 0.2) @test sort(unique(Array(dpi2))) ≈ [0, 5pi/4] d33 = dropout(CUDA.fill(3f0, 10, 1000), 0.3, dims=2) @test sort(unique(vec(Array(d33)))) ≈ [0, 3/(1-0.3)] # Gradient rule y, back = rrule(dropout, rng, hcat(CUDA.ones(1000), CUDA.zeros(1000)), 0.45) dx = back(CUDA.fill(3f0, 1000, 2))[3] @test !all(iszero, dx[:,2]) # this is why we save the random choices @test sort(unique(vec(Array(dx)))) ≈ [0, 3/(1-0.45)] @testset "Zygote" begin @test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa CuArray{Float32} @test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa CuArray{Float32} @test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa CuArray{Float32} end end ================================================ FILE: test/ext_cuda/fold.jl ================================================ @testset "fold" begin # Test for agreement between CPU/GPU versions, across a variety of kwargs options = Dict{Any, Any}.(( (), (:dilation => 2), (:flipkernel => true), (:stride => 2), (:padding => 1), (:padding => (1,0)), (:padding => (0,1)), (:padding => (2,3)), )) C_in = 3 C_out = 4 batch_size = 1 @testset "spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) for opts in options if :padding in keys(opts) padding = opts[:padding] if 1 < length(padding) && length(padding) != 2spatial_rank opts[:padding] = ntuple(i -> padding[mod1(i,2)] .+ 2div(i-1,2), 2spatial_rank) end end x = rand(Float64, fill(8, spatial_rank)..., C_in, batch_size) w = rand(Float64, fill(2, spatial_rank)..., C_in, C_out) cdims = DenseConvDims(x, w; opts...) y = NNlib.unfold(x, cdims) # test equivalence of fold/unfold across GPU/CPU gputest(x -> NNlib.unfold(x, cdims), x) gputest(y -> NNlib.fold(y, size(x), cdims), y) end end end ================================================ FILE: test/ext_cuda/gather.jl ================================================ @testset "gather" begin T = Float32 CT = CuArray{Float32} ## 1d src, 2d index of ints -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu([1 2 3 4; 4 2 1 3; 3 5 5 3]) output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) ## 1d src, 2d index of tuples -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu([(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)]) output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) ## 1d src, 2d index of CartesianIndex -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu(CartesianIndex.([(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)])) output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) ## 1d src, 3d index of ints -> 3d output src = CT([3, 4, 5, 6, 7]) index = cu([1 2 3 4; 4 2 1 3; 3 5 5 3][:,:,1:1]) output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5][:,:,1:1]) y = NNlib.gather(src, index) @test y isa CuArray{Float32,3} @test size(y) == size(index) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) ## 2d src, 2d index of ints -> 3d output src = CT([3 5 7 4 6 8]) index = cu([1 2 3; 2 2 1; 3 1 3]) output = zeros(T, 2, 3, 3) output[:,:,1] = [3 5 7 4 6 8] output[:,:,2] = [5 5 3 6 6 4] output[:,:,3] = [7 3 7 8 4 8] y = NNlib.gather(src, index) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa CuArray{Float32,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @testset "views" begin x = cu(rand(2, 5)) v = view(x, axes(x)...) i = cu([1, 2]) outx = NNlib.gather(x, i) outv = NNlib.gather(v, i) @test outx == outv # discontinuous view v2 = view(x, :, [1,3,5]) outv2 = NNlib.gather(v2, i) @test collect(outv2) == NNlib.gather(collect(v2), collect(i)) end # Zero-sized x = CT([1,2,3]) i = CT(Int[]) y = NNlib.gather(x, i) @test isempty(y) end ================================================ FILE: test/ext_cuda/pooling.jl ================================================ @testset "pooling" begin # Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs for num_spatial_dims in (1, 2, 3) # Initialize data we'll run our tests over C_in = 3 batch_size = 1 x = rand(Float64, fill(8, num_spatial_dims)..., C_in, batch_size) # Test that pooling is equivalent across GPU/CPU pdims = PoolDims(x, 2) y = maxpool(x, pdims) dy = ones(size(y)) gputest(x -> maxpool(x, pdims), x) gputest((dy, y, x) -> ∇maxpool(dy, y, x, pdims), dy, y, x, checkgrad=false) gputest(x -> maxpool(x, pdims), x) gputest((dy, y, x) -> ∇maxpool(dy, y, x, pdims), dy, y, x, checkgrad=false) end end ================================================ FILE: test/ext_cuda/runtests.jl ================================================ using Test using NNlib using Zygote using ForwardDiff: Dual using Statistics: mean using CUDA, cuDNN import CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO using NNlib: batchnorm, ∇batchnorm CUDA.allowscalar(false) include("test_utils.jl") include("activations.jl") include("dropout.jl") include("batchedadjtrans.jl") include("batchedmul.jl") include("conv.jl") include("ctc.jl") include("fold.jl") include("pooling.jl") include("softmax.jl") include("batchnorm.jl") include("scatter.jl") include("gather.jl") include("sampling.jl") ================================================ FILE: test/ext_cuda/sampling.jl ================================================ @testset "Grid Sampling" begin for T in (Float32, Float64) x = ones(T, (2, 2, 1, 1)) grid = Array{T}(undef, 2, 2, 2, 1) grid[:, 1, 1, 1] .= (-1, -1) grid[:, 2, 1, 1] .= (1, -1) grid[:, 1, 2, 1] .= (-1, 1) grid[:, 2, 2, 1] .= (1, 1) ∇grid_true = Array{T}(undef, size(grid)) ∇grid_true[:, :, 1, 1] = [[0.0, 0.0] [-0.5, 0.0]] ∇grid_true[:, :, 2, 1] = [[0.0, -0.5] [-0.5, -0.5]] x_gpu, grid_gpu = CuArray(x), CuArray(grid) padding_mode = :zeros y_gpu = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode) @test x == collect(y_gpu) @test eltype(y_gpu) == T external_grad = CUDA.ones(T, size(y_gpu)) ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode) @test x == collect(∇input) @test ∇grid_true == collect(∇grid) @test eltype(∇input) == T @test eltype(∇grid) == T padding_mode = :border fill!(∇grid_true, 0.0) sampled = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode) @test x == collect(sampled) @test eltype(sampled) == T ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode) @test x == collect(∇input) @test ∇grid_true == collect(∇grid) @test eltype(∇input) == T @test eltype(∇grid) == T end end @testset "Compare grid sampling with NNlib" begin w, h, c, n = 16, 16, 2, 4 input = rand(Float64, w, h, c, n) grid = zeros(Float64, 2, w, h, n) @inbounds for xi in 1:w, yi in 1:h, ni in 1:n grid[1, xi, yi, ni] = (xi / w) * 2.0 - 1.0 + 0.01 grid[2, xi, yi, ni] = (yi / h) * 2.0 - 1.0 end for padding_mode in (:zeros, :border) gputest(grid_sample, input, grid; atol=1e-6, padding_mode=padding_mode) end end @testset "Grid Sampling 3D" begin for T in (Float32, Float64) x = ones(T, (2, 2, 2, 1, 1)) # 3D input with depth=2 grid = Array{T}(undef, 3, 2, 2, 2, 1) # 3D grid with depth=2 grid[:, 1, 1, 1, 1] .= (-1, -1, -1) grid[:, 2, 1, 1, 1] .= (1, -1, -1) grid[:, 1, 2, 1, 1] .= (-1, 1, -1) grid[:, 2, 2, 1, 1] .= (1, 1, -1) grid[:, 1, 1, 2, 1] .= (-1, -1, 1) grid[:, 2, 1, 2, 1] .= (1, -1, 1) grid[:, 1, 2, 2, 1] .= (-1, 1, 1) grid[:, 2, 2, 2, 1] .= (1, 1, 1) ∇grid_true = Array{T}(undef, size(grid)) ∇grid_true[:, 1, 1, 1, 1] .= (0.0, 0.0, 0.0) ∇grid_true[:, 2, 1, 1, 1] .= (-0.5, 0.0, 0.0) ∇grid_true[:, 1, 2, 1, 1] .= (0.0, -0.5, 0.0) ∇grid_true[:, 2, 2, 1, 1] .= (-0.5, -0.5, 0.0) ∇grid_true[:, 1, 1, 2, 1] .= (0.0, 0.0, -0.5) ∇grid_true[:, 2, 1, 2, 1] .= (-0.5, 0.0, -0.5) ∇grid_true[:, 1, 2, 2, 1] .= (0.0, -0.5, -0.5) ∇grid_true[:, 2, 2, 2, 1] .= (-0.5, -0.5, -0.5) x_gpu, grid_gpu = CuArray(x), CuArray(grid) padding_mode = :zeros y_gpu = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode) @test x == collect(y_gpu) @test eltype(y_gpu) == T external_grad = CUDA.ones(T, size(y_gpu)) ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode) @test x == collect(∇input) @test ∇grid_true == collect(∇grid) @test eltype(∇input) == T @test eltype(∇grid) == T padding_mode = :border fill!(∇grid_true, 0.0) sampled = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode) @test x == collect(sampled) @test eltype(sampled) == T ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode) @test x == collect(∇input) @test ∇grid_true == collect(∇grid) @test eltype(∇input) == T @test eltype(∇grid) == T end end @testset "Compare grid sampling with NNlib 3D" begin w, h, d, c, n = 16, 16, 16, 2, 4 # Added depth dimension `d` input = rand(Float64, w, h, d, c, n) grid = zeros(Float64, 3, w, h, d, n) # 3D grid with depth `d` @inbounds for xi in 1:w, yi in 1:h, zi in 1:d, ni in 1:n grid[1, xi, yi, zi, ni] = (xi / w) * 2.0 - 1.0 + 0.01 grid[2, xi, yi, zi, ni] = (yi / h) * 2.0 - 1.0 grid[3, xi, yi, zi, ni] = (zi / d) * 2.0 - 1.0 end for padding_mode in (:zeros, :border) gputest(grid_sample, input, grid; atol=1e-6, padding_mode=padding_mode) end end ================================================ FILE: test/ext_cuda/scatter.jl ================================================ dsts = Dict( 0 => cu([3, 4, 5, 6, 7]), 1 => cu([3 3 4 4 5; 5 5 6 6 7]), ) srcs = Dict( (0, true) => cu(ones(Int, 3, 4)), (0, false) => cu(ones(Int, 3) * collect(1:4)'), (1, true) => cu(ones(Int, 2, 3, 4)), (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), ) idxs = [ cu([1 2 3 4; 4 2 1 3; 3 5 5 3]), # integer index cu([(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)]), # tuple index cu(CartesianIndex.([(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)])), # CartesianIndex index ] types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] @testset "scatter" begin for T = types @testset "$(T)" begin @testset "+" begin for idx = idxs, dims = [0, 1] mutated = true gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "-" begin for idx = idxs, dims = [0, 1] mutated = true gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "max" begin for idx = idxs, dims = [0, 1] mutated = true gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "min" begin for idx = idxs, dims = [0, 1] mutated = true gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end end end for T = [CuArray{Float32}, CuArray{Float64}] @testset "$(T)" begin @testset "*" begin for idx = idxs, dims = [0, 1] mutated = true gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "/" begin for idx = idxs, dims = [0, 1] mutated = true gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end @testset "mean" begin for idx = idxs, dims = [0, 1] mutated = true gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) mutated = false gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) end end end end end ================================================ FILE: test/ext_cuda/softmax.jl ================================================ @testset "softmax" begin for (sz, dims) in [((5,), :), ((5,), 1), ((5,5), :), ((5,5), 1), ((5,5), 2), ((5,5,5,5), (2,3)), ((5,5,5,5), (2,4))] x = randn(Float64, sz) dy = randn(Float64, sz) y = softmax(x, dims=dims) gputest(softmax, x, dims=dims) gputest(NNlib.∇softmax_data, dy, y; dims=dims) y2 = logsoftmax(x, dims=dims) gputest(logsoftmax, x, dims=dims) gputest(NNlib.∇logsoftmax_data, dy, y2; dims=dims) # From NNlib 0.8.3, ∇softmax! is not used in the gradient. # But NNlibCUDA still knows how to call cuDNN routines, let's test they agree: @test NNlib.∇softmax_data(dy, y; dims=dims) ≈ collect(∇softmax!(similar(cu(x)), cu(dy), cu(x), cu(y); dims=dims)) atol=1e-4 @test NNlib.∇logsoftmax_data(dy, y2; dims=dims) ≈ collect(∇logsoftmax!(similar(cu(x)), cu(dy), cu(x), cu(y2); dims=dims)) atol=1e-4 # (Note that ∇softmax! does not depend on x, it's just there to disambiguate from an even older signature.) end end ================================================ FILE: test/ext_cuda/test_utils.jl ================================================ function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...) cpu_in = xs gpu_in = CuArray.(xs) cpu_out = f(cpu_in...; kws...) gpu_out = f(gpu_in...; kws...) @test collect(cpu_out) ≈ collect(gpu_out) rtol=rtol atol=atol broken=broken if checkgrad # use mean instead of sum to prevent error accumulation (for larger # tensors) which causes error to go above atol cpu_grad = gradient((x...) -> mean(f(x...; kws...)), cpu_in...) gpu_grad = gradient((x...) -> mean(f(x...; kws...)), gpu_in...) for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad) if cpu_g === nothing @test gpu_g === nothing else @test collect(cpu_g) ≈ collect(gpu_g) rtol=rtol atol=atol broken=broken_grad end end end end ================================================ FILE: test/ext_metal/activations.jl ================================================ @testset "activation broadcast" begin broken_f = (:hardσ, :leakyrelu) for name in NNlib.ACTIVATIONS # println("Testing forward diff for activation: ", name) f = @eval $name @test gputest(DEVICE, x -> f.(x), rand(5)) broken=name ∈ broken_f end end @testset "forward diff" begin broken_f = (:hardσ, :leakyrelu) for name in NNlib.ACTIVATIONS # println("Testing forward diff for activation: ", name) f = @eval $name @test gputest(DEVICE, x -> f.(x), Dual.(rand(Float32, 5), 1)) broken=name ∈ broken_f end end ================================================ FILE: test/ext_metal/runtests.jl ================================================ using NNlib using Test using Metal using Zygote: gradient using MLDataDevices: gpu_device using ForwardDiff: Dual Metal.allowscalar(false) #TODO move this to test/ test_utils.jl and use it with all backends function gputest(device, f, xs...; checkgrad=true, atol=1e-6, kws...) cpu_in = xs gpu_in = device(xs) cpu_out = f(cpu_in...; kws...) gpu_out = f(gpu_in...; kws...) @test collect(cpu_out) ≈ collect(gpu_out) if checkgrad cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...) gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...) for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad) if cpu_g === nothing @test gpu_g === nothing else @test collect(cpu_g) ≈ collect(gpu_g) atol=atol end end end return true end DEVICE = gpu_device(force=true) include("activations.jl") ================================================ FILE: test/functions.jl ================================================ using NNlib: glu using Zygote @testset "glu" begin x = [1 2 3; 4 5 6; 7 8 9; 10 11 12] @test ceil.(glu(x, 1)) == [1 2 3; 4 5 6] @test_throws AssertionError glu(x, 2) end @testset "AutoDiff" begin local rng = StableRNG(17) gradtest(glu, rand(rng, 4, 3)) end ================================================ FILE: test/inference.jl ================================================ import NNlib: conv_direct, conv_im2col, channels_in, channels_out @testset "Conv Inference" begin for T in (Float32, Float64) impl = [conv, conv_direct, conv_im2col] x = rand(T, 10, 10, 3, 2) w = rand(T, 3, 3, 3, 1) cdims = DenseConvDims(x, w) dy = conv(x, w, cdims) for f in impl @test @inferred(f(x, w, cdims)) isa Array{T,4} end @test @inferred(conv(x, w)) isa Array{T,4} @test @inferred(∇conv_filter(x, dy, cdims)) isa Array{T,4} @test @inferred(∇conv_data(dy, w, cdims)) isa Array{T,4} end end @testset "DepthwiseConv Inference" begin for T in (Float32, Float64) x = rand(T, 10, 10, 3, 2) w = rand(T, 3, 3, 3, 3) cdims = DepthwiseConvDims(x, w) dy = depthwiseconv(x, w) @test @inferred(depthwiseconv(x, w)) isa Array{T,4} @test @inferred(∇depthwiseconv_filter(x, dy, cdims)) isa Array{T,4} @test @inferred(∇depthwiseconv_data(dy, w, cdims)) isa Array{T,4} end end @testset "DenseConvDims Inference" begin # this needs to be in a function to trigger inference problems function channels_in_test(w::AbstractArray) cdims = DenseConvDims((1,1,1,1), size(w)) channels_in(cdims) end # this needs to be in a function to trigger inference problems function channels_out_test(w::AbstractArray) cdims = DenseConvDims((1,1,1,1), size(w)) channels_out(cdims) end w = rand(Float32, 1, 1, 1, 1) @test @inferred(channels_in_test(w)) isa Int @test @inferred(channels_out_test(w)) isa Int end @testset "Pooling inference" begin for T in (Float32, Float64) x = rand(T, 10, 10, 3, 2) pdims = PoolDims(x, 3) y_maxpool = NNlib.maxpool(x, pdims) y_meanpool = NNlib.meanpool(x, pdims) dy = ones(T, size(y_maxpool)...) @test @inferred(NNlib.maxpool(x, pdims)) isa Array{T, 4} @test @inferred(NNlib.meanpool(x, pdims)) isa Array{T, 4} @test @inferred(NNlib.∇maxpool(dy, y_maxpool, x, pdims)) isa Array{T, 4} @test @inferred(NNlib.∇maxpool(dy, y_meanpool, x, pdims)) isa Array{T, 4} end end ================================================ FILE: test/padding.jl ================================================ using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular @testset "padding constant" begin x = rand(2, 2, 2) p = NNlib.gen_pad((1,2,3,4,5,6), (1,2,3), 4) @test p == ((1, 2), (3, 4), (5, 6), (0, 0)) @test_throws ArgumentError NNlib.gen_pad((1,2,3,4,5,), (1,2,3), 4) p = NNlib.gen_pad((1,3), (1,3), 4) @test p == ((1, 1), (0, 0), (3, 3), (0, 0)) p = NNlib.gen_pad(1, (1,2,3), 4) @test p == ((1, 1), (1, 1), (1, 1), (0, 0)) p = NNlib.gen_pad(3, :, 2) @test p == ((3, 3), (3, 3)) p = NNlib.gen_pad((1,0), 1, 2) @test p == ((1,0), (0,0)) y = pad_constant(x, (3, 2, 4)) @test size(y) == (8, 6, 10) @test y[4:5, 3:4, 5:6] ≈ x y[4:5, 3:4, 5:6] .= 0 @test all(y .== 0) @test pad_constant(x, (3, 2, 4)) ≈ pad_zeros(x, (3, 2, 4)) @test pad_zeros(x, 2) ≈ pad_zeros(x, (2,2,2)) y = pad_constant(x, (3, 2, 4, 5), 1.2, dims = (1,3)) @test size(y) == (7, 2, 11) @test y[4:5, 1:2, 5:6] ≈ x y[4:5, 1:2, 5:6] .= 1.2 @test all(y .== 1.2) @test pad_constant(x, (2,2,2,2), 1.2, dims = (1,3)) ≈ pad_constant(x, 2, 1.2, dims = (1,3)) @test pad_constant(x, 1, dims = 1:2) == pad_constant(x, 1, dims = (1,2)) @test size(pad_constant(x, 1, dims = 1)) == (4,2,2) @test all(pad_zeros(randn(2), (1, 2))[[1, 4, 5]] .== 0) gradtest(x -> pad_constant(x, 2), rand(2,2,2)) gradtest(x -> pad_constant(x, (2, 1, 1, 2)), rand(2,2)) gradtest(x -> pad_constant(x, (2, 1,)), rand(2)) end @testset "padding repeat" begin x = rand(2, 2, 2) # y = @inferred pad_repeat(x, (3, 2, 4, 5)) y = pad_repeat(x, (3, 2, 4, 5)) @test size(y) == (7, 11, 2) @test y[4:5, 5:6, :] ≈ x # y = @inferred pad_repeat(x, (3, 2, 4, 5), dims=(1,3)) y = pad_repeat(x, (3, 2, 4, 5), dims=(1,3)) @test size(y) == (7, 2, 11) @test y[4:5, :, 5:6] ≈ x @test pad_repeat(reshape(1:9, 3, 3), (1,2)) == [1 4 7 1 4 7 2 5 8 3 6 9 3 6 9 3 6 9] @test pad_repeat(reshape(1:9, 3, 3), (2,2), dims=2) == [1 1 1 4 7 7 7 2 2 2 5 8 8 8 3 3 3 6 9 9 9] @test pad_repeat(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_repeat(x, 2, dims=(1,3)) gradtest(x -> pad_repeat(x, (2,2,2,2)), rand(2,2,2)) end @testset "padding reflect" begin y = pad_reflect(reshape(1:9, 3, 3), (2,2), dims=2) @test y == [7 4 1 4 7 4 1 8 5 2 5 8 5 2 9 6 3 6 9 6 3] y = pad_reflect(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [9 6 3 6 9 6 3 8 5 2 5 8 5 2 7 4 1 4 7 4 1 8 5 2 5 8 5 2 9 6 3 6 9 6 3 8 5 2 5 8 5 2 7 4 1 4 7 4 1] x = rand(4, 4, 4) @test pad_reflect(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_reflect(x, 2, dims=(1,3)) # pad_reflect needs larger test input as padding must # be strictly less than array size in that dimension gradtest(x -> pad_reflect(x, (2,2,2,2)), rand(3,3,3)) x = reshape(1:9, 3, 3, 1, 1) @test NNlib.pad_reflect(x, (1, 0, 1, 0); dims=1:2) == [ 5 2 5 8; 4 1 4 7; 5 2 5 8; 6 3 6 9;;;;] @test NNlib.pad_reflect(x, (0, 1, 0, 1); dims=1:2) == [ 1 4 7 4; 2 5 8 5; 3 6 9 6; 2 5 8 5;;;;] end @testset "padding symmetric" begin y = pad_symmetric(reshape(1:9, 3, 3), (2,2), dims=2) @test y == [4 1 1 4 7 7 4 5 2 2 5 8 8 5 6 3 3 6 9 9 6] y = pad_symmetric(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [5 2 2 5 8 8 5 4 1 1 4 7 7 4 4 1 1 4 7 7 4 5 2 2 5 8 8 5 6 3 3 6 9 9 6 6 3 3 6 9 9 6 5 2 2 5 8 8 5] x = rand(4, 4, 4) @test pad_symmetric(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_symmetric(x, 2, dims=(1,3)) gradtest(x -> pad_symmetric(x, (2,2,2,2)), rand(2,2,2)) x = reshape(1:9, 3, 3, 1, 1) @test NNlib.pad_symmetric(x, (1, 0, 1, 0); dims=1:2) == [ 1 1 4 7; 1 1 4 7; 2 2 5 8; 3 3 6 9;;;;] @test NNlib.pad_symmetric(x, (0, 1, 0, 1); dims=1:2) == [ 1 4 7 7; 2 5 8 8; 3 6 9 9; 3 6 9 9;;;;] end @testset "padding circular" begin y = pad_circular(reshape(1:9, 3, 3), (2,2), dims=2) @test y == [4 7 1 4 7 1 4 5 8 2 5 8 2 5 6 9 3 6 9 3 6] y = pad_circular(reshape(1:9, 3, 3), (2,2,2,2)) @test y == [5 8 2 5 8 2 5 6 9 3 6 9 3 6 4 7 1 4 7 1 4 5 8 2 5 8 2 5 6 9 3 6 9 3 6 4 7 1 4 7 1 4 5 8 2 5 8 2 5] x = rand(4, 4, 4) @test pad_circular(x, (2, 2, 2, 2), dims=(1,3)) ≈ pad_circular(x, 2, dims=(1,3)) gradtest(x -> pad_circular(x, (2,2,2,2)), rand(2,2,2)) end ================================================ FILE: test/pooling.jl ================================================ # using NNlib, Test maxpool_answer_dict = Dict( 1 => Dict( "y" => [2, 4.], "y_nostride" => [2, 3, 4, 5.], "y_pad" => [1, 3, 5.], "dx" => [0, 2, 0, 4, 0.], "dx_nostride" => [0, 2, 3, 4, 5.], "dx_pad" => [1, 0, 3, 0, 5.], ), 2 => Dict( "y" => [ 7 17.; 9 19. ], "y_nostride" => [ 7 12 17; 8 13 18; 9 14 19; 10 15 20. ], "y_pad" => [ 1 11 16; 3 13 18; 5 15 20. ], "dx" => [ 0 0 0 0; 0 7 0 17; 0 0 0 0; 0 9 0 19; 0 0 0 0. ], "dx_nostride" => [ 0 0 0 0; 0 7 12 17; 0 8 13 18; 0 9 14 19; 0 10 15 20. ], "dx_pad" => [ 1 0 11 16; 0 0 0 0; 3 0 13 18; 0 0 0 0; 5 0 15 20. ], ), 3 => Dict( "y" => reshape([ 27, 29, 37, 39. ], (2, 2, 1)), "y_nostride" => reshape([ 27, 28, 29, 30, 32, 33, 34, 35, 37, 38, 39, 40, 47, 48, 49, 50, 52, 53, 54, 55, 57, 58, 59, 60. ], (4, 3, 2)), "y_pad" => reshape([ 1, 3, 5, 11, 13, 15, 16, 18, 20, 41, 43, 45, 51, 53, 55, 56, 58, 60. ], (3, 3, 2)), "dx" => reshape([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 0, 29, 0, 0, 0, 0, 0, 0, 0, 37, 0, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0. ], (5, 4, 3)), "dx_nostride" => reshape([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 28, 29, 30, 0, 32, 33, 34, 35, 0, 37, 38, 39, 40, 0, 0, 0, 0, 0, 0, 47, 48, 49, 50, 0, 52, 53, 54, 55, 0, 57, 58, 59, 60. ], (5, 4, 3)), "dx_pad" => reshape([ 1, 0, 3, 0, 5, 0, 0, 0, 0, 0, 11, 0, 13, 0, 15, 16, 0, 18, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 41, 0, 43, 0, 45, 0, 0, 0, 0, 0, 51, 0, 53, 0, 55, 56, 0, 58, 0, 60. ], (5, 4, 3)), ) ) meanpool_answer_dict = Dict( 1 => Dict( "y" => [1.5, 3.5], "y_nostride" => [1.5, 2.5, 3.5, 4.5], "y_pad" => [0.5, 2.5, 4.5], "dx" => [0.75, 0.75, 1.75, 1.75, 0.0], "dx_nostride" => [0.75, 2.0, 3.0, 4.0, 2.25], "dx_pad" => [0.25, 1.25, 1.25, 2.25, 2.25], ), 2 => Dict( "y" => [ 4.0 14.0; 6.0 16.0 ], "y_nostride" => [ 4.0 9.0 14.0 5.0 10.0 15.0 6.0 11.0 16.0 7.0 12.0 17.0 ], "y_pad" => [ 0.25 4.25 4.0 1.25 10.0 8.75 2.25 12.0 9.75 ], "dx" => [ 1.0 1.0 3.5 3.5; 1.0 1.0 3.5 3.5; 1.5 1.5 4.0 4.0; 1.5 1.5 4.0 4.0; 0.0 0.0 0.0 0.0 ], "dx_nostride" => [ 1.0 3.25 5.75 3.5; 2.25 7.0 12.0 7.25; 2.75 8.0 13.0 7.75; 3.25 9.0 14.0 8.25; 1.75 4.75 7.25 4.25 ], "dx_pad" => [ 0.0625 1.0625 1.0625 1.0; 0.3125 2.5 2.5 2.1875; 0.3125 2.5 2.5 2.1875; 0.5625 3.0 3.0 2.4375; 0.5625 3.0 3.0 2.4375 ], ), 3 => Dict( "y" => reshape([ 14.0, 16.0, 24.0, 26.0 ], (2, 2, 1)), "y_nostride" => reshape([ 14.0, 15.0, 16.0, 17.0, 19.0, 20.0, 21.0, 22.0, 24.0, 25.0, 26.0, 27.0, 34.0, 35.0, 36.0, 37.0, 39.0, 40.0, 41.0, 42.0, 44.0, 45.0, 46.0, 47.0 ], (4, 3, 2)), "y_pad" => reshape([ 0.125, 0.625, 1.125, 2.125, 5.0, 6.0, 2.0, 4.375, 4.875, 7.75, 16.25, 17.25, 19.25, 40.0, 42.0, 11.5, 23.75, 24.75, ], (3, 3, 2)), "dx" => reshape([ 1.75, 1.75, 2.0, 2.0, 0.0, 1.75, 1.75, 2.0, 2.0, 0.0, 3.0, 3.0, 3.25, 3.25, 0.0, 3.0, 3.0, 3.25, 3.25, 0.0, 1.75, 1.75, 2.0, 2.0, 0.0, 1.75, 1.75, 2.0, 2.0, 0.0, 3.0, 3.0, 3.25, 3.25, 0.0, 3.0, 3.0, 3.25, 3.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ], (5, 4, 3)), "dx_nostride" => reshape([ 1.75, 3.625, 3.875, 4.125, 2.125, 4.125, 8.5, 9.0, 9.5, 4.875, 5.375, 11.0, 11.5, 12.0, 6.125, 3.0, 6.125, 6.375, 6.625, 3.375, 6.0, 12.25, 12.75, 13.25, 6.75, 13.25, 27.0, 28.0, 29.0, 14.75, 15.75, 32.0, 33.0, 34.0, 17.25, 8.5, 17.25, 17.75, 18.25, 9.25, 4.25, 8.625, 8.875, 9.125, 4.625, 9.125, 18.5, 19.0, 19.5, 9.875, 10.375, 21.0, 21.5, 22.0, 11.125, 5.5, 11.125, 11.375, 11.625, 5.875 ], (5, 4, 3)), "dx_pad" => reshape([ 0.015625, 0.078125, 0.078125, 0.140625, 0.140625, 0.265625, 0.625, 0.625, 0.75, 0.75, 0.265625, 0.625, 0.625, 0.75, 0.75, 0.25, 0.546875, 0.546875, 0.609375, 0.609375, 0.96875, 2.03125, 2.03125, 2.15625, 2.15625, 2.40625, 5.0, 5.0, 5.25, 5.25, 2.40625, 5.0, 5.0, 5.25, 5.25, 1.4375, 2.96875, 2.96875, 3.09375, 3.09375, 0.96875, 2.03125, 2.03125, 2.15625, 2.15625, 2.40625, 5.0, 5.0, 5.25, 5.25, 2.40625, 5.0, 5.0, 5.25, 5.25, 1.4375, 2.96875, 2.96875, 3.09375, 3.09375 ], (5, 4, 3)), ) ) lpnormpool_answer_dict = Dict( 1 => Dict( "y" => [2.019312856150994, 4.221163518110637], "y_nostride" => [ 2.080083823051904, 3.2710663101885897, 4.497941445275415, 5.738793548317167 ], "y_pad" => [1.0, 3.605551275463989, 6.4031242374328485], "dx" => [ 0.17258020254042603, 1.9525221042381296, 1.2774501198988355, 3.496467771732918, 0.0 ], "dx_nostride" => [ 0.48074985676913606, 3.1458422620080637, 4.752311710531486, 6.345225258061685, 4.356316321455918 ], "dx_pad" => [1.0, 2.0, 3.0, 4.0, 5.0], "p" => 4.5, "p_nostride" => 3.0, "p_pad" => 2.0 ), 2 => Dict( "y" => [ 8.71909 24.9703; 11.7336 28.3804 ], "y_nostride" => [ 11.1128 23.134 35.5704; 13.4219 25.6082 38.0707; 15.8033 28.0907 40.5735; 18.2249 30.5795 43.0782 ], "y_pad" => [ 1.0 11.3616 16.0; 3.19158 15.9662 21.3545; 5.56869 18.7771 23.7903 ], "dx" => [ 0.33866 4.97727 7.30092 12.8076; 0.957876 6.27208 8.31879 14.0269; 1.51693 6.6057 8.79844 14.3351; 2.33547 7.8822 9.83293 15.5461; 0.0 0.0 0.0 0.0 ], "dx_nostride" => [ 3.33359 19.9471 35.7329 23.8564; 9.89551 44.627 76.2257 50.0307; 13.231 50.9101 82.5686 53.2022; 16.4888 57.223 88.9133 56.3742; 9.54591 30.9869 46.8371 29.3524 ], "dx_pad" => [ 1.0 2.30261 10.4791 16.0; 0.992125 2.0321 7.81903 12.075; 2.73398 2.83743 9.5512 13.9299; 2.43512 2.98652 9.0132 13.5608; 4.25398 3.8865 10.7099 15.4161 ], "p" => 2.5, "p_nostride" => 1.5, "p_pad" => 3.5 ) ) for rank in (1, 2, 3) @testset "pool$(rank)d" begin for (pool, ∇pool, answer_dict) in ( # Main API name (maxpool, ∇maxpool, maxpool_answer_dict), (meanpool, ∇meanpool, meanpool_answer_dict), # _direct name (NNlib.maxpool_direct, NNlib.∇maxpool_direct, maxpool_answer_dict), (NNlib.meanpool_direct, NNlib.∇meanpool_direct, meanpool_answer_dict),) @testset "$(pool)$(rank)d" begin y = answer_dict[rank]["y"] y_nostride = answer_dict[rank]["y_nostride"] y_pad = answer_dict[rank]["y_pad"] dx = answer_dict[rank]["dx"] dx_nostride = answer_dict[rank]["dx_nostride"] dx_pad = answer_dict[rank]["dx_pad"] x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1) # A "drop channels and batch dimension" helper ddims(x) = dropdims(x, dims=(rank + 1, rank + 2)) # Let's ensure that a 1x1x1 pooling kernel always just returns `x` @test pool(x, PoolDims(x, 1)) == x # Test vanilla pooling pdims = PoolDims(x, 2) y_hat = pool(x, pdims) @test ddims(y_hat) == y @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx # Strided pooling pdims = PoolDims(x, 2; stride=1) y_hat = pool(x, pdims) @test ddims(y_hat) == y_nostride @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx_nostride # Padded pooling pdims = PoolDims(x, 2; padding=1) y_hat = pool(x, pdims) @test ddims(y_hat) == y_pad @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx_pad end end end end for rank in (1, 2) for (pool, ∇pool, answer_dict) in ( (lpnormpool, ∇lpnormpool, lpnormpool_answer_dict), (NNlib.lpnormpool_direct, NNlib.∇lpnormpool_direct, lpnormpool_answer_dict),) @testset "$(pool)$(rank)d" begin y = answer_dict[rank]["y"] y_nostride = answer_dict[rank]["y_nostride"] y_pad = answer_dict[rank]["y_pad"] dx = answer_dict[rank]["dx"] dx_nostride = answer_dict[rank]["dx_nostride"] dx_pad = answer_dict[rank]["dx_pad"] p = answer_dict[rank]["p"] p_nostride = answer_dict[rank]["p_nostride"] p_pad = answer_dict[rank]["p_pad"] x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1) ddims(x) = dropdims(x, dims=(rank + 1, rank + 2)) @test pool(x, PoolDims(x, 1); p=p) ≈ x atol = 1e-3 # Test vanilla pooling pdims = PoolDims(x, 2) y_hat = pool(x, pdims; p=p) @test ddims(y_hat) ≈ y atol = 1e-3 @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p)) ≈ dx atol = 1e-3 # Strided pooling pdims = PoolDims(x, 2; stride=1) y_hat = pool(x, pdims; p=p_nostride) @test ddims(y_hat) ≈ y_nostride atol = 1e-3 @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p_nostride)) ≈ dx_nostride atol = 1e-3 # Padded pooling pdims = PoolDims(x, 2; padding=1) y_hat = pool(x, pdims; p=p_pad) @test ddims(y_hat) ≈ y_pad atol = 1e-3 @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p_pad)) ≈ dx_pad atol = 1e-3 end end end @testset "Pooling - Check Sizes" begin x = rand(10, 10, 3, 10) @test size(maxpool(x, (2, 2))) == (5, 5, 3, 10) @test size(maxpool(x, (2, 2); pad=(1, 1), stride=(2, 2))) == (6, 6, 3, 10) @test size(meanpool(x, (2, 2))) == (5, 5, 3, 10) @test size(meanpool(x, (2, 2); pad=(1, 1), stride=(2, 2))) == (6, 6, 3, 10) end # Add another test for 2d maxpool that uses an odd-length size: @testset "Issue #133" begin x = reshape([(1.:9.)...], 3, 3, 1, 1) pdims = PoolDims(size(x), (2, 2), padding=(1, 1), stride=(2, 2)) y = maxpool(x, pdims) dy = y .* 0 .+ 1 dx = ∇maxpool(dy, y, x, pdims) @test dx[:,:,1,1] == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0] end # test "true" strided case, see https://github.com/FluxML/NNlib.jl/issues/205 # obtained with # using FiniteDifferences maxpool_answer_nature = Dict( "rank1" => Dict( # kernel size 2, stride 1, pad 0 "k2s1p0" => (size = (2,), stride = 1, pad = 0, x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), # width, channel, batch_size dx_maxpool = reshape([ 0.0, 1.0, 2.0, 1.0, 0.0 ], 5, 1, 1), dx_meanpool = reshape([ 0.5, 1.0, 1.0, 1.0, 0.5 ], 5, 1, 1),), "k2s1p1" => (size = (2,), stride = 1, pad = 1, x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), dx_maxpool = reshape([ 1.0, 1.0, 2.0, 1.0, 1.0 ], 5, 1, 1), dx_meanpool = reshape([ 1.0, 1.0, 1.0, 1.0, 1.0 ], 5, 1, 1),), "k3s1p1" => (size = (3,), stride = 1, pad = 1, x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), dx_maxpool = reshape([ 0.0, 1.0, 3.0, 1.0, 0.0 ], 5, 1, 1), dx_meanpool = reshape([ 0.6666666666, 1.0, 1.0, 1.0, 0.6666666666 ], 5, 1, 1),), "k3s2p1" => (size = (3,), stride = 2, pad = 1, x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), dx_maxpool = reshape([ 0.0, 1.0, 1.0, 1.0, 0.0 ], 5, 1, 1), dx_meanpool = reshape([ 0.333333333, 0.666666666, 0.333333333, 0.666666666, 0.333333333, ], 5, 1, 1),) ), "rank2" => Dict( # kernel size 2, stride 1, pad 0 "k2s1p0" => (size = (2, 2), stride = 1, pad = 0, x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 0.425287 0.53451 0.0405225 0.729861 0.403925 0.473724 0.571418 0.558427 0.552183 0.561624 ], 5, 5, 1, 1), dx_maxpool = reshape([ 0.0 0.0 2.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0 4.0 0.0 2.0 0.0 1.0 0.0 2.0 0.0 0.0 2.0 0.0 0.0 0.0 ], 5, 5, 1, 1), dx_meanpool = reshape([ 0.25 0.5 0.5 0.5 0.25 0.5 1.0 1.0 1.0 0.5 0.5 1.0 1.0 1.0 0.5 0.5 1.0 1.0 1.0 0.5 0.25 0.5 0.5 0.5 0.25 ], 5, 5, 1, 1)), "k2s1p1" => (size = (2, 2), stride = 1, pad = 1, x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 0.425287 0.53451 0.0405225 0.729861 0.403925 0.473724 0.571418 0.558427 0.552183 0.561624 ], 5, 5, 1, 1), dx_maxpool = reshape([ 1.0 1.0 4.0 1.0 1.0 3.0 0.0 0.0 0.0 2.0 0.0 1.0 4.0 0.0 4.0 1.0 1.0 0.0 2.0 0.0 2.0 4.0 1.0 0.0 3.0 ], 5, 5, 1, 1), dx_meanpool = reshape([ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ], 5, 5, 1, 1)), "k3s1p1" => (size = (3, 3), stride = 1, pad = 1, x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 0.425287 0.53451 0.0405225 0.729861 0.403925 0.473724 0.571418 0.558427 0.552183 0.561624 ], 5, 5, 1, 1), dx_maxpool = reshape([ 0.0 0.0 3.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0 9.0 0.0 3.0 0.0 1.0 0.0 3.0 0.0 0.0 3.0 0.0 0.0 0.0 ], 5, 5, 1, 1), dx_meanpool = reshape([ 0.444444 0.666667 0.666667 0.666667 0.444444 0.666667 1.0 1.0 1.0 0.666667 0.666667 1.0 1.0 1.0 0.666667 0.666667 1.0 1.0 1.0 0.666667 0.444444 0.666667 0.666667 0.666667 0.444444 ], 5, 5, 1, 1)), "k3s2p1" => (size = (3, 3), stride = 2, pad = 1, x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 0.425287 0.53451 0.0405225 0.729861 0.403925 0.473724 0.571418 0.558427 0.552183 0.561624 ], 5, 5, 1, 1), dx_maxpool = reshape([ 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 1.0 0.0 1.0 0.0 2.0 0.0 0.0 1.0 0.0 0.0 0.0 ], 5, 5, 1, 1), dx_meanpool = reshape([ 0.111111 0.222222 0.111111 0.222222 0.111111 0.222222 0.444444 0.222222 0.444444 0.222222 0.111111 0.222222 0.111111 0.222222 0.111111 0.222222 0.444444 0.222222 0.444444 0.222222 0.111111 0.222222 0.111111 0.222222 0.111111 ], 5, 5, 1, 1)) ), "rank3" => Dict( # kernel size 2, stride 1, pad 0 "k2s1p0" => (size = (2, 2, 2), stride = 1, pad = 0, x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 0.716898 0.320672 0.24453 0.288568 0.261484 0.258469 0.121916 0.0685961 ], [ 0.73934 0.16631 0.525109 0.0223458 0.164918 0.790875 0.444085 0.469671 0.116848 0.359845 0.0653075 0.804886 0.525431 0.0402844 0.846814 0.84876 ], [ 0.709245 0.325828 0.715952 0.719116 0.576722 0.405659 0.770104 0.259131 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), dx_maxpool = reshape(cat([ 1.0 0.0 2.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ], [ 0.0 0.0 0.0 0.0 0.0 5.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 1.0 ], [ 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 2.0 1.0 0.0 1.0 0.0 ],dims=3), 4,4,3,1,1), dx_meanpool = reshape(cat([ 0.125 0.25 0.25 0.125 0.25 0.5 0.5 0.25 0.25 0.5 0.5 0.25 0.125 0.25 0.25 0.125 ], [ 0.25 0.5 0.5 0.25 0.5 1.0 1.0 0.5 0.5 1.0 1.0 0.5 0.25 0.5 0.5 0.25 ], [ 0.125 0.25 0.25 0.125 0.25 0.5 0.5 0.25 0.25 0.5 0.5 0.25 0.125 0.25 0.25 0.125 ],dims=3), 4,4,3,1,1)), "k2s1p1" => (size = (2, 2, 2), stride = 1, pad = 1, x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 0.716898 0.320672 0.24453 0.288568 0.261484 0.258469 0.121916 0.0685961 ], [ 0.73934 0.16631 0.525109 0.0223458 0.164918 0.790875 0.444085 0.469671 0.116848 0.359845 0.0653075 0.804886 0.525431 0.0402844 0.846814 0.84876 ], [ 0.709245 0.325828 0.715952 0.719116 0.576722 0.405659 0.770104 0.259131 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), dx_maxpool = reshape(cat([ 8.0 0.0 8.0 2.0 4.0 0.0 1.0 4.0 4.0 1.0 0.0 2.0 2.0 1.0 1.0 1.0 ], [ 3.0 0.0 0.0 0.0 0.0 5.0 0.0 0.0 0.0 0.0 0.0 2.0 2.0 0.0 2.0 5.0 ], [ 4.0 0.0 2.0 6.0 0.0 0.0 4.0 0.0 3.0 0.0 0.0 8.0 8.0 0.0 6.0 1.0 ],dims=3), 4,4,3,1,1), dx_meanpool = reshape(cat([ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ], [ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ], [ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 ],dims=3), 4,4,3,1,1)), "k3s1p1" => (size = (3, 3, 2), stride = 1, pad = 1, x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 0.716898 0.320672 0.24453 0.288568 0.261484 0.258469 0.121916 0.0685961 ], [ 0.73934 0.16631 0.525109 0.0223458 0.164918 0.790875 0.444085 0.469671 0.116848 0.359845 0.0653075 0.804886 0.525431 0.0402844 0.846814 0.84876 ], [ 0.709245 0.325828 0.715952 0.719116 0.576722 0.405659 0.770104 0.259131 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), dx_maxpool = reshape(cat([ 4.0 0.0 12.0 0.0 3.0 0.0 0.0 2.0 3.0 1.0 0.0 1.0 0.0 0.0 0.0 0.0 ], [ 0.0 0.0 0.0 0.0 0.0 5.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2.0 4.0 ], [ 2.0 0.0 0.0 0.0 0.0 0.0 5.0 0.0 0.0 0.0 0.0 12.0 8.0 0.0 0.0 0.0 ],dims=3), 4,4,3,1,1), dx_meanpool = reshape(cat([ 0.444444 0.666667 0.666667 0.444444 0.666667 1.0 1.0 0.666667 0.666667 1.0 1.0 0.666667 0.444444 0.666667 0.666667 0.444444 ], [ 0.444444 0.666667 0.666667 0.444444 0.666667 1.0 1.0 0.666667 0.666667 1.0 1.0 0.666667 0.444444 0.666667 0.666667 0.444444 ], [ 0.444444 0.666667 0.666667 0.444444 0.666667 1.0 1.0 0.666667 0.666667 1.0 1.0 0.666667 0.444444 0.666667 0.666667 0.444444 ],dims=3), 4,4,3,1,1)), "k3s2p1" => (size = (3, 3, 2), stride = 2, pad = 1, x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 0.716898 0.320672 0.24453 0.288568 0.261484 0.258469 0.121916 0.0685961 ], [ 0.73934 0.16631 0.525109 0.0223458 0.164918 0.790875 0.444085 0.469671 0.116848 0.359845 0.0653075 0.804886 0.525431 0.0402844 0.846814 0.84876 ], [ 0.709245 0.325828 0.715952 0.719116 0.576722 0.405659 0.770104 0.259131 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), dx_maxpool = reshape(cat([ 1.0 0.0 1.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ], [ 0.0 0.0 0.0 0.0 0.0 2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ], [ 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 ],dims=3), 4,4,3,1,1), dx_meanpool = reshape(cat([ 0.0555556 0.111111 0.0555556 0.0555556 0.111111 0.222222 0.111111 0.111111 0.0555556 0.111111 0.0555556 0.0555556 0.0555556 0.111111 0.0555556 0.0555556 ], [ 0.0555556 0.111111 0.0555556 0.0555556 0.111111 0.222222 0.111111 0.111111 0.0555556 0.111111 0.0555556 0.0555556 0.0555556 0.111111 0.0555556 0.0555556 ], [ 0.0555556 0.111111 0.0555556 0.0555556 0.111111 0.222222 0.111111 0.111111 0.0555556 0.111111 0.0555556 0.0555556 0.0555556 0.111111 0.0555556 0.0555556 ],dims=3), 4,4,3,1,1)) ) ) @testset "more maxpool and meanpool tests" begin # issue #205 function check(config, T) # CHECK DEFAULT pdims = PoolDims(config.x, config.size; stride=config.stride, padding=config.pad) x = T.(config.x) y_maxpool = NNlib.maxpool(x, pdims) y_meanpool = NNlib.meanpool(x, pdims) dy = ones(T, size(y_maxpool)...) # size(y_maxpool) == size(y_meanpool) @test isapprox(config.dx_maxpool, NNlib.∇maxpool(dy, y_maxpool, x, pdims), rtol=1e-5) @test isapprox(config.dx_meanpool, NNlib.∇meanpool(dy, y_meanpool, x, pdims), rtol=1e-5) # CHECK DIRECT y_maxpool_dir = NNlib.maxpool_direct(x, pdims) y_meanpool_dir = NNlib.meanpool_direct(x, pdims) @test y_maxpool_dir ≈ y_maxpool atol = 1e-6 @test isapprox(config.dx_maxpool, NNlib.∇maxpool_direct(dy, y_maxpool_dir, x, pdims), rtol=1e-5) @test isapprox(config.dx_meanpool, NNlib.∇meanpool_direct(dy, y_meanpool_dir, x, pdims), rtol=1e-5) end for (rank_name, config_dict) in maxpool_answer_nature for (setting_name, config) in config_dict for T in (Float32, Float64) check(config, T) end end end # issue 210 x, k = rand(Float32, 5, 2, 1, 3), (2, 1) pdims1 = NNlib.PoolDims(x, k, padding=1, stride=1) pdims2 = NNlib.PoolDims(x, k, padding=(1, 0, 0, 0), stride=1) @test maxpool(x, pdims1) isa Array{Float32,4} @test maxpool(x, pdims2) isa Array{Float32,4} # issue #229 x = ones(Float32, 4, 4, 1, 1) .* -1 pool = meanpool(x, PoolDims(x, 2, padding=1)) valid = reshape([ -0.25, -0.5, -0.25, -0.5, -1.0, -0.5, -0.25, -0.5, -0.25], (3, 3, 1, 1)) @test all(pool .== valid) # issue #484 # Description: some in-place pooling functions only accepted arrays with the same eltype. # The strict method signatures were based on assumption on the return type of `similar`. # For ReverseDiff, this caused problems, e.g. with taking derivatives of pooling # operations. # Now, if explicitly calling an in-place pooling functions, a different `yT` is allowed. for xT in (Int32, Int64, Float16, Float32, Float64, BigFloat) for (xsz, psz) in ( # test a few different data and kernel sizes ((1,1), (1,1)), ((1,2), (1,1)), ((1,2), (1,2)), ((2,1), (1,1)), ((2,1), (2,1)), ((2,2), (1,1)), ((2,2), (1,2)), ((2,2), (2,1)), ) x = ones(xT, xsz..., 1, 1) pdims = PoolDims(x, psz) for yT in (Float16, Float32, Float64, BigFloat) # `yT` is the target eltype and we do not test integer types here # because those cannot always store the pooling results. y = similar(x, yT, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, 4)) @test maxpool!(y, x, pdims) isa Array{yT} @test meanpool!(y, x, pdims) isa Array{yT} @test lpnormpool!(y, x, pdims; p=2) isa Array{yT} @test lpnormpool!(y, x, pdims; p=1.0) isa Array{yT} end end end # This is how to test #484 with ReverseDiff: x = reshape(Float32[ 1 2; 3 4 ], (2,2,1,1)) @test only(maxpool(x, (2,2))) == 4 # define typemin, because of https://github.com/JuliaDiff/ReverseDiff.jl/issues/225 Base.typemin(tr::Type{<:T}) where{V, T<:RD.TrackedReal{V, <:Any, <:Any}} = T(typemin(V)) @test RD.gradient(_x -> only(maxpool(_x,(2,2))), x)[:,:,1,1] == [0 0; 0 1] @test only(meanpool(x, (2,2))) == 2.5 @test all(==(0.25), RD.gradient(_x -> only(meanpool(_x,(2,2))), x)) end @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2) x = rand(rng, repeat([10], spatial_rank)..., 3, 2) pdims = PoolDims(x, 2) gradtest(x -> maxpool(x, pdims), x; skip = spatial_rank==2) gradtest(x -> meanpool(x, pdims), x) gradtest(x -> sum(maxpool(x, pdims)), x, skip = spatial_rank==2) gradtest(x -> sum(meanpool(x, pdims)), x) #https://github.com/FluxML/NNlib.jl/issues/188 k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format gradtest(x -> maxpool(x, k), x; skip = spatial_rank==2) gradtest(x -> meanpool(x, k), x) gradtest(x -> sum(maxpool(x, k)), x, skip = spatial_rank==2) gradtest(x -> sum(meanpool(x, k)), x) end @static if Test_Enzyme @testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2), (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!)) x = rand(rng, repeat([10], spatial_rank)..., 3, 2) pdims = PoolDims(x, 2) y = pool(x, pdims) for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) Tret == EnzymeCore.Const && continue # ERROR EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue EnzymeTestUtils.test_reverse(pool!, Tret, (y, Tdst), (x, Tsrc), (pdims, EnzymeCore.Const)) end end end ================================================ FILE: test/runtests.jl ================================================ using NNlib, Test, Statistics, Random using ChainRulesCore, ChainRulesTestUtils using Base.Broadcast: broadcasted import EnzymeTestUtils using EnzymeCore import FiniteDifferences import ForwardDiff import Zygote using Zygote: gradient using StableRNGs using Documenter using Adapt using ImageTransformations using Interpolations: Constant using KernelAbstractions using FFTW import ReverseDiff as RD # used in `pooling.jl` import Pkg using SpecialFunctions const Test_Enzyme = VERSION <= v"1.12-" DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true) # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests # ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests # ENV["NNLIB_TEST_METAL"] = "true" # uncomment to run Metal tests # ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests const rng = StableRNG(123) include("test_utils.jl") macro conditional_testset(name, skip_tests, expr) esc(quote @testset $name begin if $name ∉ $skip_tests $expr else @test_skip false end end end) end cpu(x) = adapt(CPU(), x) include("testsuite/gather.jl") include("testsuite/scatter.jl") include("testsuite/upsample.jl") include("testsuite/rotation.jl") include("testsuite/spectral.jl") include("testsuite/fold.jl") function nnlib_testsuite(Backend; skip_tests = Set{String}()) @conditional_testset "Upsample" skip_tests begin upsample_testsuite(Backend) end @conditional_testset "rotation" skip_tests begin rotation_testsuite(Backend) end @conditional_testset "Gather" skip_tests begin gather_testsuite(Backend) end @conditional_testset "Scatter" skip_tests begin scatter_testsuite(Backend) end @conditional_testset "Spectral" skip_tests begin spectral_testsuite(Backend) end @conditional_testset "Fold" skip_tests begin fold_testsuite(Backend) end end @testset verbose=true "NNlib.jl" begin if get(ENV, "NNLIB_TEST_CPU", "true") == "true" @testset "CPU" begin @testset "Doctests" begin doctest(NNlib, manual=false) end nnlib_testsuite(CPU) if Threads.nthreads(:default) > 1 @test NNlib.should_use_spawn() NNlib.@disallow_spawns begin @test NNlib.should_use_spawn() == false end else @test NNlib.should_use_spawn() == false end @testset "Activation Functions" begin include("activations.jl") include("bias_act.jl") end @testset "Attention" begin include("attention.jl") end @testset "Batched Multiplication" begin include("batchedmul.jl") end @testset "Convolution" begin include("conv.jl") include("conv_bias_act.jl") end @testset "CTC Loss" begin include("ctc.jl") end @testset "Dropout" begin include("dropout.jl") end @testset "Inference" begin include("inference.jl") end @testset "Pooling" begin include("pooling.jl") end @testset "Padding" begin include("padding.jl") end @testset "Softmax" begin include("softmax.jl") end @testset "Utilities" begin include("utils.jl") end @testset "Grid Sampling" begin include("sampling.jl") end @testset "Functions" begin include("functions.jl") end end else @info "Skipping CPU tests, set NNLIB_TEST_CPU=true to run them." end if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" Pkg.add(["CUDA", "cuDNN"]) using CUDA if CUDA.functional() @testset "CUDA" begin nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather"))) include("ext_cuda/runtests.jl") end else @info "Insufficient version or CUDA not found; Skipping CUDA tests" end else @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" end if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" Pkg.add("AMDGPU") using AMDGPU AMDGPU.versioninfo() if AMDGPU.functional() && AMDGPU.functional(:MIOpen) @testset "AMDGPU" begin nnlib_testsuite(ROCBackend) AMDGPU.synchronize(; blocking=false, stop_hostcalls=true) include("ext_amdgpu/runtests.jl") AMDGPU.synchronize(; blocking=false, stop_hostcalls=true) end else @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." end else @info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them." end if get(ENV, "NNLIB_TEST_METAL", "false") == "true" Pkg.add("Metal") using Metal if Metal.functional() @testset "Metal" begin # nnlib_testsuite(MetalBackend) include("ext_metal/runtests.jl") end else @info "Insufficient version or Metal not found; Skipping Metal tests" end else @info "Skipping Metal tests, set NNLIB_TEST_METAL=true to run them" end end ================================================ FILE: test/sampling.jl ================================================ @testset "Known gradients" begin x = ones(Float64, (2, 2, 1, 1)) grid = Array{Float64}(undef, 2, 2, 2, 1) grid[:, 1, 1, 1] .= (-1, -1) grid[:, 2, 1, 1] .= (1, -1) grid[:, 1, 2, 1] .= (-1, 1) grid[:, 2, 2, 1] .= (1, 1) ∇grid_true = Array{Float64}(undef, size(grid)) ∇grid_true[:, :, 1, 1] = [[0.0, 0.0] [-0.5, 0.0]] ∇grid_true[:, :, 2, 1] = [[0.0, -0.5] [-0.5, -0.5]] padding_mode = :zeros sampled = grid_sample(x, grid; padding_mode=padding_mode) @test x == sampled @test eltype(sampled) == Float64 external_grad = ones(size(sampled)) ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode) @test ∇input == x @test ∇grid == ∇grid_true @test eltype(∇input) == Float64 @test eltype(∇grid) == Float64 # ∇grid from FiniteDifferences is incorrent in case when 0-padding. # gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,)) padding_mode = :border fill!(∇grid_true, 0.0) sampled = grid_sample(x, grid; padding_mode=padding_mode) @test x == sampled @test eltype(sampled) == Float64 external_grad = ones(size(sampled)) ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode) @test ∇input == x @test ∇grid == ∇grid_true @test eltype(∇input) == Float64 @test eltype(∇grid) == Float64 gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,)) end @testset "Test out-of-bounds for different paddings" begin x = ones(Float64, (2, 2, 1, 1)) grid = Array{Float64}(undef, 2, 3, 2, 1) grid[:, 1, 1, 1] .= (-3, -1) grid[:, 2, 1, 1] .= (0, -1) grid[:, 3, 1, 1] .= (3, -1) grid[:, 1, 2, 1] .= (-1, 3) grid[:, 2, 2, 1] .= (0, 1) grid[:, 3, 2, 1] .= (1, 3) # With 0-padding, out-of-bound values are will contribute nothing to # the output values, because they are too far from any bound. y = grid_sample(x, grid; padding_mode=:zeros) y_true = reshape(Float64[[0, 1, 0] [0, 1, 0]], size(y)) @test y_true == y # With border-padding, out-of-bound values simly become border values # and the result should be all ones. y = grid_sample(x, grid; padding_mode=:border) y_true = ones(Float64, size(y)) @test y_true == y end @testset "Known gradients 3D" begin x = ones(Float64, (2, 2, 2, 1, 1)) # 3D input with depth=2 grid = Array{Float64}(undef, 3, 2, 2, 2, 1) # 3D grid with depth=2 grid[:, 1, 1, 1, 1] .= (-1, -1, -1) grid[:, 2, 1, 1, 1] .= (1, -1, -1) grid[:, 1, 2, 1, 1] .= (-1, 1, -1) grid[:, 2, 2, 1, 1] .= (1, 1, -1) grid[:, 1, 1, 2, 1] .= (-1, -1, 1) grid[:, 2, 1, 2, 1] .= (1, -1, 1) grid[:, 1, 2, 2, 1] .= (-1, 1, 1) grid[:, 2, 2, 2, 1] .= (1, 1, 1) ∇grid_true = Array{Float64}(undef, size(grid)) ∇grid_true[:, 1, 1, 1, 1] .= (0.0, 0.0, 0.0) ∇grid_true[:, 2, 1, 1, 1] .= (-0.5, 0.0, 0.0) ∇grid_true[:, 1, 2, 1, 1] .= (0.0, -0.5, 0.0) ∇grid_true[:, 2, 2, 1, 1] .= (-0.5, -0.5, 0.0) ∇grid_true[:, 1, 1, 2, 1] .= (0.0, 0.0, -0.5) ∇grid_true[:, 2, 1, 2, 1] .= (-0.5, 0.0, -0.5) ∇grid_true[:, 1, 2, 2, 1] .= (0.0, -0.5, -0.5) ∇grid_true[:, 2, 2, 2, 1] .= (-0.5, -0.5, -0.5) # ∇grid_true[:, :, :, 1, 1] = [ # [[0.0, 0.0, 0.0], [-0.5, 0.0, 0.0]], # [[0.0, -0.5, 0.0], [-0.5, -0.5, 0.0]] # ] # ∇grid_true[:, :, :, 2, 1] = [ # [[0.0, 0.0, -0.5], [-0.5, 0.0, -0.5]] # [[0.0, -0.5, -0.5], [-0.5, -0.5, -0.5]] # ] padding_mode = :zeros sampled = grid_sample(x, grid; padding_mode=padding_mode) @test x == sampled @test eltype(sampled) == Float64 external_grad = ones(size(sampled)) ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode) @test ∇input == x @test ∇grid == ∇grid_true @test eltype(∇input) == Float64 @test eltype(∇grid) == Float64 # ∇grid from FiniteDifferences is incorrect in case when 0-padding. # gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,)) padding_mode = :border fill!(∇grid_true, 0.0) sampled = grid_sample(x, grid; padding_mode=padding_mode) @test x == sampled @test eltype(sampled) == Float64 external_grad = ones(size(sampled)) ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode) @test ∇input == x @test ∇grid == ∇grid_true @test eltype(∇input) == Float64 @test eltype(∇grid) == Float64 gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,)) end @testset "Test out-of-bounds for different paddings 3D" begin x = ones(Float64, (2, 2, 2, 1, 1)) # 3D input with depth=2 grid = Array{Float64}(undef, 3, 2, 2, 2, 1) # 3D grid with depth=2 grid[:, 1, 1, 1, 1] .= (-3, -1, -1) grid[:, 2, 1, 1, 1] .= (0, -1, -1) grid[:, 1, 2, 1, 1] .= (-1, 3, -1) grid[:, 2, 2, 1, 1] .= (0, 1, -1) grid[:, 1, 1, 2, 1] .= (-1, -1, 3) grid[:, 2, 1, 2, 1] .= (0, -1, 3) grid[:, 1, 2, 2, 1] .= (-1, 1, 3) grid[:, 2, 2, 2, 1] .= (0, 1, 3) # With 0-padding, out-of-bound values will contribute nothing to # the output values, because they are too far from any bound. y = grid_sample(x, grid; padding_mode=:zeros) y_true = reshape(Float64[[0, 1] [0, 1] [0, 0] [0, 0]], size(y)) @test y_true == y # With border-padding, out-of-bound values simply become border values # and the result should be all ones. y = grid_sample(x, grid; padding_mode=:border) y_true = ones(Float64, size(y)) @test y_true == y end ================================================ FILE: test/softmax.jl ================================================ using Statistics: mean using NNlib: ∇softmax_data, ∇logsoftmax_data @testset "softmax integer input" begin @test softmax(Int[0, 0]) == [0.5, 0.5] end @testset "softmax on different dims" begin xs = rand(fill(2, 5)...) out = similar(xs) for (fn!, fn) in [(softmax!, softmax), (logsoftmax!, logsoftmax)], i = 1:ndims(xs) @test fn!(out, xs; dims = i) == fn(xs; dims = i) end end @testset "softmax" begin xs = rand(5, 5) @test all(sum(softmax(xs), dims = 1) .≈ 1) @test all(sum(softmax(xs; dims = 2), dims = 2) .≈ 1) @test sum(softmax(vec(xs))) ≈ 1 @test log.(softmax(xs; dims = 2)) ≈ logsoftmax(xs; dims = 2) xs = [-100_000.0, -100_000.0] @test softmax(xs) ≈ [0.5, 0.5] @test logsoftmax(xs) ≈ log.([0.5, 0.5]) xs = rand(5) @test softmax(xs) ≈ exp.(xs) ./ sum(exp.(xs)) @test logsoftmax(xs) ≈ log.(softmax(xs)) xs = Float32[1, 2, 3000.0] @test logsoftmax(xs) ≈ [-2999, -2998, 0] xs = Float32[1 2 3; 1000 2000 3000] @test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.0] y = logsoftmax(xs) @test ∇logsoftmax_data(ones(Float32, size(xs)), y) ≈ Float32[1 1 1; -1 -1 -1] y = softmax(xs) @test ∇softmax_data(ones(Float32, size(xs)), y) ≈ zeros(Float32, size(xs)) # These values precalculated using PyTorch's nn.LogSoftmax xs = [ -0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842 0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663 -1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977 ] ys = [ 0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273 -0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428 0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697 ] y = logsoftmax(xs) @test ∇logsoftmax_data(ones(size(xs)), y) ≈ ys rtol = 1e-6 y = softmax(xs) @test ∇softmax_data(ones(size(xs)), y) ≈ zeros(size(xs)) atol = 1e-6 end @testset "softmax with Inf, NaN" begin @test softmax(Float32[1 2; 3 Inf]) ≈ Float32[0.11920292 0.0; 0.880797 1.0] @test softmax(Float32[1 -Inf; 3 Inf]) ≈ Float32[0.11920292 0.0; 0.880797 1.0] @test softmax(Float32[1 Inf; 3 Inf]) ≈ Float32[0.11920292 0.5; 0.880797 0.5] @test softmax(Float32[1 2; 3 NaN]) ≈ Float32[0.11920292 NaN; 0.880797 NaN] nans=true @test softmax(Float32[1 2; 3 Inf]; dims=2) ≈ Float32[0.26894143 0.7310586; 0.0 1.0] @test softmax(Float32[1 2; 3 Inf]; dims=(:)) ≈ Float32[0.0 0.0; 0.0 1.0] @test softmax(Float32[1 2; 3 Inf]; dims=(1,2)) ≈ Float32[0.0 0.0; 0.0 1.0] @test exp.(logsoftmax(Float32[1 2; 3 Inf])) ≈ softmax(Float32[1 2; 3 Inf]) @test exp.(logsoftmax(Float32[1 -Inf; 3 Inf])) ≈ softmax(Float32[1 -Inf; 3 Inf]) @test exp.(logsoftmax(Float32[1 Inf; 3 Inf])) ≈ softmax(Float32[1 Inf; 3 Inf]) end @testset "mutating softmax" begin map([ Float64[1 2 3; 5 6 7], Float64[ -0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842 0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663 -1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977 ], ]) do xs out = similar(xs) softmax!(out, xs) @test out ≈ softmax(xs) rtol = 1e-6 logsoftmax!(out, xs) @test out ≈ logsoftmax(xs) rtol = 1e-6 @testset "$fn(Float64, $(size(xs)))" for fn in [zeros, ones, rand] Δ = fn(Float64, size(xs)) y = softmax(xs) ∇softmax!(out, Δ, xs, y) # deprecated @test out ≈ ∇softmax_data(Δ, y) rtol = 1e-6 y = logsoftmax(xs) ∇logsoftmax!(out, Δ, xs, y) # deprecated @test out ≈ ∇logsoftmax_data(Δ, y) rtol = 1e-6 end end end @testset "logsumexp" begin flogsoft(x; dims) = mean(x .- logsoftmax(x; dims = dims), dims = dims) x = rand(3, 4) @test logsumexp(x) ≈ flogsoft(x, dims = :) @test logsumexp(x; dims = 1) ≈ flogsoft(x, dims = 1) end @testset "AutoDiff" begin for f in (softmax, logsoftmax), d in (:, 1, 2) gradtest(f, (3,4); fkwargs = (dims = d,), check_rrule = true) end gradtest(x -> softmax(x) .* (1:3), 3) gradtest(x -> softmax(x) .* (1:3), (3,5), atol = 1e-4) gradtest(x -> softmax(x, dims = 2) .* (1:3), (3,5), atol = 1e-4) gradtest(x -> logsoftmax(x) .* (1:3), 3) gradtest(x -> logsoftmax(x) .* (1:3), (3,5)) gradtest(x -> logsoftmax(x, dims = 2) .* (1:3), (3,5)) for d in (:, 1, 2) gradtest(logsumexp, (3,4), fkwargs = (dims = d,)) end end @testset "Second derivatives" begin x = [1 2 3; 6 5 4] H = Zygote.hessian_dual(x -> sum(sin, softmax(x)), x) @test H ≈ Zygote.hessian_reverse(x -> sum(sin, softmax(x)), x) H2 = Zygote.hessian_dual(x -> sum(sin, logsoftmax(x)), x) @test H2 ≈ Zygote.hessian_reverse(x -> sum(sin, logsoftmax(x)), x) H3 = Zygote.hessian_dual(x -> sum(sin, logsumexp(x)), x) @test H3 ≈ Zygote.hessian_reverse(x -> sum(sin, logsumexp(x)), x) end ================================================ FILE: test/test_utils.jl ================================================ const IntOrTuple = Union{Int, NTuple{N,Int} where N} gradtest(f, dims::IntOrTuple...; kw...) = gradtest(f, randn.(Ref(rng), Float64, dims)...; kw...) # julia v1.3 compat # gradtest(f, randn.(rng, Float64, dims)...; kw...) """ Compare numerical gradient and automatic gradient given by Zygote. `f` has to be a scalar valued function. Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly defined. """ function gradtest( f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(), check_rrule = false, fdm = :central, check_broadcast = false, skip = false, broken = false, ) # TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166 # is merged if check_rrule test_rrule(f, xs...; fkwargs = fkwargs) end if check_broadcast length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args") h = (xs...) -> sum(f.(xs...)) else h = (xs...) -> sum(f(xs...; fkwargs...)) end y_true = h(xs...) if fdm == :central fdm_obj = FiniteDifferences.central_fdm(5, 1) elseif fdm == :forward fdm_obj = FiniteDifferences.forward_fdm(5, 1) elseif fdm == :backward fdm_obj = FiniteDifferences.backward_fdm(5, 1) end # @show fdm fdm_obj gs_fd = FiniteDifferences.grad(fdm_obj, h, xs...) y_ad, pull = Zygote.pullback(h, xs...) gs_ad = pull(one(y_ad)) @test y_true ≈ y_ad atol = atol rtol = rtol for (g_ad, g_fd) in zip(gs_ad, gs_fd) if skip @test_skip g_ad ≈ g_fd atol = atol rtol = rtol elseif broken @test_broken g_ad ≈ g_fd atol = atol rtol = rtol else @test g_ad ≈ g_fd atol = atol rtol = rtol end end return true end """ gputest(f, xs...; checkgrad=true, atol=1e-6, kws...) Compare gradients computed on the device vs CPU. `xs...` should already be on the device. """ function gputest(f, xs...; checkgrad=true, atol=1e-6, kws...) cpu_xs = map(x -> adapt(CPU(), x), xs) cpu_y = f(cpu_xs...; kws...) y = f(xs...; kws...) @test collect(cpu_y) ≈ collect(y) if checkgrad cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_xs...) gpu_grad = gradient((x...) -> sum(f(x...; kws...)), xs...) for (cpu_g, gpu_g) in zip(cpu_grad, adapt(CPU(), gpu_grad)) if cpu_g === nothing @test gpu_g === nothing else @test collect(cpu_g) ≈ collect(gpu_g) atol=atol end end end end ================================================ FILE: test/testsuite/fold.jl ================================================ import NNlib function fold_testsuite(Backend) device(x) = adapt(Backend(), x) gradtest_fn = Backend == CPU ? gradtest : gputest @testset "unfold wrapper" begin x = device(rand(rng, 16, 16, 3, 10)) w = device(rand(rng, 5, 5, 3, 2)) @test size(NNlib.unfold(x, size(w))) == (144, 75, 10) @test size(NNlib.unfold(x, size(w); pad=2)) == (256, 75, 10) @test size(NNlib.unfold(x, size(w); stride=2)) == (36, 75, 10) @test size(NNlib.unfold(x, size(w); dilation=2)) == (64, 75, 10) end @testset "Inverses: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = device(rand(rng, repeat([8], spatial_rank)..., 3, 2)) w = device(rand(rng, repeat([3], spatial_rank)..., 3, 3)) cdims = DenseConvDims(x, w; padding=1) y = NNlib.unfold(x, cdims) z = NNlib.fold(y, size(x), cdims) o = device(ones(eltype(x), size(x)...)) divisor = NNlib.fold(NNlib.unfold(o, cdims), size(x), cdims) @test isapprox(z ./ divisor, x, rtol=1.0e-7) # introduce stride cdims = DenseConvDims(x, w; padding=1, stride=2) y = NNlib.unfold(x, cdims) z = NNlib.fold(y, size(x), cdims) divisor = NNlib.fold(NNlib.unfold(o, cdims), size(x), cdims) @test isapprox(z ./ divisor, x, rtol=1.0e-7) end @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = device(rand(rng, repeat([5], spatial_rank)..., 3, 2)) w = device(rand(rng, repeat([3], spatial_rank)..., 3, 3)) cdims = DenseConvDims(x, w) gradtest_fn(x -> NNlib.unfold(x, cdims), x) Backend == CPU && test_rrule(NNlib.unfold, x, cdims) y = NNlib.unfold(x, cdims) gradtest_fn(y -> NNlib.fold(y, size(x), cdims), y) Backend == CPU && test_rrule(NNlib.fold, y, size(x), cdims) end end ================================================ FILE: test/testsuite/gather.jl ================================================ using NNlib: gather, gather! import EnzymeTestUtils using EnzymeCore function gather_testsuite(Backend) device(x) = adapt(Backend(), x) gradtest_fn = Backend == CPU ? gradtest : gputest T = Float32 @testset "gather scalar index" begin ## 1d src, 2d index of ints -> 2d output src = device(T[3, 4, 5, 6, 7]) index = device([ 1 2 3 4; 4 2 1 3; 3 5 5 3]) output = T[ 3 4 5 6; 6 4 3 5; 5 7 7 5] y = cpu(gather(src, index)) @test y isa Array{T,2} @test size(y) == size(index) @test y == output dst = device(T.(zero(index))) @test cpu(gather!(dst, src, index)) == output dst = device(zeros(T, 3, 5)) @test_throws ArgumentError gather!(dst, src, index) if Backend == CPU index2 = [1 2 3 4; 4 2 1 3; 3 6 5 3] @test_throws BoundsError gather!(T.(zero(index)), src, index2) end ## 1d src, 3d index of ints -> 3d output src = device(T[3, 4, 5, 6, 7]) index = device([ 1 2 3 4; 4 2 1 3; 3 5 5 3][:,:,1:1]) output = T[ 3 4 5 6; 6 4 3 5; 5 7 7 5][:,:,1:1] y = cpu(gather(src, index)) @test y isa Array{T,3} @test size(y) == size(index) @test y == output ## 2d src, 2d index of ints -> 3d output src = device(T[ 3 5 7 4 6 8]) index = device([ 1 2 3; 2 2 1; 3 1 3]) output = zeros(T, 2, 3, 3) output[:,:,1] = [ 3 5 7 4 6 8] output[:,:,2] = [ 5 5 3 6 6 4] output[:,:,3] = [ 7 3 7 8 4 8] y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) @test y == output end @testset "gather tuple index" begin ## 2d src, 1d index of 2-tuples -> 1d output src = device(T[ 3 5 7 4 6 8]) index = device([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) output = T[3, 5, 7, 4, 6, 8] y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,1} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) @test y == output ## 3d src, 2d index of 2-tuples -> 3d output n1, nsrc, nidx = 2, 3, 6 src = device(rand(T, n1, nsrc, nsrc)) index = device([ (rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx]) y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) end @testset "gather cartesian index" begin ## 2d src, 1d index of 2-tuples -> 1d output src = device(T[ 3 5 7 4 6 8]) index = device(CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])) output = T[3, 5, 7, 4, 6, 8] y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,1} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) @test y == output ## 3d src, 2d index of 2-tuples -> 3d output n1, nsrc, nidx = 2, 3, 6 src = device(rand(Float32, n1, nsrc, nsrc)) index = device([ CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx]) y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) end @testset "gather gradient for scalar index" begin src = device(Float64[3, 4, 5, 6, 7]) idx = device([ 1 2 3 4; 4 2 1 3; 3 5 5 3]) dst = device(Float64[ 3 4 5 6; 6 4 3 5; 5 7 7 5]) Backend == CPU ? gradtest_fn(xs -> gather!(dst, xs, idx), src) : gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx) Backend == CPU ? gradtest_fn(xs -> gather(xs, idx), src) : gradtest_fn((s, i) -> gather(s, i), src, idx) end @static if Test_Enzyme @testset "EnzymeRules: gather! gradient for scalar index" begin src = device(Float64[3, 4, 5, 6, 7]) idx = device([ 1 2 3 4; 4 2 1 3; 3 5 5 3]) dst = gather(src, idx) for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) Tret == EnzymeCore.Const && continue # ERROR EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) end end end @testset "gather gradient for tuple index" begin src = device(Float64[ 3 5 7 4 6 8]) idx = device([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) dst = device(Float64[3, 5, 7, 4, 6, 8]) Backend == CPU ? gradtest_fn(xs -> gather!(dst, xs, idx), src) : gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx) Backend == CPU ? gradtest_fn(xs -> gather(xs, idx), src) : gradtest_fn((s, i) -> gather(s, i), src, idx) end @testset "gather(src, IJK...)" begin x = device(reshape([1:15;], 3, 5)) i, j = device([1,2]), device([2,4]) y = gather(x, i, j) @test cpu(y) == [4, 11] y = gather(x, device([1, 2])) @test cpu(y) == [ 1 4 2 5 3 6] end end ================================================ FILE: test/testsuite/rotation.jl ================================================ function rotation_testsuite(Backend) device(x) = adapt(Backend(), x) gradtest_fn = Backend == CPU ? gradtest : gputest T = Float64 atol = T == Float32 ? 1e-3 : 1e-6 rtol = T == Float32 ? 1f-3 : 1f-6 angles = deg2rad.([0, 0.0001, 35, 90, -90, -90.0123, 170, 180, 270, 360, 450, 1234.1234]) @testset "imrotate" begin @testset "Simple test" begin arr = device(zeros((6, 6, 1, 1))); arr[3:4, 4, 1, 1] .= 1; @test all(cpu(NNlib.imrotate(arr, deg2rad(45))) .≈ [0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.29289321881345254 0.585786437626905 0.0; 0.0 0.0 0.08578643762690495 1.0 0.2928932188134524 0.0; 0.0 0.0 0.0 0.08578643762690495 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0]) end @testset "Compare with ImageTransformations" begin for sz in [(51,51,1,1), (52,52,1,1), (51,52,1,1), (52,51,1,1)] rotation_center = (sz[1:2] .+ 1) ./ 2 arr1 = device(zeros(T, sz)) arr1[15:40, 15:40, :, :] .= device(1 .+ randn((26, 26))) arr2 = device(zeros(T, (sz[1], sz[2], sz[3], 3))) arr2[15:40, 15:40, :, :] .= device(arr1[15:40, 15:40, :, :]) for method in [:nearest, :bilinear] @testset "$method" begin for angle in angles res1 = cpu(NNlib.imrotate(arr1, angle; method, rotation_center=rotation_center)) res2 = cpu(NNlib.imrotate(arr2, angle; method, rotation_center=rotation_center)) if method == :nearest res_IT = ImageTransformations.imrotate(cpu(arr1)[:, :, 1, 1], angle, axes(arr1)[1:2], method=Constant(), fillvalue=0) elseif method == :bilinear res_IT = ImageTransformations.imrotate(cpu(arr1)[:, :, 1, 1], angle, axes(arr1)[1:2], fillvalue=0) end if method == :nearest @test ≈(1 .+ res1[:, :, :, :], 1 .+ res_IT[:, :], rtol=0.5) @test ≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 1], rtol=0.5) @test ≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 2], rtol=0.5) else @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res_IT[:, :], rtol=rtol)) @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 1], rtol=rtol)) @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 2], rtol=rtol)) end end end end end end @testset "Compare for plausibilty" begin @testset "Special cases of rotation" begin arr = device(zeros(T, (10, 10, 1, 3))) arr[6, 6, :, 1] .= 1 arr[6, 6, :, 2] .= 2 arr[6, 6, :, 3] .= 3 for method in [:bilinear, :nearest] @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(0); method))) @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(90); method))) @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(180); method))) @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(270); method))) @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(360); method))) end end end @testset "Test gradients" begin for method in [:nearest, :bilinear] for angle in angles gradtest_fn( x -> NNlib.imrotate(x, angle; method), device(rand(T, 11,11,1,1)); atol) gradtest_fn( x -> NNlib.imrotate(x, angle; method), device(rand(T, 10,10,1,1)); atol) end end end end end ================================================ FILE: test/testsuite/scatter.jl ================================================ using NNlib: scatter, scatter! dsts = Dict( 0 => [3, 4, 5, 6, 7], 1 => [3 3 4 4 5; 5 5 6 6 7], ) srcs = Dict( (0, true) => ones(Int, 3, 4), (0, false) => ones(Int, 3) * collect(1:4)', (1, true) => ones(Int, 2, 3, 4), (1, false) => [1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4), ) idxs = Dict( :int => [1 2 3 4; 4 2 1 3; 3 5 5 3], :tup => [(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)], :car => CartesianIndex.( [(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)]), ) res = Dict( (+, 0, true) => [5, 6, 9, 8, 9], (+, 1, true) => [5 5 8 6 7; 7 7 10 8 9], (+, 0, false) => [4, 4, 12, 5, 5], (+, 1, false) => [4 4 12 5 5; 8 8 24 10 10], (-, 0, true) => [1, 2, 1, 4, 5], (-, 1, true) => [1 1 0 2 3; 3 3 2 4 5], (-, 0, false) => [-4, -4, -12, -5, -5], (-, 1, false) => [-4 -4 -12 -5 -5; -8 -8 -24 -10 -10], (max, 0, true) => [3, 4, 5, 6, 7], (max, 1, true) => [3 3 4 4 5; 5 5 6 6 7], (max, 0, false) => [3, 2, 4, 4, 3], (max, 1, false) => [3 2 4 4 3; 6 4 8 8 6], (min, 0, true) => [1, 1, 1, 1, 1], (min, 1, true) => [1 1 1 1 1; 1 1 1 1 1], (min, 0, false) => [1, 2, 1, 1, 2], (min, 1, false) => [1 2 1 1 2; 2 4 2 2 4], (*, 0, true) => [3, 4, 5, 6, 7], (*, 1, true) => [3 3 4 4 5; 5 5 6 6 7], (*, 0, false) => [3, 4, 48, 4, 6], (*, 1, false) => [3 4 48 4 6; 12 16 768 16 24], (/, 0, true) => [0.75, 1., 0.3125, 1.5, 1.75], (/, 1, true) => [0.75 0.75 0.25 1. 1.25; 1.25 1.25 0.375 1.5 1.75], (/, 0, false) => [1//3, 1//4, 1//48, 1//4, 1//6], (/, 1, false) => [1//3 1//4 1//48 1//4 1//6; 1//12 1//16 1//768 1//16 1//24], (mean, 0, true) => [4., 5., 6., 7., 8.], (mean, 1, true) => [4. 4. 5. 5. 6.; 6. 6. 7. 7. 8.], (mean, 0, false) => [2, 2, 3, 2.5, 2.5], (mean, 1, false) => [2. 2. 3. 2.5 2.5; 4. 4. 6. 5. 5.], ) function test_scatter(device, types, ops; pt, ops_skip_types) for T in types, IT in (Int8, Int64) PT = promote_type(T, pt) @testset "eltype $T - idx eltype $IT - $op" for op in ops skip_types = get(ops_skip_types, op, []) for idx = values(idxs), dims = [0, 1] # Tests with indices of different types. eltype(idx) == Int && (idx = IT.(idx);) idx = device(idx) dst = device(dsts[dims]) mutated = true target_y = res[(op, dims, mutated)] src = device(srcs[(dims, mutated)]) if op == / src = src .* T(2) end @test cpu(scatter!(op, T.(dst), T.(src), idx)) == T.(target_y) @test cpu(scatter!(op, T.(dst), src, idx)) == PT.(target_y) if op == / @test cpu(scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y) else @test cpu(scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y) end if T ∉ skip_types mutated = false src = device(srcs[(dims, mutated)]) @test cpu(scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)]) end end end end end function scatter_testsuite(Backend) device(x) = adapt(Backend(), x) gradtest_fn = Backend == CPU ? gradtest : gputest ops_skip_types = Dict( (+) => [], (-) => [UInt8, UInt16, UInt32, UInt64, UInt128], (*) => [UInt8, Int8], max => [BigInt], min => [BigInt]) types = if Backend == CPU [UInt8, UInt32, UInt64, Int32, Int64, Float16, Float32, Float64, BigFloat, Rational] elseif Symbol(Backend) == :CUDABackend [Int32, Int64, Float32, Float64] else # Need LLVM 15+ for atomic fmin/fmax: # https://reviews.llvm.org/D127041 # But fmin/fmax can be done by reinterpreting an array to `UInt`. [Int32, Int64, UInt32, UInt64] end ops = Backend == CPU ? (+, -, max, min, *) : (+, -, max, min) test_scatter(device, types, ops; pt=Int, ops_skip_types) types = Backend == CPU ? [Float16, Float32, BigFloat, Rational] : [Float32, Float64] ops = if Backend == CPU (/, mean) elseif Symbol(Backend) == :CUDABackend (*, /, mean) else # LLVM does not support atomic fmul/fdiv: # https://llvm.org/docs/LangRef.html#atomicrmw-instruction (mean,) end test_scatter(device, types, ops; pt=Float64, ops_skip_types=Dict()) if Backend == CPU @testset "scatter exceptions" begin idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] @test_throws AssertionError scatter!(+, copy(dsts[0]), srcs[(1, true)], idxs[:int]) @test_throws BoundsError scatter!(+, copy(dsts[1]), srcs[(1, true)], idx) end end @testset "∇scatter" begin T = Float64 fdm(op) = op == min ? :backward : :forward @testset "dstsize" begin idx = device([2, 2, 3, 4, 4]) src = device(ones(T, 3, 5)) y = scatter(+, src, idx, dstsize = (3, 6)) @test eltype(y) == T @test size(y) == (3, 6) Backend == CPU ? gradtest_fn(x -> scatter(+, x, idx; dstsize=(3, 6)), src) : gradtest_fn((x, i) -> scatter(+, x, i; dstsize=(3, 6)), src, idx) end @testset "∂dst" begin ops = if Backend == CPU || Symbol(Backend) == :CUDABackend (+, -, *, /, mean, max, min) else (+, -, mean, max, min) end for op in ops, i in (0, 1), IT in (Int8, Int64) PT = ( # If not CPU and CUDA -> use Int64 for min/max. Backend != CPU && Symbol(Backend) != :CUDABackend && (op == max || op == min)) ? Int64 : T src = device(srcs[(i, true)]) idx = device(IT.(idxs[:int])) dst = device(PT.(dsts[i])) Backend == CPU ? gradtest_fn(x -> scatter!(op, copy(x), src, idx), dst; fdm=fdm(op)) : gradtest_fn((x, s, i) -> scatter!(op, x, s, i), dst, src, idx) end end @testset "∂src" begin ops = if Backend == CPU || Symbol(Backend) == :CUDABackend (+, -, *, /, mean, max, min) else (+, -, mean, max, min) end for op in ops, i in (0, 1), IT in (Int8, Int64) PT = ( # If not CPU and CUDA -> use Int64 for min/max. Backend != CPU && Symbol(Backend) != :CUDABackend && (op == max || op == min)) ? Int64 : T src = PT.(device(srcs[(i, false)])) idx = device(IT.(idxs[:int])) Backend == CPU ? gradtest_fn(xs -> scatter(op, xs, idx), src; fdm=fdm(op)) : gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx) end end @static if Test_Enzyme @testset "EnzymeRules" begin idx = device([2, 2, 3, 4, 4]) src = device(ones(T, 3, 5)) for op in (+, -) dst = scatter(op, src, idx) for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) Tret == EnzymeCore.Const && continue # ERROR EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue EnzymeTestUtils.test_reverse(scatter!, Tret, (op, EnzymeCore.Const), (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const)) end end end end end end ================================================ FILE: test/testsuite/spectral.jl ================================================ function spectral_testsuite(Backend) cpu(x) = adapt(CPU(), x) device(x) = adapt(Backend(), x) gradtest_fn = Backend == CPU ? gradtest : gputest @testset "Window functions" begin for window_fn in (hann_window, hamming_window) @inferred window_fn(10, Float32) @inferred window_fn(10, Float64) w = window_fn(10) @test length(w) == 10 @test eltype(w) == Float32 wp = window_fn(10; periodic=false) @test wp[1:5] ≈ reverse(wp[6:10]) @test window_fn(10; periodic=true) ≈ window_fn(10 + 1; periodic=false)[1:10] end end @testset "STFT" for batch in ((), (3,)) @testset "Grads" begin if Backend != CPU x = rand(Float32, 16, batch...) window = hann_window(16) gradtest_fn(s -> abs.(stft(s; n_fft=16)), x) gradtest_fn((s, w) -> abs.(stft(s; n_fft=16, window=w)), x, window) x = rand(Float32, 2045, batch...) n_fft = 256 window = hann_window(n_fft) gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w)), x, window) gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w, center=false)), x, window) gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w, normalized=true)), x, window) end end @testset "Batch $batch" begin x = device(ones(Float32, 16, batch...)) # TODO fix type stability for pad_reflect # @inferred stft(x; n_fft=16) bd = ntuple(_ -> Colon(), length(batch)) y = stft(x; n_fft=16) @test size(y) == (9, 5, batch...) @test all(real(cpu(y))[1, :, bd...] .≈ 16) xx = istft(y; n_fft=16) @test size(xx) == (16, batch...) @test cpu(x) ≈ cpu(xx) # Test multiple hops. x = device(rand(Float32, 2048, batch...)) y = stft(x; n_fft=1024) xx = istft(y; n_fft=1024) @test cpu(x) ≈ cpu(xx) # Test odd sizes. x = device(rand(Float32, 1111, batch...)) y = stft(x; n_fft=256) xx = istft(y; n_fft=256, original_length=size(x, 1)) @test cpu(x) ≈ cpu(xx) # Output from inverse is cropped on the right # without knowing the original size. xx = istft(y; n_fft=256) @test length(xx) < length(x) @test cpu(x)[[1:s for s in size(xx)]...] ≈ cpu(xx) # Test different options. # Normalized. x = device(rand(Float32, 1234, batch...)) y = stft(x; n_fft=512, normalized=true) xx = istft(y; n_fft=512, normalized=true, original_length=size(x, 1)) @test cpu(x) ≈ cpu(xx) # With window. window = device(hann_window(512)) y = stft(x; n_fft=512, window) xx = istft(y; n_fft=512, window, original_length=size(x, 1)) @test cpu(x) ≈ cpu(xx) # Hop. for hop_length in (32, 33, 255, 256, 511, 512) y = stft(x; n_fft=512, hop_length) xx = istft(y; n_fft=512, hop_length, original_length=size(x, 1)) @test cpu(x) ≈ cpu(xx) end # N FFT. for n_fft in (32, 33, 64, 65, 128, 129, 512) y = stft(x; n_fft) xx = istft(y; n_fft, original_length=size(x, 1)) @test cpu(x) ≈ cpu(xx) end end end @testset "Spectrogram" begin x = device(rand(Float32, 1024)) window = device(hann_window(1024)) y = stft(x; n_fft=1024, hop_length=128, window, center=true, normalized=false) spec = spectrogram(x; n_fft=1024, hop_length=128, window, center=true, normalized=false) @test abs.(y).^2 ≈ spec # Gradient with `0`s in spectrogram. # We add small ϵ to spectrogram before computing power # to prevent `NaN` in gradient due to `abs(0)`. x = device(ones(Float32, 1024)) g = Zygote.gradient(x) do x sum(spectrogram(x; n_fft=1024, hop_length=128, window, center=true, normalized=false)) end @test !any(isnan.(g[1])) # Batched. x = device(rand(Float32, 1024, 3)) spec = spectrogram(x; n_fft=1024, hop_length=128, window, center=true, normalized=false) for i in 1:3 y = stft(x[:, i]; n_fft=1024, hop_length=128, window, center=true, normalized=false) @test abs.(y).^2 ≈ spec[:, :, i] end if Backend != CPU @testset "Grads" begin for batch in ((), (3,)) x = rand(Float32, 2045, batch...) n_fft = 256 window = hann_window(n_fft) gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w), x, window) gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w, center=false), x, window) gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w, normalized=true), x, window) end end end end @testset "Power to dB" begin x = device(rand(Float32, 1024)) window = device(hann_window(1024)) spec = spectrogram(x; pad=0, n_fft=1024, hop_length=128, window) @test spec ≈ NNlib.db_to_power(NNlib.power_to_db(spec)) @inferred NNlib.power_to_db(spec) @inferred NNlib.db_to_power(NNlib.power_to_db(spec)) end end ================================================ FILE: test/testsuite/upsample.jl ================================================ function upsample_testsuite(Backend) device(x) = adapt(Backend(), x) gradtest_fn = Backend == CPU ? gradtest : gputest T = Float32 # TODO test against all supported eltypes for each backend. atol = T == Float32 ? 1e-3 : 1e-6 @testset "upsample_nearest, integer scale via reshape" begin x = device(reshape(T[1 2; 3 4], (2,2,1,1))) @test cpu(upsample_nearest(x, (3,3)))[1,:] == [1,1,1, 2,2,2] y = upsample_nearest(x, (2,3)) @test size(y) == (4,6,1,1) y2 = upsample_nearest(x, size=(4,6)) @test cpu(y) ≈ cpu(y2) @test cpu(∇upsample_nearest(y, (2,3)))[:, :, 1, 1] == [6 12; 18 24] gradtest_fn( x -> upsample_nearest(x, (2,3)), device(rand(T, 2,2,1,1)); atol) gradtest_fn( x -> upsample_nearest(x, size=(4,6)), device(rand(T, 2,2,1,1)); atol) @test_throws ArgumentError ∇upsample_nearest(y, (2,4)) @test_throws ArgumentError upsample_nearest(x, (1,2,3,4,5)) @test_throws ArgumentError upsample_nearest(x, size=(3,4)) end @testset "Linear upsampling (1D)" begin x = T[1,2,3,4] x = hcat(x,x,x)[:,:,:] y = collect(1:1//3:4) y = hcat(y,y,y)[:,:,:] xd = device(x) @test y ≈ cpu(upsample_linear(xd, 2.5)) @test y ≈ cpu(upsample_linear(xd; size=10)) gradtest_fn(x -> upsample_linear(x, 2.5), xd; atol) end @testset "Bilinear upsampling (2D)" begin x = Float32[1 2; 3 4][:,:,:,:] x = cat(x,x; dims=3) x = cat(x,x; dims=4) # this output matches the one of pytorch v1.5.0 # nn.UpsamplingBilinear2d(scale_factor=(3,2), align_corners=True) # for above x y_true = Float32[ 1//1 4//3 5//3 2//1; 7//5 26//15 31//15 12//5; 9//5 32//15 37//15 14//5; 11//5 38//15 43//15 16//5; 13//5 44//15 49//15 18//5; 3//1 10//3 11//3 4//1][:,:,:,:] y_true = cat(y_true, y_true; dims=3) y_true = cat(y_true, y_true; dims=4) xd = device(x) y = upsample_bilinear(xd, (3, 2)) @test size(y) == size(y_true) @test eltype(y) == Float32 @test cpu(y) ≈ y_true gradtest_fn(x -> upsample_bilinear(x, (3, 2)), xd; atol) # additional grad check, also compliant with pytorch o = ones(Float32,6,4,2,1) grad_true = 6*ones(Float32,2,2,2,1) @test cpu(∇upsample_bilinear(device(o); size = (2,2))) ≈ grad_true # CPU only tests. y_true_2 = Rational{Int}[1//1 5//4 6//4 7//4 2//1; 3//2 7//4 8//4 9//4 5//2; 4//2 9//4 10//4 11//4 6//2; 5//2 11//4 12//4 13//4 7//2; 3//1 13//4 14//4 15//4 4//1][:,:,:,:] y_true_2 = cat(y_true_2, y_true_2; dims=3) y_true_2 = cat(y_true_2, y_true_2; dims=4) # check for real-valued single-number argument and type stability for rationals y_rational = upsample_bilinear(Rational{Int}.(x), 2.5) @test eltype(y_rational) == Rational{Int} @test y_rational == y_true_2 # check Integer support for forward pass # grads are always assumed to be floats, so no extension there x = UInt8[1 3; 3 5][:,:,:,:] y_true_int = UInt8[1 2 3; 2 3 4; 3 4 5][:,:,:,:] y = upsample_bilinear(x, 1.5) @test eltype(y) == UInt8 @test y == y_true_int end @testset "Trilinear upsampling (3D)" begin # Layout: WHDCN, where D is depth # we generate data which is constant along W & H and differs in D # then we upsample along all dimensions x = ones(T, 3,3,3,1,1) x[:,:,1,:,:] .= 1. x[:,:,2,:,:] .= 2. x[:,:,3,:,:] .= 3. y_true = ones(T, 5,5,5,1,1) y_true[:,:,1,:,:] .= 1. y_true[:,:,2,:,:] .= 1.5 y_true[:,:,3,:,:] .= 2. y_true[:,:,4,:,:] .= 2.5 y_true[:,:,5,:,:] .= 3. xd = device(x) y = upsample_trilinear(xd; size=(5,5,5)) @test size(y) == size(y_true) @test eltype(y) == T @test collect(y) ≈ collect(y_true) gradtest_fn( x -> upsample_trilinear(x, (2,2,2)), xd; atol=(T == Float32) ? 1e-2 : 1e-5) # This test only works when `align_corners=false`. o = device(ones(Float32,8,8,8,1,1)) grad_true = 8 * ones(Float32,4,4,4,1,1) @test cpu(∇upsample_trilinear(o; size=(4,4,4), align_corners=false)) ≈ grad_true end @testset "pixel_shuffle" begin x = reshape(1:16, (2, 2, 4, 1)) # [:, :, 1, 1] = # 1 3 # 2 4 # [:, :, 2, 1] = # 5 7 # 6 8 # [:, :, 3, 1] = # 9 11 # 10 12 # [:, :, 4, 1] = # 13 15 # 14 16 y_true = [1 9 3 11 5 13 7 15 2 10 4 12 6 14 8 16][:,:,:,:] y = pixel_shuffle(device(x), 2) @test size(y) == size(y_true) @test y_true == cpu(y) x = reshape(1:32, (2, 2, 8, 1)) y_true = zeros(Int, 4, 4, 2, 1) y_true[:,:,1,1] .= [ 1 9 3 11 5 13 7 15 2 10 4 12 6 14 8 16 ] y_true[:,:,2,1] .= [ 17 25 19 27 21 29 23 31 18 26 20 28 22 30 24 32] y = pixel_shuffle(device(x), 2) @test size(y) == size(y_true) @test y_true == cpu(y) x = reshape(1:4*3*27*2, (4,3,27,2)) y = pixel_shuffle(device(x), 3) @test size(y) == (12, 9, 3, 2) # batch dimension is preserved x1 = x[:,:,:,[1]] x2 = x[:,:,:,[2]] y1 = pixel_shuffle(device(x1), 3) y2 = pixel_shuffle(device(x2), 3) @test cpu(cat(y1, y2, dims=4)) == cpu(y) for d in [1, 2, 3] r = rand(1:5) n = rand(1:5) c = rand(1:5) insize = rand(1:5, d) x = rand(insize..., r^d*c, n) xd = device(x) y = pixel_shuffle(xd, r) @test size(y) == ((r .* insize)..., c, n) gradtest_fn(x -> pixel_shuffle(x, r), xd) end end @testset "Complex-valued upsample" begin for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear]) for (k, interp) in zip((2, ntuple(_ -> 2, d)), [method, upsample_nearest]) x = device(randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1)) upsize = (8, 16, 24)[1:d] xup = interp(x, k) @test size(xup)[1:d] == upsize @test cpu(real(xup)) == cpu(interp(real(x), k)) @test cpu(imag(xup)) == cpu(interp(imag(x), k)) upsize = (8,24,48)[1:d] xup = interp(x; size=upsize) @test size(xup)[1:d] == upsize @test cpu(real(xup)) == cpu(interp(real(x), size=upsize)) @test cpu(imag(xup)) == cpu(interp(imag(x), size=upsize)) end end end end ================================================ FILE: test/utils.jl ================================================ @testset "within_gradient" begin @test NNlib.within_gradient([1.0]) === false @test gradient(x -> NNlib.within_gradient(x) * x, 2.0) == (1.0,) @test NNlib.within_gradient([ForwardDiff.Dual(1.0, 2)]) === true end @testset "maximum_dims" begin ind1 = [1,2,3,4,5,6] @test NNlib.maximum_dims(ind1) == (6,) ind2 = [(3,4,5), (1,2,3), (2,3,9)] @test NNlib.maximum_dims(ind2) == (3,4,9) ind3 = [(3,4,5) (1,2,3) (2,3,9); (4,6,2) (5,3,2) (4,4,4)] @test NNlib.maximum_dims(ind3) == (5,6,9) ind4 = CartesianIndex.( [(3,4,5) (1,2,3) (2,3,9); (4,6,2) (5,3,2) (4,4,4)]) @test NNlib.maximum_dims(ind4) == (5,6,9) end @testset "reverse_indices" begin res = [ CartesianIndex.([(1,1), (2,3)]), CartesianIndex.([(1,2), (2,2)]), CartesianIndex.([(3,1), (1,3), (2,4), (3,4)]), CartesianIndex.([(2,1), (1,4)]), CartesianIndex.([(3,2), (3,3)]) ] idx = [1 2 3 4; 4 2 1 3; 3 5 5 3] @test NNlib.reverse_indices(idx) == res @test NNlib.reverse_indices(idx) isa typeof(res) idx = [(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)] @test NNlib.reverse_indices(idx) == res @test NNlib.reverse_indices(idx) isa typeof(res) idx = CartesianIndex.( [(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); (3,) (5,) (5,) (3,)]) @test NNlib.reverse_indices(idx) == res @test NNlib.reverse_indices(idx) isa typeof(res) end