Showing preview only (707K chars total). Download the full file or copy to clipboard to get everything.
Repository: FluxML/NNlib.jl
Branch: master
Commit: d3c38a49d9cf
Files: 141
Total size: 634.0 KB
Directory structure:
gitextract_0xi0432k/
├── .buildkite/
│ └── pipeline.yml
├── .codecov.yml
├── .github/
│ ├── copilot-instructions.md
│ ├── dependabot.yml
│ └── workflows/
│ ├── BenchmarkTrigger.yml
│ ├── CompatHelper.yml
│ ├── Downstream.yml
│ ├── TagBot.yml
│ ├── ci.yml
│ ├── clean_preview.yml
│ └── pr_comment.yml
├── .gitignore
├── LICENSE.md
├── Project.toml
├── README.md
├── benchmark/
│ ├── Project.toml
│ ├── benchmarks.jl
│ ├── perf_report.jl
│ └── runbenchmarks.jl
├── docs/
│ ├── .gitignore
│ ├── Project.toml
│ ├── make.jl
│ └── src/
│ ├── assets/
│ │ ├── flux.css
│ │ └── jfk.flac
│ ├── audio.md
│ ├── index.md
│ └── reference.md
├── ext/
│ ├── NNlibAMDGPUExt/
│ │ ├── NNlibAMDGPUExt.jl
│ │ ├── activations.jl
│ │ ├── batched_mul.jl
│ │ ├── conv.jl
│ │ └── pool.jl
│ ├── NNlibCUDACUDNNExt/
│ │ ├── NNlibCUDACUDNNExt.jl
│ │ ├── activations.jl
│ │ ├── batchnorm.jl
│ │ ├── conv.jl
│ │ ├── pooling.jl
│ │ └── softmax.jl
│ ├── NNlibCUDAExt/
│ │ ├── NNlibCUDAExt.jl
│ │ ├── activations.jl
│ │ ├── batchedadjtrans.jl
│ │ ├── batchedmul.jl
│ │ ├── ctc.jl
│ │ ├── sampling.jl
│ │ ├── scatter.jl
│ │ └── utils.jl
│ ├── NNlibEnzymeCoreExt/
│ │ └── NNlibEnzymeCoreExt.jl
│ ├── NNlibFFTWExt/
│ │ ├── NNlibFFTWExt.jl
│ │ └── stft.jl
│ ├── NNlibForwardDiffExt.jl
│ ├── NNlibMetalExt.jl
│ └── NNlibSpecialFunctionsExt.jl
├── src/
│ ├── NNlib.jl
│ ├── activations.jl
│ ├── attention.jl
│ ├── audio/
│ │ ├── mel.jl
│ │ ├── spectrogram.jl
│ │ └── stft.jl
│ ├── batched/
│ │ ├── batchedadjtrans.jl
│ │ └── batchedmul.jl
│ ├── bias_act.jl
│ ├── conv.jl
│ ├── conv_bias_act.jl
│ ├── ctc.jl
│ ├── deprecations.jl
│ ├── dim_helpers/
│ │ ├── ConvDims.jl
│ │ ├── DenseConvDims.jl
│ │ ├── DepthwiseConvDims.jl
│ │ └── PoolDims.jl
│ ├── dim_helpers.jl
│ ├── dropout.jl
│ ├── fold.jl
│ ├── functions.jl
│ ├── gather.jl
│ ├── gemm.jl
│ ├── impl/
│ │ ├── conv_direct.jl
│ │ ├── conv_im2col.jl
│ │ ├── depthwiseconv_direct.jl
│ │ ├── depthwiseconv_im2col.jl
│ │ ├── padding_edges.jl
│ │ └── pooling_direct.jl
│ ├── normalization.jl
│ ├── padding.jl
│ ├── pooling.jl
│ ├── rotation.jl
│ ├── sampling.jl
│ ├── scatter.jl
│ ├── softmax.jl
│ ├── upsample.jl
│ └── utils.jl
└── test/
├── Project.toml
├── activations.jl
├── attention.jl
├── batchedmul.jl
├── bias_act.jl
├── conv.jl
├── conv_bias_act.jl
├── ctc.jl
├── dropout.jl
├── ext_amdgpu/
│ ├── activations.jl
│ ├── attention.jl
│ ├── batched_mul.jl
│ ├── batched_repr.jl
│ ├── conv.jl
│ ├── dropout.jl
│ ├── pool.jl
│ ├── runtests.jl
│ ├── softmax.jl
│ └── storage_type.jl
├── ext_cuda/
│ ├── activations.jl
│ ├── batchedadjtrans.jl
│ ├── batchedmul.jl
│ ├── batchnorm.jl
│ ├── conv.jl
│ ├── ctc.jl
│ ├── dropout.jl
│ ├── fold.jl
│ ├── gather.jl
│ ├── pooling.jl
│ ├── runtests.jl
│ ├── sampling.jl
│ ├── scatter.jl
│ ├── softmax.jl
│ └── test_utils.jl
├── ext_metal/
│ ├── activations.jl
│ └── runtests.jl
├── functions.jl
├── inference.jl
├── padding.jl
├── pooling.jl
├── runtests.jl
├── sampling.jl
├── softmax.jl
├── test_utils.jl
├── testsuite/
│ ├── fold.jl
│ ├── gather.jl
│ ├── rotation.jl
│ ├── scatter.jl
│ ├── spectral.jl
│ └── upsample.jl
└── utils.jl
================================================
FILE CONTENTS
================================================
================================================
FILE: .buildkite/pipeline.yml
================================================
steps:
- label: ":julia: Julia {{matrix.julia}} - CUDA GPU"
command:
- echo 'CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"' >> test/Project.toml
- echo 'cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"' >> test/Project.toml
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliagpu"
cuda: "*"
env:
JULIA_NUM_THREADS: 4
NNLIB_TEST_CUDA: "true"
NNLIB_TEST_CPU: "false"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 180
matrix:
setup:
julia:
- "1.10"
- "1"
- "nightly"
adjustments:
- with:
julia: "nightly"
soft_fail: true
- label: ":julia: Julia {{matrix.julia}} - AMD GPU"
command:
- echo 'AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"' >> test/Project.toml
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
timeout_in_minutes: 180
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
NNLIB_TEST_AMDGPU: "true"
NNLIB_TEST_CPU: "false"
JULIA_NUM_THREADS: 4
matrix:
setup:
julia:
# - "1.10"
- "1"
# - "nightly"
# adjustments:
# - with:
# julia: "nightly"
# soft_fail: true
- label: ":julia: Julia {{matrix.julia}} - Metal GPU"
command:
- echo 'Metal = "dde4c033-4e86-420c-a63e-0dd931031962"' >> test/Project.toml
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliaecosystem"
os: "macos"
arch: "aarch64"
timeout_in_minutes: 180
env:
NNLIB_TEST_METAL: "true"
NNLIB_TEST_CPU: "false"
JULIA_NUM_THREADS: 4
matrix:
setup:
julia:
# - "1.10"
- "1"
# - "nightly "
# adjustments:
# - with:
# julia: "nightly"
# soft_fail: true
- label: "Benchmarks"
plugins:
- JuliaCI/julia#v1:
version: 1
env:
JULIA_NUM_THREADS: 4
command:
- julia --project=benchmark -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
- julia --project=benchmark benchmark/runbenchmarks.jl
- printf '%b\n' "$(cat benchmark/report.md)" | buildkite-agent annotate --style 'info'
agents:
queue: "juliagpu"
if: build.pull_request.labels includes "benchmark"
timeout_in_minutes: 30
env:
SECRET_CODECOV_TOKEN: "IlEMvDI6RciJQr5eX7qBBpHYFAe8+Svf3lNJh9gZi0MeJZQvMZWzHfW/lVncA9d9K+gDBBTv/zwqF86xOaIFLuACNdcGZiGgHS+NGeXN5CEppjqLnqKuaeHmLgJ43jygxRwgF88LhwTGcHG7pmESIp1Bn3Jd23UUv4t8hJLBDF+KJLZMefzCXnEVzfwJYxhJktnKJPA4dOv59w33Vj1x5uCYZbQlLP54IJPBm8UGdXS+JrUX8Z7lhxbkJUi6c+R6cvVBw27uRjF0pUJY26mt1frx8MzTGTOweXTpi+Kc5JhzlokMlan17j6T/b7qMC13IuKopfqu1GhkSBQD3ZhQqA==;U2FsdGVkX19l7JMB48k4oJHLoaqC7/MmvQWmaiBxRN472ZC6AcQ0uCBRy6Fw8tI0YcjIxKDScaBnJ2v/deOfhg=="
================================================
FILE: .codecov.yml
================================================
comment: false
================================================
FILE: .github/copilot-instructions.md
================================================
# NNlib.jl Copilot Instructions
## Repository Overview
NNlib.jl is a library providing fundamental neural network operations and primitives for Julia. It is primarily used by Flux.jl but can be used independently. The library provides:
- Activation functions (sigmoid, relu, gelu, etc.)
- Convolution and pooling operations
- Attention mechanisms
- Batched matrix operations
- Neural network utilities (dropout, normalization, etc.)
- GPU acceleration support (CUDA, AMDGPU)
## Project Structure
```
NNlib.jl/
├── src/ # Core library implementation
│ ├── NNlib.jl # Main module file
│ ├── activations.jl # Activation functions
│ ├── attention.jl # Attention mechanisms
│ ├── conv.jl # Convolution operations
│ ├── pooling.jl # Pooling operations
│ ├── batched/ # Batched operations
│ └── impl/ # Implementation details
├── ext/ # Package extensions for GPU backends
│ ├── NNlibCUDAExt/ # CUDA-specific implementations
│ ├── NNlibAMDGPUExt/ # AMDGPU-specific implementations
│ └── NNlibCUDACUDNNExt/ # cuDNN-specific implementations
├── test/ # Test suite
└── docs/ # Documentation
```
## Julia Version
- Minimum Julia version: 1.10
- CI tests on: minimum julia version, latest stable (1.x), and pre-release versions
## Coding Standards
### Julia Conventions
1. **Naming**:
- Functions: lowercase with underscores (e.g., `dot_product_attention`)
- Types: PascalCase (e.g., `ConvDims`, `PoolDims`)
- Constants: UPPERCASE with underscores (e.g., `ACTIVATIONS`)
2. **Documentation**:
- Use Julia docstrings (""" ... """) for all exported functions
- Include examples in docstrings where appropriate
- Keep documentation up-to-date with implementation changes
3. **Type Annotations**:
- Use type parameters and abstract types for generic implementations
- Leverage Julia's multiple dispatch for specialized implementations
- Define clear type hierarchies (e.g., `DenseConvDims`, `DepthwiseConvDims`)
4. **Performance**:
- Prefer in-place operations where appropriate (functions ending with `!`)
- Use `@inbounds` judiciously when bounds checking is verified
- Consider thread safety for multi-threaded operations
- Use `NNlib.@disallow_spawns` to control threading behavior
### Code Organization
1. **Core Implementations**: CPU implementations go in `src/`
2. **GPU Extensions**: GPU-specific code belongs in `ext/` as package extensions
3. **Tests**: Mirror the structure of `src/` in `test/`
4. **Gradients**: Define gradients using ChainRules.jl (`rrule` functions)
## Testing
### Test Infrastructure
- Uses the standard Julia `Test` framework
- Tests are organized to mirror the source structure
- GPU tests are conditional (controlled by environment variables)
### Running Tests
```julia
# Run all CPU tests
julia --project -e 'using Pkg; Pkg.test()'
# Run tests with threading
JULIA_NUM_THREADS=4 julia --project -e 'using Pkg; Pkg.test()'
```
### Test Patterns
1. **Activation Functions**: Test at specific values (0.0, 1.0, -1.0) and verify expected outputs
2. **Gradient Tests**: Use `ChainRulesTestUtils` for gradient correctness
3. **Type Stability**: Use `@inferred` where appropriate
4. **GPU Tests**: Conditional testing based on environment variables:
- `ENV["NNLIB_TEST_CUDA"]` for CUDA tests
- `ENV["NNLIB_TEST_AMDGPU"]` for AMDGPU tests
### Writing New Tests
- Include tests for edge cases (zero inputs, negative values, boundary conditions)
- Test both forward pass and gradients (using ChainRulesTestUtils)
- For array operations, test multiple dimensions and batch sizes
- Include tests for type stability when performance-critical
## Dependencies
### Core Dependencies
- **ChainRulesCore**: For automatic differentiation support
- **KernelAbstractions**: For GPU kernel abstractions
- **Adapt**: For moving data between CPU/GPU
- **GPUArraysCore**: GPU array interface
### Weak Dependencies (Extensions)
- **CUDA.jl/cuDNN**: NVIDIA GPU support
- **AMDGPU.jl**: AMD GPU support
- **FFTW**: Fast Fourier transforms
- **ForwardDiff**: Forward-mode AD support
- **EnzymeCore**: Enzyme AD support
### Adding New Dependencies
- Consider whether the dependency should be a weak dependency (extension)
- Update `Project.toml` with version constraints
- Ensure compatibility with supported Julia versions
- Run full test suite after adding dependencies
## GPU Support
NNlib uses Julia's package extension system for GPU backends:
1. **CUDA**: Load with `using NNlib, CUDA, cuDNN`
2. **AMDGPU**: Load with `using NNlib, AMDGPU`
### GPU Implementation Guidelines
- Keep GPU-specific code in appropriate extensions (`ext/` directory)
- Provide CPU fallback implementations in `src/`
- Test GPU implementations separately (conditional on hardware availability)
- Use KernelAbstractions for portable GPU kernels when possible
## Build and CI/CD
### Continuous Integration
- **CI Workflow**: `.github/workflows/ci.yml`
- Tests on Linux (always), Windows, and macOS
- Tests with different Julia versions (LTS, stable, pre-release)
- Tests with different thread counts
### Additional Workflows
- **TagBot**: Automatic release tagging
- **CompatHelper**: Dependency compatibility updates
- **Downstream**: Tests dependent packages
- **BenchmarkTrigger**: Performance regression testing
## Common Tasks
### Adding a New Activation Function
1. Add function to `src/activations.jl`
2. Add to `ACTIVATIONS` tuple for automatic export
3. Define gradient with `@scalar_rule` or `rrule`
4. Add tests in `test/activations.jl` at key values
5. Document with docstring and example
### Adding a New Operation
1. Implement in appropriate file in `src/`
2. Export from `src/NNlib.jl`
3. Define gradients using ChainRules
4. Add comprehensive tests
5. Add GPU implementations in extensions if applicable
6. Document in appropriate file in `docs/src/`
### Modifying Existing Functions
1. Check for dependent code in Flux.jl and other downstream packages
2. Maintain backward compatibility or document breaking changes
3. Update tests to cover new behavior
4. Update gradients if needed
5. Consider performance implications
## Performance Considerations
1. **Memory Allocation**: Minimize allocations in hot paths
2. **Threading**: NNlib uses Julia threads for parallel operations
- Control with `NNlib.@disallow_spawns` if needed
- Thread count controlled by `JULIA_NUM_THREADS`
3. **GPU Kernels**: Optimize kernel launch parameters and memory access patterns
4. **Type Stability**: Ensure type-stable code for performance-critical paths
## Documentation
- Documentation source: `docs/src/`
- Built with Documenter.jl
- Includes API reference, examples, and guides
- Documentation tests run via `DocTestSetup`
## Related Projects
- **Flux.jl**: Primary consumer of NNlib
- **Zygote.jl**: Automatic differentiation (uses ChainRules)
- **ChainRules.jl**: Gradient definitions
- **KernelAbstractions.jl**: GPU kernel abstraction
## Getting Help
- Documentation: https://fluxml.ai/NNlib.jl/dev/
- Issues: https://github.com/FluxML/NNlib.jl/issues
- FluxML community: https://github.com/FluxML
================================================
FILE: .github/dependabot.yml
================================================
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
================================================
FILE: .github/workflows/BenchmarkTrigger.yml
================================================
name: Benchmark Trigger
on:
pull_request_target:
types: [ labeled ]
workflow_dispatch:
inputs:
pr_id:
type: string
description: id of the pull request that triggers this workflow
target_url:
type: string
description: url of target
baseline_url:
type: string
description: url of baseline
jobs:
benchmark_trigger:
if: ${{ github.event.label.name == 'benchmark' }}
runs-on: ubuntu-latest
env:
REPOSITORY: ${{ github.event.repository.full_name }}
PR_ID: ${{ github.event.inputs.pr_id || github.event.pull_request.number }}
TARGET_URL: ${{ github.event.inputs.target_url || format('{0}#{1}', github.event.pull_request.head.repo.html_url, github.event.pull_request.head.sha) }}
BASELINE_URL: ${{ github.event.inputs.baseline_url || format('{0}#{1}', github.event.pull_request.base.repo.html_url, github.event.pull_request.base.sha) }}
steps:
-
name: Get app installation token (ghs)
id: get-app-token
uses: tibdex/github-app-token@v2
with:
app_id: ${{ secrets.BENCH_APP_ID }}
installation_id: ${{ secrets.BENCH_INSTALL_ID }}
private_key: ${{ secrets.BENCH_PRIVATE_KEY }}
-
uses: benc-uk/workflow-dispatch@v1
with:
repo: FluxML/FluxMLBenchmarks.jl
ref: refs/heads/main
workflow: Benchmark.yml
token: ${{ steps.get-app-token.outputs.token }}
inputs: '{ "repository": "${{ env.REPOSITORY }}", "pr_id": "${{ env.PR_ID }}", "target_url": "${{ env.TARGET_URL }}", "baseline_url": "${{ env.BASELINE_URL }}" }'
================================================
FILE: .github/workflows/CompatHelper.yml
================================================
name: CompatHelper
on:
schedule:
- cron: 0 0 * * *
workflow_dispatch:
permissions:
contents: write
pull-requests: write
jobs:
CompatHelper:
runs-on: ubuntu-latest
steps:
- name: Check if Julia is already available in the PATH
id: julia_in_path
run: which julia
continue-on-error: true
- name: Install Julia, but only if it is not already available in the PATH
uses: julia-actions/setup-julia@v2
with:
version: '1'
arch: ${{ runner.arch }}
if: steps.julia_in_path.outcome != 'success'
- name: "Add the General registry via Git"
run: |
import Pkg
ENV["JULIA_PKG_SERVER"] = ""
Pkg.Registry.add("General")
shell: julia --color=yes {0}
- name: "Install CompatHelper"
run: |
import Pkg
name = "CompatHelper"
uuid = "aa819f21-2bde-4658-8897-bab36330d9b7"
version = "3"
Pkg.add(; name, uuid, version)
shell: julia --color=yes {0}
- name: "Run CompatHelper"
run: |
import CompatHelper
CompatHelper.main()
shell: julia --color=yes {0}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
# COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }}
================================================
FILE: .github/workflows/Downstream.yml
================================================
name: IntegrationTest
on:
push:
branches: [master]
tags: [v*]
pull_request:
# needed to allow julia-actions/cache to delete old caches that it has created
permissions:
actions: write
contents: read
jobs:
test:
name: ${{ matrix.package.repo }}/${{ matrix.package.group }}
runs-on: ${{ matrix.os }}
env:
GROUP: ${{ matrix.package.group }}
strategy:
fail-fast: false
matrix:
julia-version: [1]
os: [ubuntu-latest]
package:
- {user: FluxML, repo: Flux.jl, group: All}
- {user: FluxML, repo: Tracker.jl, group: All}
- {user: LuxDL, repo: Lux.jl, group: All}
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.julia-version }}
arch: x64
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@latest
- name: Clone Downstream
uses: actions/checkout@v6
with:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
- name: Load this and run the downstream tests
shell: julia --color=yes --project=downstream {0}
run: |
using Pkg
try
# force it to use this PR's version of the package
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
Pkg.update()
Pkg.test() # resolver may fail with test time deps
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
# If we can't resolve that means this is incompatible by SemVer and this is fine
# It means we marked this as a breaking change, so we don't need to worry about
# Mistakenly introducing a breaking change, as we have intentionally made one
@info "Not compatible with this release. No problem." exception=err
exit(0) # Exit immediately, as a success
end
env:
RETESTITEMS_NWORKERS: 4
BACKEND_GROUP: CPU # for Lux.jl
================================================
FILE: .github/workflows/TagBot.yml
================================================
name: TagBot
on:
issue_comment:
types:
- created
workflow_dispatch:
inputs:
lookback:
default: 3
permissions:
actions: read
checks: read
contents: write
deployments: read
issues: read
discussions: read
packages: read
pages: read
pull-requests: read
repository-projects: read
security-events: read
statuses: read
jobs:
TagBot:
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
runs-on: ubuntu-latest
steps:
- uses: JuliaRegistries/TagBot@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
# Edit the following line to reflect the actual name of the GitHub Secret containing your private key
ssh: ${{ secrets.DOCUMENTER_KEY }}
# ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }}
================================================
FILE: .github/workflows/ci.yml
================================================
name: CI
on:
push:
branches:
- master
- staging
- trying
tags: '*'
pull_request:
# needed to allow julia-actions/cache to delete old caches that it has created
permissions:
actions: write
contents: read
defaults:
run:
shell: bash
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.julia-threads }} thread(s)
runs-on: ${{ matrix.os }}
env:
JULIA_NUM_THREADS: ${{ matrix.julia-threads }}
strategy:
fail-fast: false
matrix:
version:
- '1.10' # uncomment when julia 1.10 is out
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'nightly'
os:
- ubuntu-latest
# - macOS-latest
# - windows-latest
julia-threads:
- '1'
include:
- os: windows-latest
version: '1'
julia-threads: '1'
- os: macOS-latest
version: '1'
julia-threads: '1'
- os: ubuntu-latest
version: '1'
julia-threads: '2'
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- name: "Run test without coverage"
uses: julia-actions/julia-runtest@v1
if: ${{ !contains(fromJson('["1"]'), matrix.version) || matrix.os != 'ubuntu-latest' }}
with:
coverage: false
- name: "Run test with coverage"
uses: julia-actions/julia-runtest@v1
if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest'
- uses: julia-actions/julia-processcoverage@v1
if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest'
- uses: codecov/codecov-action@v5
if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest'
with:
file: lcov.info
docs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v2
with:
version: '1.10'
- uses: julia-actions/cache@v2
- run: |
julia --project=docs -e '
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
- run: julia --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
================================================
FILE: .github/workflows/clean_preview.yml
================================================
# from https://github.com/CliMA/ClimaTimeSteppers.jl
name: Doc Preview Cleanup
on:
pull_request:
types: [closed]
jobs:
doc-preview-cleanup:
runs-on: ubuntu-latest
steps:
- name: Checkout gh-pages branch
uses: actions/checkout@v6
with:
ref: gh-pages
- name: Delete preview and history + push changes
run: |
if [ -d "previews/PR$PRNUM" ]; then
git config user.name "Documenter.jl"
git config user.email "documenter@juliadocs.github.io"
git rm -rf "previews/PR$PRNUM"
git commit -m "delete preview"
git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree})
git push --force origin gh-pages-new:gh-pages
fi
env:
PRNUM: ${{ github.event.number }}
================================================
FILE: .github/workflows/pr_comment.yml
================================================
name: pr_comment
on:
pull_request:
types: [labeled]
jobs:
pr_comment:
runs-on: ubuntu-latest
steps:
- name: Create PR comment
if: github.event_name == 'pull_request' && github.repository == github.event.pull_request.head.repo.full_name && github.event.label.name == 'documentation' # if this is a pull request build AND the pull request is NOT made from a fork
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b
with:
message: 'Once the documentation build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/NNlib.jl/previews/PR${{ github.event.number }}/'
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
================================================
FILE: .gitignore
================================================
*.jl.cov
*.jl.*.cov
*.jl.mem
*.o
*.so
*.dylib
*.dll
*~
\#*
deps/usr
deps.jl
*.log
.vscode/
/Manifest.toml
test/Manifest.toml
benchmark/Manifest.toml
benchmark/*.json
benchmark/report.md
================================================
FILE: LICENSE.md
================================================
The NNlib.jl package is licensed under the MIT "Expat" License:
> Copyright (c) 2017-19: Julia Computing, Inc., Mike J Innes, and Contributors
>
> Permission is hereby granted, free of charge, to any person obtaining a copy
> of this software and associated documentation files (the "Software"), to deal
> in the Software without restriction, including without limitation the rights
> to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
> copies of the Software, and to permit persons to whom the Software is
> furnished to do so, subject to the following conditions:
>
> The above copyright notice and this permission notice shall be included in all
> copies or substantial portions of the Software.
>
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
> IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
> FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
> AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
> LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
> OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
> SOFTWARE.
>
================================================
FILE: Project.toml
================================================
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.9.34"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
NNlibForwardDiffExt = "ForwardDiff"
NNlibMetalExt = "Metal"
NNlibSpecialFunctionsExt = "SpecialFunctions"
[compat]
AMDGPU = "1, 2"
Adapt = "3.2, 4"
Atomix = "0.1, 1"
CUDA = "4, 5, 6"
ChainRulesCore = "1.25"
EnzymeCore = "0.7, 0.8"
FFTW = "1.8.0"
ForwardDiff = "1"
GPUArraysCore = "0.2"
KernelAbstractions = "0.9.2"
LinearAlgebra = "1"
Metal = "1.6"
Random = "1"
ScopedValues = "1.3.0"
SpecialFunctions = "2"
Statistics = "1"
cuDNN = "1, 6"
julia = "1.10"
================================================
FILE: README.md
================================================
<img align="right" width="200px" src="https://github.com/FluxML/NNlib.jl/raw/master/docs/src/assets/logo.png">
# NNlib.jl
[![Documentation][docs-dev-img]][docs-dev-url]
[](https://github.com/FluxML/NNlib.jl/actions/workflows/ci.yml)
[](https://codecov.io/gh/FluxML/NNlib.jl)
[docs-stable-img]: https://img.shields.io/badge/docs-stable-blue.svg
[docs-stable-url]: https://fluxml.ai/NNlib.jl/stable/
[docs-dev-img]: https://img.shields.io/badge/docs-latest-blue.svg
[docs-dev-url]: https://fluxml.ai/NNlib.jl/dev/
This package provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multiplication, convolutions and pooling. Many of these are used by [Flux.jl](https://github.com/FluxML/Flux.jl), which loads this package, but they may be used independently.
For use with automatic differentiation, this package defines gradients using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). These will be seen by various packages including [Zygote.jl](https://github.com/FluxML/Zygote.jl).
GPU support is provided as package extensions (see the `ext/` folder). In order to load the extensions, use the imports
```julia
using NNlib, CUDA, cuDNN
```
for CUDA support, or
```julia
using NNlib, AMDGPU
```
for AMDGPU support.
================================================
FILE: benchmark/Project.toml
================================================
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
BenchmarkCI = "20533458-34a3-403d-a444-e18f38190b5b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
[compat]
# No compat bounds for NNlib because we may test breaking versions
ArgParse = "1"
BenchmarkCI = "0.1"
BenchmarkTools = "1.3"
PkgBenchmark = "0.2"
julia = "1.6"
================================================
FILE: benchmark/benchmarks.jl
================================================
using BenchmarkTools
using NNlib
using NNlib.ChainRulesCore: rrule
using Random
Random.seed!(1234567890)
const SUITE = BenchmarkGroup()
SUITE["activations"] = BenchmarkGroup()
for et in (Float16, Float32, Float64)
et_suite = BenchmarkGroup()
SUITE["activations"][string(et)] = et_suite
let x = rand(et, 1024, 1024), y = similar(x)
for f in NNlib.ACTIVATIONS
act = @eval($f)
et_suite[string(f)] = @benchmarkable broadcast!($act, $y, $x)
end
end
end
for (fn!, fn_bw) in [(softmax!, NNlib.∇softmax_data), (logsoftmax!, NNlib.∇logsoftmax_data)]
fn_suite = BenchmarkGroup()
SUITE[rstrip(string(fn!), '!')] = fn_suite
let SIZES = [
(128, 384, 8),
(512, 784, 8),
(768, 1024, 4),
(1024, 2048, 4),
(2048, 2048, 2),
(4096, 2048, 2),
(4096, 4096, 2),
(12288, 2048, 1)
]
for et in (Float16, Float32)
et_suite = BenchmarkGroup("fw" => BenchmarkGroup(), "bw" => BenchmarkGroup())
fn_suite[string(et)] = et_suite
for sz in SIZES
x = randn(et, sz)
y = similar(x)
dy = zero(x)
fn!(y, x)
et_suite["fw"][string(sz)] = @benchmarkable $fn!($y, $x)
et_suite["bw"][string(sz)] = @benchmarkable $fn_bw($dy, $y)
end
end
end
end
================================================
FILE: benchmark/perf_report.jl
================================================
using JLD2, NNlib, BenchmarkTools
# TODO organize and compare benchmarks using BenchmarkGroups
# We need things to go quickly here
BenchmarkTools.DEFAULT_PARAMETERS.samples = 20
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 2.5
results = Dict()
function add_result(val, keys...)
r = results
for k in keys[1:end-1]
if !haskey(r, k)
r[k] = Dict()
end
r = r[k]
end
r[keys[end]] = val
return r
end
# Modify these as needed
for rank in (2,),
N in (20, 40, 80),
C_in in (1,),
C_out in (1,),
K in (3,),
stride in (1,),
dilation in (1,),
padding in (0, 2)
benchmark_items = [
(NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, "direct"),
(NNlib.conv_im2col!, NNlib.∇conv_data_im2col!, NNlib.∇conv_filter_im2col!, DenseConvDims, "im2col"),
(NNlib.depthwiseconv_direct!, NNlib.∇depthwiseconv_data_direct!, NNlib.∇depthwiseconv_filter_direct!, DepthwiseConvDims, "direct"),
(NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, "im2col"),
]
for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in benchmark_items
x = zeros(Float32, repeat([N], rank)..., C_in, 1)
if cT == DenseConvDims
w = zeros(Float32, repeat([K], rank)..., C_in, C_out)
else
w = zeros(Float32, repeat([K], rank)..., C_out, C_in)
end
cdims = try
cT(x, w; stride=stride, dilation=dilation, padding=padding)
catch
continue
end
if cT == DenseConvDims
y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1)
else
y = zeros(Float32, NNlib.output_size(cdims)..., C_out*C_in, 1)
end
dx = similar(x)
dw = similar(w)
dy = similar(y)
t_fwd = @benchmark $(conv!)($y, $x, $w, $cdims)
t_dx = @benchmark $(∇conv_data!)($dx, $y, $w, $cdims)
t_dw = @benchmark $(∇conv_filter!)($dw, $x, $y, $cdims)
add_result(t_fwd, "conv$(rank)d", backend, cdims)
add_result(t_dx, "conv$(rank)d_data", backend, cdims)
add_result(t_dw, "conv$(rank)d_filter", backend, cdims)
@show(cdims)
@save "results.jld2" results
end
end
# Modify these as needed
for rank in (2,),
N in (20,),
K in (2, 4),
stride in (1, 2, 4)
x = zeros(Float32, repeat([N], rank)..., 1, 1)
pdims = PoolDims(x, K; stride=stride)
y = zeros(Float32, NNlib.output_size(pdims)..., 1, 1)
dx = similar(x)
for (pool, ∇pool, name) in (
(NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"),
(NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"),
(NNlib.lpnormpool!, NNlib.∇lpnormpool!, "lpnormpool"),
)
t_fwd = @benchmark $(pool)( $y, $x, $pdims)
t_data = @benchmark $(∇pool)($dx, $y, $y, $x, $pdims)
add_result(t_fwd, "$(name)$(rank)d", "direct", pdims)
add_result(t_data, "$(name)$(rank)d_data", "direct", pdims)
@show(pdims)
@save "results.jld2" results
end
end
================================================
FILE: benchmark/runbenchmarks.jl
================================================
# Adapted from
# https://github.com/kul-forbes/ProximalOperators.jl/tree/master/benchmark
using ArgParse
using PkgBenchmark
using BenchmarkCI: displayjudgement, printresultmd, CIResult
using Markdown
function markdown_report(judgement)
md = sprint(printresultmd, CIResult(judgement = judgement))
md = replace(md, ":x:" => "❌")
md = replace(md, ":white_check_mark:" => "✅")
return md
end
function parse_commandline()
s = ArgParseSettings()
@add_arg_table! s begin
"--target"
help = "the branch/commit/tag to use as target"
default = "HEAD"
"--baseline"
help = "the branch/commit/tag to use as baseline"
default = "master"
"--retune"
help = "force re-tuning (ignore existing tuning data)"
action = :store_false
end
return parse_args(s)
end
function main()
parsed_args = parse_commandline()
mkconfig(; kwargs...) =
BenchmarkConfig(
env = Dict(
"JULIA_NUM_THREADS" => get(ENV, "JULIA_NUM_THREADS", "1"),
);
kwargs...
)
target = parsed_args["target"]
group_target = benchmarkpkg(
dirname(@__DIR__),
mkconfig(id = target),
resultfile = joinpath(@__DIR__, "result-$(target).json"),
retune = parsed_args["retune"],
)
baseline = parsed_args["baseline"]
group_baseline = benchmarkpkg(
dirname(@__DIR__),
mkconfig(id = baseline),
resultfile = joinpath(@__DIR__, "result-$(baseline).json"),
)
judgement = judge(group_target, group_baseline)
report_md = markdown_report(judgement)
write(joinpath(@__DIR__, "report.md"), report_md)
display(Markdown.parse(report_md))
end
main()
================================================
FILE: docs/.gitignore
================================================
build/
site/
Manifest.toml
================================================
FILE: docs/Project.toml
================================================
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
================================================
FILE: docs/make.jl
================================================
using Documenter, NNlib
DocMeta.setdocmeta!(NNlib, :DocTestSetup,
:(using FFTW, NNlib, UnicodePlots); recursive = true)
makedocs(modules = [NNlib],
sitename = "NNlib.jl",
doctest = true,
pages = ["Home" => "index.md",
"Reference" => "reference.md",
"Audio" => "audio.md"],
format = Documenter.HTML(
canonical = "https://fluxml.ai/NNlib.jl/stable/",
# analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"),
warnonly=[:missing_docs,]
)
deploydocs(repo = "github.com/FluxML/NNlib.jl.git",
target = "build",
push_preview = true)
================================================
FILE: docs/src/assets/flux.css
================================================
@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');
body {
font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
}
nav.toc {
padding-top: 0;
background: rgb(240, 240, 240);
line-height: 2em;
cursor: default;
user-select: none;
}
h1+h2 {
margin-top: 0;
}
/* Green banner in ToC */
nav.toc > h1 {
margin-top: 0;
padding-top: 0.4em;
padding-bottom: 0.5em;
border-bottom: 5px solid white;
box-shadow: 0px -2px 5px rgb(60,60,60);
margin-bottom: 0.5em;
background: rgb(60, 150, 60);
font-style: italic;
font-weight: normal;
font-size: 50pt;
text-transform: lowercase;
text-shadow: 2px 2px 5px rgba(0,0,0,0.2);
color: white;
}
/* Reduce ToC font size */
.toctext {
font-size: 10pt;
}
/* Fade out non-clickable ToC headers */
nav.toc ul span.toctext {
color: rgb(180, 180, 180);
}
nav.toc ul .toctext {
color: rgb(100, 100, 100);
}
nav.toc ul a.toctext:hover {
color: inherit;
background: rgb(220, 220, 220);
cursor: default;
}
nav.toc li.current > .toctext {
background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);
font-weight: normal;
}
nav.toc ul.internal li.toplevel {
font-weight: normal;
}
/* Content */
article { max-width: none; }
article > p, article > ul {
max-width: 45em;
}
/* Links */
a, a:visited { color: rgb(0, 120, 0); }
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
a:hover, a:visited:hover { color: rgb(0, 80, 0); }
/* Article Links */
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }
article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }
/* Doctstrings */
article section.docstring {
padding: 0.5em 0;
border-left: none;
border-right: none;
border-bottom: none;
}
/* Code */
article pre, article p > code {
background: rgb(245, 250, 245);
}
article pre {
border: none;
max-width: none;
padding: 1em;
border-radius: 10px 0px 0px 10px;
}
.hljs-comment {
font-style: italic;
}
.hljs-number {
color: rgb(0, 150, 150);
}
================================================
FILE: docs/src/audio.md
================================================
# Reference
!!! note
Spectral functions require importing `FFTW` package to enable them.
## Window functions
```@docs
hann_window
hamming_window
```
## Spectral
```@docs
stft
istft
NNlib.power_to_db
NNlib.db_to_power
```
## Spectrogram
```@docs
melscale_filterbanks
spectrogram
```
Example:
```@example 1
using FFTW # <- required for STFT support.
using NNlib
using FileIO
using Makie, CairoMakie
CairoMakie.activate!()
waveform, sampling_rate = load("./assets/jfk.flac")
fig = lines(reshape(waveform, :))
save("waveform.png", fig)
# Spectrogram.
n_fft = 1024
spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft))
fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1]))
save("spectrogram.png", fig)
# Mel-scale spectrogram.
n_freqs = n_fft ÷ 2 + 1
fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate))
mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels)
fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1])
save("mel-spectrogram.png", fig)
nothing # hide
```
|Waveform|Spectrogram|Mel Spectrogram|
|:---:|:---:|:---:|
||||
================================================
FILE: docs/src/index.md
================================================
# NNlib.jl
`NNlib` provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multiplication, convolutions and pooling. Many of these are used by [Flux.jl](https://github.com/FluxML/Flux.jl), which loads this package, but they may be used independently.
For use with automatic differentiation, this package defines gradients using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). These will be seen by various packages including [Zygote.jl](https://github.com/FluxML/Zygote.jl).
GPU support is provided as package extensions. In order to load the extensions, use the imports
```julia
using NNlib, CUDA, cuDNN
```
for CUDA support, or
```julia
using NNlib, AMDGPU
```
for AMDGPU support.
## Threading
Various `NNlib` functions utilize available julia threads on divisible workloads. To disable this use
the `ScopedValue`-backed switch `NNlib.@disallow_spawns`
i.e.
```julia
NNlib.@disallow_spawns function_that_uses_nnlib()
```
================================================
FILE: docs/src/reference.md
================================================
# Reference
The API reference of `NNlib`.
## Activation Functions
Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.
```@docs
celu
elu
gelu
gelu_tanh
gelu_sigmoid
gelu_erf
hardsigmoid
sigmoid_fast
hardtanh
tanh_fast
leakyrelu
lisht
logcosh
logsigmoid
mish
relu
relu6
rrelu
selu
sigmoid
softplus
softshrink
softsign
swish
hardswish
tanhshrink
trelu
```
## Attention
```@docs
dot_product_attention
dot_product_attention_scores
make_causal_mask
```
## Softmax
`Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally.
```@docs
softmax
logsoftmax
```
## Pooling
`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, `MeanPool` and `lpnormpool` use `NNlib.PoolDims`, `NNlib.maxpool`, `NNlib.meanpool` and `NNlib.lpnormpool` as their backend.
```@docs
PoolDims
maxpool
meanpool
lpnormpool
```
## Padding
```@docs
pad_reflect
pad_symmetric
pad_circular
pad_repeat
pad_constant
pad_zeros
```
## Convolution
`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.
`NNlib.conv` supports complex datatypes on CPU and CUDA devices.
!!! note "AMDGPU MIOpen supports only cross-correlation (`flipkernel=true`)."
Therefore for every regular convolution (`flipkernel=false`)
kernel is flipped before calculation.
For better performance, use cross-correlation (`flipkernel=true`)
and manually flip the kernel before `NNlib.conv` call.
`Flux` handles this automatically, this is only required for direct calls.
```@docs
conv
ConvDims
depthwiseconv
DepthwiseConvDims
DenseConvDims
NNlib.unfold
NNlib.fold
```
## Upsampling
`Flux`'s `Upsample` layer uses `NNlib.upsample_nearest`, `NNlib.upsample_bilinear`, and `NNlib.upsample_trilinear` as its backend. Additionally, `Flux`'s `PixelShuffle` layer uses `NNlib.pixel_shuffle` as its backend.
```@docs
upsample_nearest
∇upsample_nearest
upsample_linear
∇upsample_linear
upsample_bilinear
∇upsample_bilinear
upsample_trilinear
∇upsample_trilinear
pixel_shuffle
```
## Rotation
Rotate images in the first two dimensions of an array.
```@docs
imrotate
∇imrotate
```
## Batched Operations
`Flux`'s `Bilinear` layer uses `NNlib.batched_mul` internally.
```@docs
batched_mul
batched_mul!
batched_adjoint
batched_transpose
batched_vec
```
## Gather and Scatter
`Flux`'s `Embedding` layer uses `NNlib.gather` as its backend.
```@docs
NNlib.gather
NNlib.gather!
NNlib.scatter
NNlib.scatter!
```
## Sampling
```@docs
grid_sample
∇grid_sample
```
## Losses
```@docs
ctc_loss
```
## Miscellaneous
```@docs
logsumexp
NNlib.glu
NNlib.within_gradient
bias_act!
```
================================================
FILE: ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl
================================================
module NNlibAMDGPUExt
using Adapt
using AMDGPU
using ChainRulesCore
using NNlib
using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
using NNlib: DenseConvDims, PoolDims
const MIOPENFloat = Union{Float16, Float32}
const ROCBatchedAdjoint{T} = BatchedAdjoint{T, <: ROCArray{T}}
const ROCBatchedTranspose{T} = BatchedTranspose{T, <: ROCArray{T}}
const ROCBatchedAdjOrTrans{T} = Union{ROCBatchedAdjoint{T}, ROCBatchedTranspose{T}}
const WrappedROCBatchedAdjOrTrans{T, N} = Adapt.WrappedArray{T, N, ROCBatchedAdjOrTrans{T}, ROCBatchedAdjOrTrans{T}}
const AnyROCBatchedAdjOrTrans = Union{ROCBatchedAdjOrTrans, WrappedROCBatchedAdjOrTrans}
function Base.convert(::Type{T}, b::AnyROCBatchedAdjOrTrans) where {T <: Array}
Base.convert(T, adapt(Array, b))
end
function Base.Array{T, N}(b::AnyROCBatchedAdjOrTrans) where {T, N}
Array{T, N}(adapt(Array, b))
end
Base.collect(b::AnyROCBatchedAdjOrTrans) = collect(adapt(Array, b))
function Base.show(
io::IO, mime::MIME{Symbol("text/plain")}, x::AnyROCBatchedAdjOrTrans,
)
show(io, mime, adapt(Array, x))
end
Base.show(io::IO, x::AnyROCBatchedAdjOrTrans) = show(io, adapt(Array, x))
Base.display(x::AnyROCBatchedAdjOrTrans) = display(adapt(Array, x))
function nnlib_padding(dims)
pd = NNlib.padding(dims)
if !all(pd[1:2:end] .== pd[2:2:end])
@warn """
MIOpen does not support asymmetric padding, defaulting to symmetric choice:
$pd -> $(pd[1:2:end]).
""" maxlog=1
end
pd[1:2:end]
end
include("batched_mul.jl")
@static if AMDGPU.functional(:MIOpen)
using AMDGPU.MIOpen
include("conv.jl")
include("pool.jl")
include("activations.jl")
else
@warn """
ROCm MIOpen is not available for AMDGPU.
NNlib has limited functionality for AMDGPU.
"""
end
end
================================================
FILE: ext/NNlibAMDGPUExt/activations.jl
================================================
for (f, op) in [
NNlib.relu => MIOpen.relu,
NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6),
NNlib.softplus => MIOpen.softrelu,
NNlib.σ => MIOpen.sigmoid,
Base.tanh => MIOpen.tanh,
# TODO define for leakyrelu, elu, etc.?
], N in 1:5
@eval function Base.materialize(
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat,$N}}}
)
return $op(bc.args[1])
end
end
Base.broadcasted(::typeof(identity), x::ROCArray{T}) where {T<:MIOPENFloat} = x
================================================
FILE: ext/NNlibAMDGPUExt/batched_mul.jl
================================================
function _blas_at(x)
Base.stride(x, 1) == 1 && return x, 'N'
Base.stride(x, 2) == 1 && return batched_transpose(x), 'T'
throw(ArgumentError("""
Unsupported array layout for batched mul.
- Size: $(size(x))
- Strides: $(strides(x))
"""))
end
function NNlib._batched_mul!(
::Type{AT}, C, A, B, α::Float16, β::Float16,
) where AT <: ROCArray{Float16}
blasA, transA = _blas_at(A)
blasB, transB = _blas_at(B)
NNlib._batched_gemm!(AT, transA, transB, α, blasA, blasB, β, C)
C
end
function NNlib._batched_gemm!(
::Type{<:ROCArray{T}}, transA::Char, transB::Char, α::T, A, B, β::T, C,
) where T <: Union{MIOPENFloat, Float64}
AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C)
end
================================================
FILE: ext/NNlibAMDGPUExt/conv.jl
================================================
function NNlib.conv!(
y::ROCArray{T, N}, x::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims,
) where {T <: MIOPENFloat, N}
if !NNlib.flipkernel(cdims)
@warn """
MIOpen supports only cross-correlation (flipkernel=true).
Therefore for every regular convolution (flipkernel=false)
kernel is flipped before calculation.
For better performance, use cross-correlation (flipkernel=true)
and manually flip the kernel before `NNlib.conv` call.
""" maxlog=1
flip_dims = ntuple(
i -> (i ≤ ndims(w) - 2) ? (size(w, i):-1:1) : Colon(),
ndims(w))
w = w[flip_dims...]
end
nd = max(0, 4 - N)
ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)
MIOpen.convolution!(
NNlib.insert_singleton_spatial_dimension(y, nd),
NNlib.insert_singleton_spatial_dimension(x, nd),
NNlib.insert_singleton_spatial_dimension(w, nd);
padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),
dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))
return y
end
function NNlib.∇conv_data!(
dx::ROCArray{T, N}, dy::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims,
) where {T <: MIOPENFloat, N}
if !NNlib.flipkernel(cdims)
@warn """
MIOpen supports only cross-correlation (flipkernel=true).
Therefore for every regular convolution (flipkernel=false)
kernel is flipped before calculation.
For better performance, use cross-correlation (flipkernel=true)
and manually flip the kernel before `NNlib.conv` call.
""" maxlog=1
flip_dims = ntuple(
i -> (i ≤ ndims(w) - 2) ? (size(w, i):-1:1) : Colon(),
ndims(w))
w = w[flip_dims...]
end
nd = max(0, 4 - N)
ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)
MIOpen.∇convolution_data!(
NNlib.insert_singleton_spatial_dimension(dx, nd),
NNlib.insert_singleton_spatial_dimension(dy, nd),
NNlib.insert_singleton_spatial_dimension(w, nd);
padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),
dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))
return dx
end
function NNlib.∇conv_filter!(
dw::ROCArray{T, N}, x::ROCArray{T, N}, dy::ROCArray{T, N}, cdims::DenseConvDims,
) where {T <: MIOPENFloat, N}
nd = max(0, 4 - N)
ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)
MIOpen.∇convolution_weight!(
NNlib.insert_singleton_spatial_dimension(dw, nd),
NNlib.insert_singleton_spatial_dimension(dy, nd),
NNlib.insert_singleton_spatial_dimension(x, nd);
padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),
dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))
if !NNlib.flipkernel(cdims)
@warn """
MIOpen supports only cross-correlation (flipkernel=true).
Therefore for every regular convolution (flipkernel=false)
kernel is flipped before calculation.
For better performance, use cross-correlation (flipkernel=true)
and manually flip the kernel before `NNlib.conv` call.
""" maxlog=1
flip_dims = ntuple(
i -> (i ≤ ndims(dw) - 2) ? (size(dw, i):-1:1) : Colon(),
ndims(dw))
dw = dw[flip_dims...]
end
return dw
end
================================================
FILE: ext/NNlibAMDGPUExt/pool.jl
================================================
for poolname in (:maxpool, :meanpool)
@eval function NNlib.$(poolname)(
x::ROCArray{T, N}, pdims::PoolDims,
) where {T <: MIOPENFloat, N}
y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N))
nd = max(0, 4 - N)
npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd)
MIOpen.$(Symbol("$(poolname)!"))(
NNlib.insert_singleton_spatial_dimension(y, nd),
NNlib.insert_singleton_spatial_dimension(x, nd);
dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),
stride=NNlib.stride(npdims), do_backward=false)
return y
end
@eval function ChainRulesCore.rrule(
::typeof(NNlib.$(poolname)), x::ROCArray{T, N}, pdims::PoolDims,
) where {T <: MIOPENFloat, N}
y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N))
nd = max(0, 4 - N)
npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd)
# `workspace` is used in the pullback.
_, workspace = MIOpen.$(Symbol("$(poolname)!"))(
NNlib.insert_singleton_spatial_dimension(y, nd),
NNlib.insert_singleton_spatial_dimension(x, nd);
dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),
stride=NNlib.stride(npdims))
function _pooling_pullback(Δ)
dx = similar(x)
MIOpen.$(Symbol("∇$(poolname)!"))(
NNlib.insert_singleton_spatial_dimension(dx, nd),
NNlib.insert_singleton_spatial_dimension(unthunk(Δ), nd),
NNlib.insert_singleton_spatial_dimension(y, nd),
NNlib.insert_singleton_spatial_dimension(x, nd);
dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),
stride=NNlib.stride(npdims), workspace)
return NoTangent(), dx, NoTangent()
end
y, _pooling_pullback
end
end
================================================
FILE: ext/NNlibCUDACUDNNExt/NNlibCUDACUDNNExt.jl
================================================
module NNlibCUDACUDNNExt
using NNlib
using cuDNN
using CUDA
using Random, Statistics
using cuDNN: handle, with_workspace, cudnnTensorDescriptor, cudnnFilterDescriptor,
cudnnDataType, math_mode, CUDNN_DEFAULT_REORDER, CUDNN_CROSS_CORRELATION,
CUDNN_NOT_PROPAGATE_NAN, CUDNN_TENSOR_NCHW, dim4
cudnnversion() = cuDNN.version()
function nnlibPadding(dims)
pd = NNlib.padding(dims)
if !all(pd[1:2:end] .== pd[2:2:end])
@warn "cuDNN does not support asymmetric padding; defaulting to symmetric choice" maxlog=1
end
return pd[1:2:end]
end
include("conv.jl")
include("pooling.jl")
include("softmax.jl")
include("activations.jl")
include("batchnorm.jl")
end # module
================================================
FILE: ext/NNlibCUDACUDNNExt/activations.jl
================================================
# Activation
using Base.Broadcast
using cuDNN: cudnnActivationForward!, cudnnOpTensor!,
CUDNN_ACTIVATION_TANH, CUDNN_ACTIVATION_SIGMOID, CUDNN_ACTIVATION_ELU,
CUDNN_ACTIVATION_RELU, CUDNN_ACTIVATION_CLIPPED_RELU, CUDNN_OP_TENSOR_MAX,
CUDNN_ACTIVATION_IDENTITY
for (f, op) in [
CUDA.tanh => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_TANH),
NNlib.σ => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_SIGMOID),
NNlib.elu => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_ELU),
NNlib.relu => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_RELU),
# NNlib.relu6 => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_CLIPPED_RELU, coef=6.0),
# NNlib.leakyrelu => (src,dst)->cudnnOpTensor!(dst, src, src; op=CUDNN_OP_TENSOR_MAX, alpha1=0.01),
]
@eval begin
# in-place
function Base.materialize!(dst::DenseCuArray{<:CUDNNFloat},
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
$op(bc.args[1], dst)
return dst
end
# out of place
function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
dst = similar(bc, ElType)
$op(bc.args[1], dst)
return dst
end
end
end
# CUDNN_ACTIVATION_IDENTITY does not work with cudnnActivationForward
# FIXME: put this optimization in GPUArrays' `copyto!` (like Base.Broadcast's `copyto!`)
Base.broadcasted(::typeof(identity), x::DenseCuArray{T}) where {T<:CUDNNFloat} = x
================================================
FILE: ext/NNlibCUDACUDNNExt/batchnorm.jl
================================================
using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL,
cudnnBatchNormalizationForwardTraining
import NNlib: batchnorm, ∇batchnorm
# TODO: replace with new cudnn normalization interface
# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl
mutable struct BNCache
mean
ivar
end
BNCache() = BNCache(nothing, nothing)
@inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray,
running_mean, running_var, momentum; kws...)
affine_sz = _wsize(x)
g = fill!(similar(x, affine_sz), 1)
b = fill!(similar(x, affine_sz), 0)
return batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
end
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
# so reshape a 2D Tensor into 4D
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
x = reshape(x, 1, 1, size(x, 1), size(x, 2))
y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
return dropdims(y, dims = (1, 2))
end
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}},
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)
end
function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T},
running_mean, running_var, momentum;
cache = nothing,
alpha = T(1), beta = T(0),
eps = T(1e-5),
training = true,
affine = true,
track_stats = true) where T<:CUDNNFloat
dims = _wsize(x)
if eps < CUDNN_BN_MIN_EPSILON
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
eps = CUDNN_BN_MIN_EPSILON
end
if running_mean === nothing || running_var === nothing
running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing"))
if track_stats || !training
running_mean = fill!(similar(x, dims), 0)
running_var = fill!(similar(x, dims), 1)
end
end
xd = cudnnTensorDescriptor(x)
yd = cudnnTensorDescriptor(y)
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW)))
if training
if !track_stats
running_mean = CU_NULL
running_var = CU_NULL
end
if cache !== nothing
mean = fill!(similar(x, dims), 0)
ivar = fill!(similar(x, dims), 1)
else
mean = CU_NULL
ivar = CU_NULL
end
cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar)
if cache !== nothing
cache.mean = mean
cache.ivar = ivar
end
else
if track_stats
cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps)
else
# cudnnBatchNormalizationForwardInference does not accept CV_NULL for running_mean
# and running_var. We could calculate mean and var of `x` here, but instead use
# cudnnBatchNormalizationFowardTraining. cudnnBatchNormalizationForwardTraining does
# accept CV_NULL and will calculate mean and var itself.
cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, CU_NULL, CU_NULL, eps, CU_NULL, CU_NULL)
end
end
return y
end
function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray,
running_mean, running_var, momentum; kws...)
affine_sz = _wsize(x)
g = fill!(similar(x, affine_sz), 1)
b = fill!(similar(x, affine_sz), 0)
return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)
end
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},
running_mean, running_var, momentum;
kws...) where T<:CUDNNFloat
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
size(dy, 2)), running_mean, running_var, momentum; kws...)
(dg, db, dropdims(dx, dims = (1, 2)))
end
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
running_mean, running_var, momentum;
affine=true, kws...) where T<:CUDNNFloat
dg = similar(g)
db = similar(b)
dx = similar(x)
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...)
if affine
(dg, db, dx)
else
# cuDNN always calculates dg and db, therefore we just have to drop them
(nothing, nothing, dx)
end
end
function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T},
dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
running_mean, running_var,
momentum; cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0), training = true,
track_stats = true) where T<:CUDNNFloat
if !track_stats
running_mean = CU_NULL
running_var = CU_NULL
end
xd = cudnnTensorDescriptor(x)
dyd = cudnnTensorDescriptor(dy)
dxd = cudnnTensorDescriptor(dx)
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
if cache !== nothing
@debug "fetching mean and ivar from the cache"
mean, ivar = cache.mean, cache.ivar
else
mean, ivar = CU_NULL, CU_NULL
end
if eps < CUDNN_BN_MIN_EPSILON
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
eps = CUDNN_BN_MIN_EPSILON
end
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
end
================================================
FILE: ext/NNlibCUDACUDNNExt/conv.jl
================================================
using NNlib: DenseConvDims
import NNlib: conv!, ∇conv_filter!, ∇conv_data!, conv_bias_act!
using cuDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
cudnnConvolutionBwdDataAlgoPerf,
cudnnConvolutionForward!, cudnnConvolutionBwdFilterAlgoPerf,
cudnnConvolutionBackwardData, cudnnConvolutionBackwardFilter,
cudnnConvolutionBackwardBias
import cuDNN: cudnnConvolutionDescriptor
const CUDNNFloat = Union{Float16,Float32,Float64}
const CUDNNComplexFloat = Union{ComplexF16,ComplexF32,ComplexF64}
function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T
# The main purpose of this function is to catch asymmetric padding which cudnn does not support
# If we find asymmetric padding we'll make a copy of x which is manually padded so that we can
# call cudnn with symmetric padding.
pad = NNlib.padding(cdims)
sdims = NNlib.spatial_dims(cdims)
all(i -> pad[i] .== pad[i+1], 1:2:2sdims) && return (cudnnConvolutionDescriptor(cdims, x), x, identity)
# Naive implementation, is there a faster way?
# How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top
# and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out
# which side of x to pad. Oh, and we use a CartesianIndex as we will mainly use this to index in x
pad_manual = CartesianIndex(ntuple(i -> i > sdims ? 0 : pad[2(i-1)+1] - pad[2(i-1)+2], ndims(x)))
# How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top
# and pad_bottom etc. respectively
pad_cudnn = ntuple(i -> min(pad[2(i-1)+1], pad[2(i-1)+2]), sdims)
x_padded_size = ntuple(i -> i <= sdims ? size(x, i) + abs(pad_manual[i]) : size(x ,i), ndims(x))
x_padded = similar(x, x_padded_size)
fill!(x_padded, 0)
# This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x
# Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim
# by dim to an array in a loop
xIs = CartesianIndices(x)
xI_first = first(xIs)
xI_last = last(xIs)
xIs_pad = max(xI_first, xI_first + pad_manual) : max(xI_last, xI_last + pad_manual)
x_padded[xIs_pad] = x
return cudnnConvolutionDescriptor(cdims, x_padded, pad_cudnn), x_padded, _x -> _x[xIs_pad]
end
function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pad = nnlibPadding(cdims)) where T
mode=(NNlib.flipkernel(cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION)
cudnnConvolutionDescriptor(convdims(pad, size(x),0),
convdims(NNlib.stride(cdims),size(x),1),
convdims(NNlib.dilation(cdims),size(x),1),
mode,
cudnnDataType(real(T)),
math_mode(),
CUDNN_DEFAULT_REORDER,
Cint(NNlib.groupcount(cdims)))
end
@inline function _complex!(y::DenseCuArray{T1}, yr::DenseCuArray{T2}, yi::DenseCuArray{T2}; bias=zero(T1), alpha=one(T1), beta=zero(T1), σ=identity) where {T1 <: CUDNNComplexFloat, T2<:CUDNNFloat}
# if y is from similar(), it may have NaNs, and beta*NaN will propagate.
if beta != 0
@. y = σ(alpha*(yr + im*yi) + bias + beta*y)
else
@. y = σ(alpha*(yr + im*yi) + bias)
end
return y
end
function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;
alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
if cudnnversion() < v"6"
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
if algo != -1
@warn "algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
end
d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y)
end
# Complex convolution with Gauss's trick (1 complex mul === 3 real mul):
# Consider x = xr + im*xi, y = yr + im*yi,
# so x*y = (xr*yr - xi*yi) + im*(xr*yi + xi*yr).
# Let a = xr*yr,
# b = xi*yi,
# c = (xr + xi)*(yr + yi) = xr*yr + xr*yi + xi*yr + xi*yi.
# Then,
# x*y = (a - b) + im*(c - a - b).
# Convolution is linear so this multiplication trick translates to convolution.
function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;
alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
xr, xi = reim(x)
wr, wi = reim(w)
a = conv!(similar(real(y)), xr, wr, cdims; algo=algo)
b = conv!(similar(a), xi, wi, cdims; algo=algo)
c = conv!(similar(a), xr + xi, wr + wi, cdims; algo=algo)
return _complex!(y, a - b, c - a - b; alpha=alpha, beta=beta)
end
# (xr + im*xi) * w = xr*w + im*(xi*w)
function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T1}, w::DenseCuArray{T2}, cdims::DenseConvDims;
alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
xr, xi = reim(x)
yr = conv!(similar(real(y)), xr, w, cdims; algo=algo)
yi = conv!(similar(yr), xi, w, cdims; algo=algo)
return _complex!(y, yr, yi; alpha=alpha, beta=beta)
end
# x * (wr + im*wi) = x*wr + im*(x*wi)
function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T2}, w::DenseCuArray{T1}, cdims::DenseConvDims;
alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
wr, wi = reim(w)
yr = conv!(similar(real(y)), x, wr, cdims; algo=algo)
yi = conv!(similar(yr), x, wi, cdims; algo=algo)
return _complex!(y, yr, yi; alpha=alpha, beta=beta)
end
function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
if cudnnversion() < v"6"
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
if algo != -1
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
end
d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
# only relu and identity are supported by cudnnConvolutionForward!
activation = (σ == NNlib.relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY)
cudnnConvolutionForward!(y, w, x, d; z, bias, activation, alpha, beta)
if activation === CUDNN_ACTIVATION_IDENTITY && σ ∉ (nothing, identity)
@. y = σ(y)
end
return y
end
function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
xr, xi = reim(x)
wr, wi = reim(w)
a = conv!(similar(real(y)), xr, wr, cdims; alpha=1, beta=0, algo=algo)
b = conv!(similar(a), xi, wi, cdims; alpha=1, beta=0, algo=algo)
c = conv!(similar(a), xr + xi, wr + wi, cdims; alpha=1, beta=0, algo=algo)
return _complex!(y, a - b, c - a - b; bias=bias, alpha=alpha, beta=beta, σ=σ)
end
function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
if cudnnversion() < v"6"
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
if algo != -1
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
end
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx)
xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx, beta!=0)
with_workspace(p.memory) do workspace
cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx)
end
return depad(dx)
end
function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
dyr, dyi = reim(dy)
wr, wi = reim(w)
# note: w is conjugated, i.e. wi is negated below
a = ∇conv_data!(similar(real(dx)), dyr, wr, cdims; alpha=1, beta=0, algo=algo)
b = ∇conv_data!(similar(a), dyi, -wi, cdims; alpha=1, beta=0, algo=algo)
c = ∇conv_data!(similar(a), dyr + dyi, wr - wi, cdims; alpha=1, beta=0, algo=algo)
return _complex!(dx, a - b, c - a - b; alpha=alpha, beta=beta)
end
# dx = (dyr + im*dyi)*w = dyr*w + im*(dyi*w)
function ∇conv_data!(dx::DenseCuArray{T1}, dy::DenseCuArray{T1}, w::DenseCuArray{T2},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
dyr, dyi = reim(dy)
dxr = ∇conv_data!(similar(real(dx)), dyr, w, cdims; alpha=1, beta=0, algo=algo)
dxi = ∇conv_data!(similar(dxr), dyi, w, cdims; alpha=1, beta=0, algo=algo)
return _complex!(dx, dxr, dxi; alpha=alpha, beta=beta)
end
function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
if cudnnversion() < v"6"
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
if algo != -1
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
end
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw, beta!=0);
with_workspace(p.memory) do workspace
cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw);
end
return dw
end
function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
xr, xi = reim(x)
dyr, dyi = reim(dy)
# note: x is conjugated, i.e. xi is negated below
a = ∇conv_filter!(similar(real(dw)), xr, dyr, cdims; alpha=1, beta=0, algo=algo)
b = ∇conv_filter!(similar(a), -xi, dyi, cdims; alpha=1, beta=0, algo=algo)
c = ∇conv_filter!(similar(a), xr - xi, dyr + dyi, cdims; alpha=1, beta=0, algo=algo)
return _complex!(dw, a - b, c - a - b; alpha=alpha, beta=beta)
end
# dw = x*(dyr + im*dyi) = x*dyr + im*(x*dyi)
function ∇conv_filter!(dw::DenseCuArray{T1}, x::DenseCuArray{T2}, dy::DenseCuArray{T1},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
dyr, dyi = reim(dy)
dwr = ∇conv_filter!(similar(real(dw)), x, dyr, cdims; alpha=1, beta=0, algo=algo)
dwi = ∇conv_filter!(similar(dwr), x, dyi, cdims; alpha=1, beta=0, algo=algo)
return _complex!(dw, dwr, dwi; alpha=alpha, beta=beta)
end
function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNFloat
alpha,beta = scalingParameter(T,alpha), scalingParameter(T,beta)
bDesc, yDesc = cudnnTensorDescriptor.((db,dy))
cudnnConvolutionBackwardBias(handle(), alpha, yDesc, dy, beta, bDesc, db)
return db
end
function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNComplexFloat
dyr, dyi = reim(dy)
dbr = ∇conv_bias!(similar(real(db)), dyr; alpha=1, beta=0)
dbi = ∇conv_bias!(similar(dbr), dyi; alpha=1, beta=0)
return _complex!(db, dbr, dbi; alpha=alpha, beta=beta)
end
================================================
FILE: ext/NNlibCUDACUDNNExt/pooling.jl
================================================
using cuDNN: cudnnPoolingMode_t, CUDNN_POOLING_MAX,
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING,
cudnnPoolingForward!, pooldims, cudnnPoolingBackward
import NNlib: maxpool!, ∇maxpool!, meanpool!, ∇meanpool!
import cuDNN: cudnnPoolingDescriptor
function cudnnPoolingDescriptor(pdims::PoolDims, x::DenseCuArray{T}, mode::cudnnPoolingMode_t) where T
window, padding, stride = NNlib.kernel_size(pdims), nnlibPadding(pdims), NNlib.stride(pdims)
nanOpt = CUDNN_NOT_PROPAGATE_NAN
cudnnPoolingDescriptor(mode, nanOpt, Cint(ndims(x)-2), pooldims(window,size(x)), pooldims(padding,size(x)), pooldims(stride,size(x)))
end
function maxpool!(y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat
d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_MAX)
cudnnPoolingForward!(y, x, d)
end
function ∇maxpool!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat
xDesc, yDesc = cudnnTensorDescriptor.((x, y))
d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_MAX)
alpha, beta = scalingParameter(T,1), scalingParameter(T,0)
cudnnPoolingBackward(handle(), d, alpha, yDesc, y, yDesc, dy, xDesc, x, beta, xDesc, dx)
return dx
end
function meanpool!(y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat
d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING)
cudnnPoolingForward!(y, x, d)
end
function ∇meanpool!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat
xDesc, yDesc = cudnnTensorDescriptor.((x, y))
d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING)
alpha, beta = scalingParameter(T,1), scalingParameter(T,0)
cudnnPoolingBackward(handle(), d, alpha, yDesc, y, yDesc, dy, xDesc, x, beta, xDesc, dx)
return dx
end
### Since CUDA.jl does not support 1D pooling, we have to convert to 2d
add1d(x) = reshape(x, 1, size(x)...)
function fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D}
PoolDims{2, K + 1, S + 1, P + 2, D + 1}((1, NNlib.input_size(pdims)...),
(1, NNlib.kernel_size(pdims)...),
NNlib.channels_in(pdims),
(1, NNlib.stride(pdims)...),
(0, 0, NNlib.padding(pdims)...),
(1, NNlib.dilation(pdims)...))
end
function maxpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
maxpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims))
return y
end
function meanpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
meanpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims))
return y
end
function ∇maxpool!(dx::DenseCuArray{T,3}, dy::DenseCuArray{T,3}, y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
∇maxpool!(add1d(dx), add1d(dy), add1d(y), add1d(x), fix_pooldims_1d(pdims))
return dx
end
function ∇meanpool!(dx::DenseCuArray{T,3}, dy::DenseCuArray{T,3}, y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
∇meanpool!(add1d(dx), add1d(dy), add1d(y), add1d(x), fix_pooldims_1d(pdims))
return dx
end
================================================
FILE: ext/NNlibCUDACUDNNExt/softmax.jl
================================================
import NNlib: softmax, softmax!, ∇softmax, ∇softmax!,
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!
using cuDNN: CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL,
CUDNN_SOFTMAX_FAST, CUDNN_SOFTMAX_ACCURATE, cudnnSoftmaxForward!,
cudnnSoftmaxBackward
# Softmax
# @denizyuret: do not do inplace operations with softmax/logsoftmax when (1) cpu version is not, (2) one can use softmax!
function softmax(x::T; dims=1) where {T<:DenseCuArray}
softmax!(similar(x), x; dims)
end
function ∇softmax(dy::T, x::T, y::T; dims=1) where {T<:DenseCuArray}
∇softmax!(similar(x), dy, x, y; dims)
end
function logsoftmax(x::T; dims=1) where {T<:DenseCuArray}
logsoftmax!(similar(x), x; dims)
end
function ∇logsoftmax(dy::T, x::T, y::T; dims=1) where {T<:DenseCuArray}
∇logsoftmax!(similar(x), dy, x, y; dims)
end
# @denizyuret: backup implementations for unsupported/slow size/dims combinations:
function _softmax!(y::T, x::T; dims) where {T<:DenseCuArray}
y .= exp.(x .- maximum(x; dims))
y ./= sum(y; dims)
end
function _∇softmax!(dx::T, dy::T, x::T, y::T; dims) where {T<:DenseCuArray}
dx .= y .* (dy .- sum(dy .* y; dims))
end
function _logsoftmax!(y::T, x::T; dims) where {T<:DenseCuArray}
y .= x .- maximum(x; dims)
y .-= log.(sum(exp.(y); dims))
end
function _∇logsoftmax!(dx::T, dy::T, x::T, y::T; dims) where {T<:DenseCuArray}
dx .= dy .- sum(dy; dims) .* exp.(y)
end
# Trick by @norci to use cudnn for softmax dims args that are contiguous:
# If dims=(dmin:dmax) then CUDNN_SOFTMAX_MODE_CHANNEL does the trick with reshape
# (1, prod(size(x)[1:dmin-1]), prod(size(x)[dmin:dmax]), :)
# softmaxdims returns nothing when the backup implementation should be used.
function softmaxdims(x, dims)
dims === Colon() && return (1, 1, length(x), 1)
mind,maxd = minimum(dims),maximum(dims)
all(i in dims for i in mind:maxd) || return nothing # cannot handle if not contiguous
stride = dimsize = 1
for i in 1:(mind-1); stride *= size(x,i); end # Using size(x,i) assumes trailing dims = 1, robust to maxd > ndims(x)
for i in mind:maxd; dimsize *= size(x,i); end
batchsize = length(x)÷(stride*dimsize)
# Here is a region where cudnn is slower, so we go with the backup:
batchsize == 1 && 64 <= stride <= 4096 && 64 <= dimsize <= 4096 && return nothing
return (1, stride, dimsize, batchsize)
end
# Determine softmax algo based on math_mode
softmaxalgo() = (CUDA.math_mode()===CUDA.FAST_MATH ? CUDNN_SOFTMAX_FAST : CUDNN_SOFTMAX_ACCURATE)
# Main implementations:
function softmax!(y::T, x::T = y; dims=1) where {T<:DenseCuArray}
s = softmaxdims(x, dims)
s === nothing && return _softmax!(y, x; dims)
cudnnSoftmaxForward!(reshape(y,s), reshape(x,s); mode = CUDNN_SOFTMAX_MODE_CHANNEL, algo = softmaxalgo())
return y
end
function ∇softmax!(dx::T, dy::T, x::T, y::T; dims=1) where {R,T<:DenseCuArray{R}}
s = softmaxdims(x, dims)
s === nothing && return _∇softmax!(dx, dy, x, y; dims)
xDesc = cudnnTensorDescriptor(reshape(x,s))
alpha, beta = scalingParameter(R,1), scalingParameter(R,0)
cudnnSoftmaxBackward(handle(), softmaxalgo(), CUDNN_SOFTMAX_MODE_CHANNEL,
alpha, xDesc, y, xDesc, dy, beta, xDesc, dx)
return dx
end
function logsoftmax!(y::T, x::T = y; dims=1) where {T<:DenseCuArray}
s = softmaxdims(x, dims)
s === nothing && return _logsoftmax!(y, x; dims)
cudnnSoftmaxForward!(reshape(y,s), reshape(x,s); mode = CUDNN_SOFTMAX_MODE_CHANNEL, algo = CUDNN_SOFTMAX_LOG)
return y
end
function ∇logsoftmax!(dx::T, dy::T, x::T, y::T; dims=1) where {R,T<:DenseCuArray{R}}
s = softmaxdims(x, dims)
s === nothing && return _∇logsoftmax!(dx, dy, x, y; dims)
xDesc = cudnnTensorDescriptor(reshape(x,s))
alpha, beta = scalingParameter(R,1), scalingParameter(R,0)
cudnnSoftmaxBackward(handle(), CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL,
alpha, xDesc, y, xDesc, dy, beta, xDesc, dx)
return dx
end
================================================
FILE: ext/NNlibCUDAExt/NNlibCUDAExt.jl
================================================
module NNlibCUDAExt
using NNlib
using CUDA
using Random, Statistics
include("sampling.jl")
include("activations.jl")
include("batchedadjtrans.jl")
include("batchedmul.jl")
include("ctc.jl")
include("scatter.jl")
include("utils.jl")
end # module
================================================
FILE: ext/NNlibCUDAExt/activations.jl
================================================
# Activation functions
# Some of activation functions need a wrapper for GPU support
# https://github.com/JuliaGPU/CuArrays.jl/issues/614
# @cufunc softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
# @cufunc logσ(x::Real) = -softplus(-x)
# @cufunc function gelu(x::Real)
# p = oftype(x / 1, π)
# λ = oftype(x / 1, √(2 / p))
# α = oftype(x / 1, 0.044715)
# h = oftype(x / 1, 0.5)
# h * x * (one(x) + tanh(λ * (x + α * x^3)))
# end
# @cufunc lisht(x::Real) = x * tanh(x)
# @cufunc logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2))
# @cufunc mish(x::Real) = x * tanh(softplus(x))
# @cufunc tanhshrink(x::Real) = x - tanh(x)
================================================
FILE: ext/NNlibCUDAExt/batchedadjtrans.jl
================================================
using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
using Adapt
using Adapt: WrappedArray
const CuBatchedAdjoint{T} = BatchedAdjoint{T, <: CuArray{T}}
const CuBatchedTranspose{T} = BatchedTranspose{T, <: CuArray{T}}
const CuBatchedAdjOrTrans{T} = Union{CuBatchedAdjoint{T}, CuBatchedTranspose{T}}
const WrappedCuBatchedAdjOrTrans{T, N} = WrappedArray{T, N, CuBatchedAdjOrTrans{T}, CuBatchedAdjOrTrans{T}}
Base.print_array(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))
Base._show_nonempty(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)
Base.show_vector(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)
Base.convert(::Type{T}, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))
Base.Array{T, N}(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))
Base.collect(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = collect(adapt(Array, b))
================================================
FILE: ext/NNlibCUDAExt/batchedmul.jl
================================================
# Batched matrix multiplication
# 1st argument is produced by NNlib.storage_type(A)
NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C)
Base.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} =
Base.unsafe_convert(CuPtr{T}, parent(A))
================================================
FILE: ext/NNlibCUDAExt/ctc.jl
================================================
# CTC loss moved from Flux.jl to NNlib
import NNlib: ctc_loss, ctc_alpha, ∇ctc_loss
## GPU implementation
# a port of the GPU kernels from Baidu's C++ warp-ctc package,
# which itself is Copyright 2015-2016 Baidu USA LLC
# and available under the Apache 2.0 license
#
# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0
# GitHub: https://github.com/baidu-research/warp-ctc/
# paper: https://arxiv.org/pdf/1512.02595.pdf
const MAX_THREADS = 256
@inline function log_plus_f(p1, p2)
isinf(p1) && return p2
isinf(p2) && return p1
if p1 < p2
p1, p2 = p2, p1
end
return p1 + log(1+exp(p2 - p1))
end
function count_repeats(A)
repeats = 0
for (i,elem) in enumerate(A)
if i > 1 && A[i] == A[i-1]
repeats += 1
end
end
return repeats
end
function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)
tid = threadIdx().x
L = labelSize
T = uttLength
S = length(labelsWithBlanks)
if L + repeats > T
return nothing
end
labels = labelsWithBlanks
# Corner-case checking
start = (L + repeats <= T) ? 0 : 1
last = S > 1 ? 2 : 1
# Fill in first column (time step)
i = tid
while i <= last - start
alpha[start+i, 1] = probs[labels[start+i], 1]
i += blockDim().x
end
sync_threads()
# Fill in coefficients for each time step
for t=2:T
# Corner-case checking
if tid == 1 && !(1 < S - 2*(T-t) - 1)
if start == 0
alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t]
elseif start == 1
alpha[1, t] = alpha[1, t-1]
end
end
sync_threads()
# Fill in coefficients for each label class in the target output sequence;
# each thread will process the calculations for one class
idx = tid+1
while idx <= S
prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])
if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
end
if idx < S - 2*(T-t) - 1
alpha[idx, t] = -Inf32
else
alpha[idx, t] = prevSum + probs[labels[idx], t]
end
idx += blockDim().x
end
sync_threads()
end
return nothing
end
function compute_beta_and_grad_kernel(probs, labelSize, uttLength,
repeatsInLabel, labelsWithBlanks,
alphas, beta, output, accum,
grad, blankLabel, loss)
tid = threadIdx().x
L = labelSize
T = uttLength
S = 2*L + 1
repeats = repeatsInLabel
labels = labelsWithBlanks
if (L+repeats) > T
return nothing
end
# Corner-case checking
start = S > 1 ? S-2 : 0
last = L + repeats < T ? S : S-1
sync_threads()
i = tid
# Calculate coefficients for last column (time step)
# then determine alpha and beta product
while i <= last - start
beta[i+start, T] = 0
output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
i += blockDim().x
end
sync_threads()
# Fill in `accum` for last column (time step)
if tid == 1
for i=1:S
labelIdx = labels[i]
accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
end
end
sync_threads()
# Fill in `grad` for last column (time step)
idx = tid
while idx <= size(grad, 1)
s = -Inf32
for i=1:S
s = log_plus_f(s, output[i, T])
end
# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, T] = exp(probs[idx, T]) - exp(accum[idx, T] - s)
idx += blockDim().x
end
sync_threads()
# Fill in the rest of the coefficients
t = T-1
while t >= 1
if t < T
idx = tid
while idx <= S
nextSum = probs[labels[idx], t+1] + beta[idx, t+1]
if idx < S
nextSum = log_plus_f(nextSum,
probs[labels[idx+1], t+1] + beta[idx+1, t+1])
end
if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
nextSum = log_plus_f(nextSum,
probs[labels[idx+2], t+1] + beta[idx + 2, t+1])
end
if idx > 2*t
beta[idx, t] = -Inf32
else
beta[idx, t] = nextSum
end
idx += blockDim().x
end
sync_threads()
idx = tid
while idx <= S
output[idx, t] = alphas[idx, t] + beta[idx, t]
idx += blockDim().x
end
sync_threads()
end
sync_threads()
# Calculate accumulated alpha-beta products for each label class for
# each time step; used in calculating gradients
if tid == 1
for i=1:S
labelIdx = labels[i]
accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
end
end
sync_threads()
idx = tid
# Calculate gradients
while idx <= size(grad, 1)
# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, t] = exp(probs[idx, t]) - exp(accum[idx, t] + loss)
idx += blockDim().x
end
sync_threads()
t -= 1
sync_threads()
end
return nothing
end
function ctc_alpha(ŷ::CuArray, y)
ŷ = logsoftmax(ŷ)
blank = size(ŷ, 1)
ycu = cu(y)
z′ = CUDA.fill(blank, 2 * length(y) + 1)
z′[eachindex(y) .* 2] .= ycu
T = size(ŷ, 2)
U′ = 2*length(y) + 1
alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T)
nRepeats = count_repeats(CUDA.adapt(Array, y))
nThreads = min(U′, MAX_THREADS)
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank)
return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)
end
ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss
function ∇ctc_loss(ŷ::CuArray, y, out)
loss, alphas, z′, ŷ, nRepeats = out
U′, T = size(alphas)
blank = size(ŷ, 1)
typed_zero = zero(eltype(ŷ))
betas = CUDA.fill(log(typed_zero), U′, T)
output = CUDA.fill(log(typed_zero), U′, T)
nThreads = min(U′, MAX_THREADS)
grads = CUDA.fill(log(typed_zero), size(ŷ))
accum = CUDA.fill(log(typed_zero), size(ŷ))
@cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss)
return grads
end
================================================
FILE: ext/NNlibCUDAExt/sampling.jl
================================================
@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 4}, value, ix, iy, c, n) where T
@inbounds CUDA.@atomic dx[ix, iy, c, n] += value
end
function grid_sample_kernel!(n_elem, output::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V}
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
if index < n_elem
iW, iH, iC, _ = size(input)
_, gW, gH, _ = size(grid)
w = index % gW + 1
h = (index ÷ gW) % gH + 1
n = index ÷ (gW * gH) + 1
NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, n, iW, iH, iC)
end
nothing
end
function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 4}, dgrid::AbstractArray{V, 4}, Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V}
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
if index < n_elem
iW, iH, iC, _ = size(input)
_, gW, gH, _ = size(grid)
w = index % gW + 1
h = (index ÷ gW) % gH + 1
n = index ÷ (gW * gH) + 1
NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, n, iW, iH, iC)
end
nothing
end
function NNlib.grid_sample(x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V}
pad = Val(padding_mode)
_, _, xC, xN = size(x)
_, gW, gH, _ = size(grid)
n_elem = gW * gH * xN
y = similar(x, T, (gW, gH, xC, xN))
kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(n_elem, config.threads)
blocks = cld(n_elem, threads)
kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks)
y
end
function NNlib.∇grid_sample(Δ::CuArray{T, 4}, x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V}
pad = Val(padding_mode)
xN = size(x, 4)
_, gW, gH, _ = size(grid)
n_elem = gW * gH * xN
dx, dgrid = CUDA.zeros(T, size(x)), similar(grid)
kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(n_elem, config.threads)
blocks = cld(n_elem, threads)
kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)
dx, dgrid
end
@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 5}, value, ix, iy, iz, c, n) where T
@inbounds CUDA.@atomic dx[ix, iy, iz, c, n] += value
end
function grid_sample_kernel!(n_elem, output::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V}
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
if index < n_elem
iW, iH,iD, iC, _ = size(input)
_, gW, gH, gD, _ = size(grid)
w = index % gW + 1
h = (index ÷ gW) % gH + 1
d = (index ÷ (gW * gH)) % gD + 1
n = index ÷ (gW * gH * gD) + 1
# n = index ÷ (gW * gH) + 1
# d = (index ÷ (gW * gH * n)) + 1
NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)
end
nothing
end
function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 5}, dgrid::AbstractArray{V, 5}, Δ::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V}
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
if index < n_elem
iW, iH, iD, iC, _ = size(input)
_, gW, gH, gD, _ = size(grid)
w = index % gW + 1
h = (index ÷ gW) % gH + 1
d = (index ÷ (gW * gH)) % gD + 1
n = index ÷ (gW * gH * gD) + 1
# n = index ÷ (gW * gH) + 1
# d = (index ÷ (gW * gH * n)) + 1
NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)
end
nothing
end
function NNlib.grid_sample(x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V}
pad = Val(padding_mode)
_, _, _, xC, xN = size(x)
_, gW, gH, gD, _ = size(grid)
n_elem = gW * gH * gD * xN
y = similar(x, T, (gW, gH, gD, xC, xN))
kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(n_elem, config.threads)
blocks = cld(n_elem, threads)
kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks)
y
end
function NNlib.∇grid_sample(Δ::CuArray{T, 5}, x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V}
pad = Val(padding_mode)
xN = size(x, 5)
_, gW, gH, gD, _ = size(grid)
n_elem = gW * gH * gD * xN
dx, dgrid = CUDA.zeros(T, size(x)), similar(grid)
kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(n_elem, config.threads)
blocks = cld(n_elem, threads)
kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)
dx, dgrid
end
================================================
FILE: ext/NNlibCUDAExt/scatter.jl
================================================
# supported op: +, -, *, /, max, min, &, |, mean
## TODO support sparse dst/src/idx
## See issue https://github.com/FluxML/NNlib.jl/issues/647
# import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector
# const AnyCuSparseMatrix{Tv,Ti} = Union{
# AbstractCuSparseMatrix{Tv,Ti},
# CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix
# CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX
# CUDA.CuSparseMatrixCOO{Tv,Ti},
# }
# const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}}
function scatter_kernel!(op::OP, dst, src, idx) where OP
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= length(idx)
CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src[index])
end
return nothing
end
function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}) where OP
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= length(idx)
li = Base._to_linear_index(dst, Tuple(idx[index])...)
CUDA.@atomic dst[li] = op(dst[li], src[index])
end
return nothing
end
function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size) where OP
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= max_idx
j, k = divrem(index-1, max_dims_idx)
dims_i = CartesianIndices(dims_size)[k+1]
CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index])
end
return nothing
end
function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
max_idx, max_dims_idx, dims_size) where OP
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= max_idx
j, k = divrem(index-1, max_dims_idx)
dims_i = CartesianIndices(dims_size)[k+1]
li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...)
CUDA.@atomic dst[li] = op(dst[li], src[index])
end
return nothing
end
function NNlib.scatter!(op::OP, dst::AnyCuArray,
src::AnyCuArray,
idx::AnyCuArray) where OP
isempty(idx) && return dst
dims = NNlib.scatter_dims(dst, src, idx)
args = if dims == 0
max_idx = length(idx)
op, dst, src, idx
else
dims_size = size(dst)[1:dims]
max_dims_idx = prod(dims_size)
max_idx = max_dims_idx * length(idx)
op, dst, src, idx, max_idx, max_dims_idx, dims_size
end
kernel = @cuda launch=false scatter_kernel!(args...)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(max_idx, config.threads)
blocks = cld(max_idx, threads)
kernel(args...; threads=threads, blocks=blocks)
return dst
end
function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray,
src::AnyCuArray,
idx::AnyCuArray)
Ns = NNlib.scatter!(+, zero(dst), one.(src), idx)
dst_ = NNlib.scatter!(+, zero(dst), src, idx)
dst .+= NNlib.safe_div.(dst_, Ns)
return dst
end
## Gradients
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= max_idx
cart_j = CartesianIndices(idx)[index]
# get aggregating indeices, which is to be aggregated together, and itself index
inds = rev_idx[idx[cart_j]...]
# multiply all values to be aggregated but not itself
x = one(T)
for k in inds
x *= src[k]
end
x /= src[cart_j]
# apply `op` on `Δsrc[i, k]` and `x`
Δsrc[cart_j] = op(Δsrc[cart_j], x)
end
return nothing
end
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= max_idx
cart_j = CartesianIndices(idx)[index]
# get aggregating indeices, which is to be aggregated together, and itself index
inds = rev_idx[Tuple(idx[cart_j])...]
# multiply all values to be aggregated but not itself
x = one(T)
for k in inds
x *= src[k]
end
x /= src[cart_j]
# apply `op` on `Δsrc[i, k]` and `x`
Δsrc[cart_j] = op(Δsrc[cart_j], x)
end
return nothing
end
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= max_idx
i, j = fldmod1(index, max_dims_idx)
cart_i = CartesianIndices(idx)[i]
cart_j = pre_cart_idx[j]
# get aggregating indeices, which is to be aggregated together, and itself index
inds = rev_idx[idx[cart_i]...]
# multiply all values to be aggregated but not itself
x = one(T)
for k in inds
jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)
x *= src[jk]
end
x /= src[index]
# apply `op` on `Δsrc[i, k]` and `x`
Δsrc[index] = op(Δsrc[index], x)
end
return nothing
end
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= max_idx
i, j = fldmod1(index, max_dims_idx)
cart_i = CartesianIndices(idx)[i]
cart_j = pre_cart_idx[j]
# get aggregating indeices, which is to be aggregated together, and itself index
inds = rev_idx[Tuple(idx[cart_i])...]
# multiply all values to be aggregated but not itself
x = one(T)
for k in inds
jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)
x *= src[jk]
end
x /= src[index]
# apply `op` on `Δsrc[i, k]` and `x`
Δsrc[index] = op(Δsrc[index], x)
end
return nothing
end
function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AnyCuArray,
idx::AnyCuArray)
dims = ndims(src) - ndims(idx)
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
rev_idx = NNlib.reverse_indices(idx)
rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))
if dims == 0
max_idx = length(idx)
args = op, Δsrc, src, idx, rev_idx, max_idx, eltype(src)
else
pre_cart_idx = CartesianIndices(axes(src)[1:dims])
max_dims_idx = length(pre_cart_idx)
max_idx = max_dims_idx * length(idx)
args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, eltype(src)
end
kernel = @cuda launch=false ∇scatter_src_kernel!(args...)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(max_idx, config.threads)
blocks = cld(max_idx, threads)
kernel(args...; threads=threads, blocks=blocks)
CUDA.unsafe_free!(rev_idx)
return Δsrc
end
================================================
FILE: ext/NNlibCUDAExt/utils.jl
================================================
NNlib._rng_from_array(::CuArray) = CUDA.default_rng()
NNlib._rng_compat_array(rng::CUDA.RNG, A::CuArray) = nothing
NNlib._rng_compat_array(rng::AbstractRNG, A::CuArray) = throw(ArgumentError(
"cannot use rng::$(typeof(rng)) with array::CuArray, only CUDA's own RNG type works"))
function divide_kernel!(xs, ys, max_idx)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= max_idx
xs[index] = xs[index] / ys[index]
end
return nothing
end
function divide_kernel!(xs, counts, max_idx, max_dims_idx, dims_size)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
@inbounds if index <= max_idx
j, k = divrem(index-1, max_dims_idx)
dims_i = Tuple(CartesianIndices(dims_size)[k+1])
CUDA.@atomic xs[dims_i..., j+1] = xs[dims_i..., j+1] / counts[j+1]
end
return nothing
end
function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N
max_dims = NNlib.maximum_dims(idx)
T = CartesianIndex{N}
rev = Array{Vector{T}}(undef, max_dims...)
for i in eachindex(rev)
rev[i] = T[]
end
NNlib.reverse_indices!(rev, idx)
return map(cu, rev)
end
================================================
FILE: ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl
================================================
module NNlibEnzymeCoreExt
using NNlib
import EnzymeCore
using Random
using EnzymeCore.EnzymeRules
for (name, dataname, filtername) in (
(typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!),
(typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!),
(typeof(NNlib.∇conv_data!), NNlib.conv!, NNlib.∇conv_filter!),
(typeof(NNlib.∇conv_filter!), NNlib.∇conv_data!, NNlib.conv!),
)
@eval begin
function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT},
y::EnzymeCore.Annotation{<:AbstractArray{yT, N}},
x::EnzymeCore.Annotation{<:AbstractArray{xT, N}},
w::EnzymeCore.Annotation{<:AbstractArray{wT, N}},
cdims; kwargs...) where {RT, yT, xT, wT, N}
if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated
func.val(y.val, x.val, w.val, cdims.val; kwargs...)
end
primal = if EnzymeRules.needs_primal(config)
y.val
else
nothing
end
shadow = if EnzymeRules.needs_shadow(config)
y.dval
else
nothing
end
# Cache x if its overwritten and w is active (and thus required)
cache_x = ( EnzymeRules.overwritten(config)[3]
&& !(typeof(w) <: EnzymeCore.Const)
&& !(typeof(y) <: EnzymeCore.Const)
) ? copy(x.val) : nothing
# Cache w if its overwritten and x is active (and thus required)
cache_w = ( EnzymeRules.overwritten(config)[4]
&& !(typeof(x) <: EnzymeCore.Const)
&& !(typeof(y) <: EnzymeCore.Const)
) ? copy(w.val) : nothing
cache = (cache_x, cache_w)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache,
y::EnzymeCore.Annotation{<:AbstractArray{yT, N}},
x::EnzymeCore.Annotation{<:AbstractArray{xT, N}},
w::EnzymeCore.Annotation{<:AbstractArray{wT, N}},
cdims; kwargs...) where {RT, yT, xT, wT, N}
cache_x, cache_w = cache
# Don't cache x if not overwritten and w is active (and thus required)
if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
if !EnzymeRules.overwritten(config)[3]
cache_x = x.val
end
end
# Don't cache w if not overwritten and x is active (and thus required)
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
if !EnzymeRules.overwritten(config)[4]
cache_w = w.val
end
end
dys = y.dval
dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval
dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval
if EnzymeRules.width(config) == 1
dys = (dys,)
dxs = (dxs,)
dws = (dws,)
end
for (dy, dx, dw) in zip(dys, dxs, dws)
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
# dx += grad wrt x.val
$dataname(dx, $(name != typeof(NNlib.∇conv_filter!) ? :dy : :cache_w), $(name != typeof(NNlib.∇conv_filter!) ? :cache_w : :dy), cdims.val; alpha=xT(1), beta=xT(1), kwargs...)
end
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
# dw += grad wrt w.val
$filtername(dw, $(name != typeof(NNlib.∇conv_data!) ? :cache_x : :dy), $(name != typeof(NNlib.∇conv_data!) ? :dy : :cache_x), cdims.val; alpha=wT(1), beta=wT(1), kwargs...)
end
dy .= 0
end
end
return (nothing, nothing, nothing, nothing)
end
end
end
function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
func.val(dst.val, src.val, idx.val)
end
primal = if EnzymeRules.needs_primal(config)
dst.val
else
nothing
end
shadow = if EnzymeRules.needs_shadow(config)
dst.dval
else
nothing
end
# Cache idx if its overwritten
cache_idx = ( EnzymeRules.overwritten(config)[4]
&& !(typeof(src) <: EnzymeCore.Const)
&& !(typeof(dst) <: EnzymeCore.Const)
) ? copy(idx.val) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
end
function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
# Don't cache idx if not overwritten
if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)
if !EnzymeRules.overwritten(config)[4]
cache_idx = idx.val
end
end
ddsts = dst.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
if EnzymeRules.width(config) == 1
ddsts = (ddsts,)
dsrcs = (dsrcs,)
end
for (ddst, dsrc) in zip(ddsts, dsrcs)
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
NNlib.scatter!(+, dsrc, ddst, cache_idx)
end
ddst .= 0
end
end
return (nothing, nothing, nothing)
end
function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
@assert !(OutType <: EnzymeCore.Const)
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
func.val(op.val, dst.val, src.val, idx.val)
end
primal = if EnzymeRules.needs_primal(config)
dst.val
else
nothing
end
shadow = if EnzymeRules.needs_shadow(config)
dst.dval
else
nothing
end
# Cache idx if its overwritten
cache_idx = ( EnzymeRules.overwritten(config)[4]
&& !(typeof(src) <: EnzymeCore.Const)
&& !(typeof(dst) <: EnzymeCore.Const)
) ? copy(idx.val) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
end
function EnzymeRules.reverse(config,
func::EnzymeCore.Const{typeof(NNlib.scatter!)},
::Type{RT},
cache_idx,
op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType,
src,
idx::EnzymeCore.Const) where {OutType, RT}
# Don't cache idx if not overwritten
if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)
if !EnzymeRules.overwritten(config)[4]
cache_idx = idx.val
end
end
ddsts = dst.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
if EnzymeRules.width(config) == 1
ddsts = (ddsts,)
dsrcs = (dsrcs,)
end
for (ddst, dsrc) in zip(ddsts, dsrcs)
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
if eltype(typeof(op)) == typeof(+)
dsrc .+= NNlib.gather(ddst, cache_idx)
else
@assert eltype(typeof(op)) == typeof(-)
dsrc .-= NNlib.gather(ddst, cache_idx)
end
end
end
end
return (nothing, nothing, nothing, nothing)
end
for pool in [:maxpool, :meanpool, :lpnormpool]
pool! = Symbol(pool, :!)
∇pool = Symbol(:∇, pool, :!)
@eval begin
function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT}
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
func.val(y.val, x.val, dims.val; kwargs...)
end
primal = if EnzymeRules.needs_primal(config)
y.val
else
nothing
end
shadow = if EnzymeRules.needs_shadow(config)
y.dval
else
nothing
end
cache_y = ( EnzymeRules.overwritten(config)[2]
&& !(typeof(x) <: EnzymeCore.Const)
&& !(typeof(y) <: EnzymeCore.Const)
) ? copy(y.val) : nothing
cache_x = ( EnzymeRules.overwritten(config)[3]
&& !(typeof(x) <: EnzymeCore.Const)
&& !(typeof(y) <: EnzymeCore.Const)
) ? copy(x.val) : nothing
cache = (cache_y, cache_x)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT}
cache_y, cache_x = cache
# Don't cache y if not overwritten
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
if !EnzymeRules.overwritten(config)[2]
cache_y = y.val
end
end
# Don't cache x if not overwritten
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
if !EnzymeRules.overwritten(config)[3]
cache_x = x.val
end
end
dys = y.dval
dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval
if EnzymeRules.width(config) == 1
dys = (dys,)
dxs = (dxs,)
end
for (dy, dx) in zip(dys, dxs)
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...)
end
dy .= 0
end
end
return (nothing, nothing, nothing)
end
end
end
function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT}
T = float(real(eltype(dst.val)))
val = convert(T, 1/(1-p.val))
keep = if dims.val isa Colon
similar(dst.val, T, size(dst.val))
else
similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val)))
end
rand!(rng.val, keep)
keep = keep .> p.val
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
dst.val .= (keep .* val) .* src.val
end
primal = if EnzymeRules.needs_primal(config)
dst.val
else
nothing
end
shadow = if EnzymeRules.needs_shadow(config)
dst.dval
else
nothing
end
if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const
keep = nothing
end
return EnzymeRules.AugmentedReturn(primal, shadow, keep)
end
function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT}
T = float(real(eltype(dst.val)))
val = convert(T, 1/(1-p.val))
ddsts = dst.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
if EnzymeRules.width(config) == 1
ddsts = (ddsts,)
dsrcs = (dsrcs,)
end
for (ddst, dsrc) in zip(ddsts, dsrcs)
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
dsrc .+= (keep .* val) .* ddst
end
ddst .= 0
end
end
dp = if typeof(p) <: EnzymeCore.Active
typeof(p.val)(0)
else
nothing
end
return (nothing, nothing, nothing, dp, nothing)
end
end
================================================
FILE: ext/NNlibFFTWExt/NNlibFFTWExt.jl
================================================
module NNlibFFTWExt
using FFTW
using NNlib
using KernelAbstractions
include("stft.jl")
end
================================================
FILE: ext/NNlibFFTWExt/stft.jl
================================================
function NNlib.stft(x;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
)
kab = get_backend(x)
use_window = !isnothing(window)
use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as stft input `x` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))
# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp
end
if center
pad_amount = n_fft ÷ 2
x = pad_reflect(x, pad_amount; dims=1)
end
n = size(x, 1)
(0 < n_fft ≤ n) || throw(ArgumentError(
"Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`."))
n_frames = 1 + (n - n_fft) ÷ hop_length
# time2col.
# Reshape `x` to (n_fft, n_frames, B) if needed.
# Each row in `n_frames` is shifted by `hop_length`.
if n_frames > 1
# TODO can be more efficient if we support something like torch.as_strided
ids = [
row + hop_length * col
for row in 1:n_fft, col in 0:(n_frames - 1)]
x = @inbounds x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
end
region = 1
use_window && (x = x .* window;)
y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region)
normalized && (y = y .* eltype(y)(n_fft^-0.5);)
return y
end
function NNlib.istft(y;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
return_complex::Bool = false,
original_length::Union{Nothing, Int} = nothing,
)
kab = get_backend(y)
use_window = !isnothing(window)
use_window && kab != get_backend(window) && throw(ArgumentError(
"`window` must be on the same device as istft input `y` ($kab), \
instead: `$(get_backend(window))`."))
use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError(
"Expected `0 < length(window) ≤ n_fft=$n_fft`, \
but got `length(window)=$(length(window))`."))
hop_length < 0 && throw(ArgumentError(
"Expected `hop_length > 0`, but got `hop_length=$hop_length`."))
# TODO check `y` eltype is complex
n_frames = size(y, 2)
# Pad window on both sides with `0` to `n_fft` length if needed.
if use_window && length(window) < n_fft
left = ((n_fft - length(window)) ÷ 2) + 1
tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
tmp[left:left + length(window) - 1] .= window
window = tmp
end
# Denormalize.
normalized && (y = y .* eltype(y)(n_fft^0.5);)
region = 1
x = return_complex ? ifft(y, region) : irfft(y, n_fft, region)
# De-apply window.
use_window && (x = x ./ window;)
# col2time.
expected_output_len = n_fft + hop_length * (n_frames - 1)
ids = Vector{Int}(undef, expected_output_len)
in_idx, out_idx = 0, 0
prev_e, v = 0, 0
for col in 0:(n_frames - 1)
for row in 1:n_fft
in_idx += 1
v = row + hop_length * col
v > prev_e || continue
out_idx += 1
ids[out_idx] = in_idx
end
prev_e = v
end
# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
nd = ntuple(_ -> Colon(), ndims(x) - 2)
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
x = @inbounds x[ids, nd...]
# Trim padding.
left = center ? (n_fft ÷ 2 + 1) : 1
right = if isnothing(original_length)
center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len
else
left + original_length - 1
end
x = x[left:right, nd...]
return x
end
================================================
FILE: ext/NNlibForwardDiffExt.jl
================================================
module NNlibForwardDiffExt
using ForwardDiff: ForwardDiff
using NNlib: NNlib
NNlib.within_gradient(x::ForwardDiff.Dual) = true
NNlib.within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true
end
================================================
FILE: ext/NNlibMetalExt.jl
================================================
module NNlibMetalExt
using Metal: method_table, @device_override
using NNlib: NNlib
@device_override NNlib.tanh_fast(x) = Base.FastMath.tanh_fast(x)
end
================================================
FILE: ext/NNlibSpecialFunctionsExt.jl
================================================
module NNlibSpecialFunctionsExt
using NNlib: NNlib, oftf
using SpecialFunctions: erf
# Full gelu (gelu_erf)
NNlib.gelu_erf(x) = x/2*(1 + erf(x/sqrt(oftf(x,2))))
function NNlib.deriv_gelu_erf(x)
SQRT2 = sqrt(oftf(x,2))
Φ = (1 + erf(x/SQRT2))/2
Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π))
end
end
================================================
FILE: src/NNlib.jl
================================================
module NNlib
import Atomix
import ChainRulesCore: rrule
using Base.Broadcast: broadcasted
using Base.Threads
using ChainRulesCore
using GPUArraysCore
using KernelAbstractions
using KernelAbstractions: @atomic
using LinearAlgebra
using LinearAlgebra.BLAS: @blasfunc, BlasInt
using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose
using Random
using ScopedValues
using Statistics
using Statistics: mean
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}
# internal. TODO: change to an approach where amount of threading is controlled, not just on/off
const ALLOW_SPAWNS = ScopedValue(true)
should_use_spawn() = Threads.nthreads(:default) > 1 && ALLOW_SPAWNS[]
"""
@disallow_spawns ex
Disallow NNlib to use `@spawn` on divisible workloads. i.e. within `conv` etc.
"""
macro disallow_spawns(ex)
quote
@with ALLOW_SPAWNS => false $(esc(ex))
end
end
# Include APIs
include("dim_helpers.jl")
export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims
include("activations.jl")
for f in ACTIVATIONS
@eval export $(f)
end
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu, gelu # Aliases
include("attention.jl")
export dot_product_attention, dot_product_attention_scores, make_causal_mask
include("dropout.jl")
export dropout, dropout!
include("softmax.jl")
export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,
logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp
include("batched/batchedadjtrans.jl")
include("batched/batchedmul.jl")
export batched_mul, batched_mul!, ⊠, batched_vec,
batched_transpose, batched_adjoint
include("gemm.jl")
export grid_sample, ∇grid_sample
include("conv.jl")
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
∇conv_filter!, depthwiseconv, depthwiseconv!,
∇depthwiseconv_data, ∇depthwiseconv_data!,
∇depthwiseconv_filter, ∇depthwiseconv_filter!
include("conv_bias_act.jl")
export conv_bias_act, conv_bias_act!
include("bias_act.jl")
export bias_act!
include("fold.jl")
include("ctc.jl")
export ctc_loss
include("pooling.jl")
export maxpool, maxpool!, meanpool, meanpool!, lpnormpool, lpnormpool!,
∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lpnormpool, ∇lpnormpool!
include("padding.jl")
export pad_constant, pad_repeat, pad_reflect, pad_zeros, pad_symmetric, pad_circular
include("upsample.jl")
export upsample_nearest, ∇upsample_nearest,
upsample_linear, ∇upsample_linear,
upsample_bilinear, ∇upsample_bilinear,
upsample_trilinear, ∇upsample_trilinear,
pixel_shuffle
include("gather.jl")
include("scatter.jl")
include("utils.jl")
include("sampling.jl")
include("functions.jl")
include("normalization.jl")
# export batchnorm, ∇batchnorm
## Include implementations
include("impl/padding_edges.jl")
# Direct implementations of convolutional and depthwise-convolutional algorithms
include("impl/conv_direct.jl")
include("impl/depthwiseconv_direct.jl")
# im2col implementations of convolutional and depthwise-convolutional algorithms
include("impl/conv_im2col.jl")
include("impl/depthwiseconv_im2col.jl")
# Direct implementations of pooling
include("impl/pooling_direct.jl")
include("deprecations.jl")
include("rotation.jl")
export imrotate, ∇imrotate
include("audio/stft.jl")
include("audio/spectrogram.jl")
include("audio/mel.jl")
export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks
end # module NNlib
================================================
FILE: src/activations.jl
================================================
## Activation functions
#
# Some of activation functions have its wrapper function for GPU in NNlibCUDAExt.jl.
# https://github.com/JuliaGPU/CuArrays.jl/issues/614
ACTIVATIONS = [
:σ, :hardσ, :hardtanh, :relu,
:leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_sigmoid, :gelu_erf, :swish, :hardswish, :selu,
:celu, :softplus, :softsign, :logσ, :logcosh,
:mish, :tanhshrink, :softshrink, :trelu, :lisht,
:tanh_fast, :sigmoid_fast,
]
# of type float (to allow for integer inputs)
oftf(x, y) = oftype(float(x), y)
# oftype contains control flow on 1.10+, which can lead to type instabilities under AD
function rrule(::typeof(oftf), x, y)
proj_y = ChainRulesCore.ProjectTo(y)
oftf_pullback(Δ) = (NoTangent(), NoTangent(), proj_y(Δ))
return oftf(x, y), oftf_pullback
end
"""
σ(x) = 1 / (1 + exp(-x))
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
function.
Unicode `σ` can be entered as `\\sigma` then tab, in many editors.
The ascii name `sigmoid` is also exported.
See also [`sigmoid_fast`](@ref).
```julia-repl
julia> using UnicodePlots
julia> lineplot(sigmoid, -5, 5, height=7)
┌────────────────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡏⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡔⠋⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠊⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⣀⣀⣀⣀⣀⣀⣀⠤⠤⠤⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> sigmoid === σ
true
```
"""
function σ(x)
t = exp(-abs(x))
ifelse(x ≥ 0, inv(1 + t), t / (1 + t))
end
const sigmoid = σ
"""
hardσ(x) = max(0, min(1, (x + 3) / 6))
Piecewise linear approximation of [`sigmoid`](@ref).
```julia-repl
julia> lineplot(hardsigmoid, -5, 5, height=7)
┌────────────────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠉⠉⠉⠉⠉⠉⠉⠉│ hardσ(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡠⠔⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⡗⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠋⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⠤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> lineplot(sigmoid, -5, 5, height=7)
┌────────────────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡏⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡔⠋⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠊⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⣀⣀⣀⣀⣀⣀⣀⠤⠤⠤⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
hardσ(x) = clamp((x + 3) / 6, 0, 1)
# https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
const hardsigmoid = hardσ
"""
logσ(x)
Return `log(σ(x))` which is computed in a numerically stable way.
```julia-repl
julia> lineplot(logsigmoid, -5, 5, height=7)
┌────────────────────────────────────────┐
0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡧⠤⠔⠒⠒⠒⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ logσ(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠉⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⢀⡤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⡤⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
logσ(x) = -softplus(-x)
const logsigmoid = logσ
"""
hardtanh(x) = max(-1, min(1, x))
Segment-wise linear approximation of `tanh`, much cheaper to compute.
See ["Large Scale Machine Learning"](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf).
See also [`tanh_fast`](@ref).
```julia-repl
julia> lineplot(hardtanh, -2, 2, height=7)
┌────────────────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⠔⠋⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ hardtanh(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡷⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠖⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠖⠋⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⠔⠋⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x
julia> lineplot(tanh, -2, 2, height=7)
┌────────────────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠤⠒⠒⠒⠊⠉⠉⠉│ tanh(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡷⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠊⠁⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⣀⣀⣀⡠⠤⠤⠤⠖⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
hardtanh(x) = clamp(x, oftype(x, -1), oftype(x, 1)) # clamp(x, -1, 1) is type-stable, but would promote Int32, for which we have tests
"""
relu(x) = max(0, x)
[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
activation function.
```julia-repl
julia> lineplot(relu, -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│ relu(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠊⠁⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⡠⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⡠⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⠔⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
relu(x) = ifelse(x<0, zero(x), x) # faster than max(zero(x), x), still preserves NaN
"""
leakyrelu(x, a=0.01) = max(a*x, x)
Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
activation function.
You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
```julia-repl
julia> lineplot(x -> leakyrelu(x, 0.5), -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ #42(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠤⠒⠒⠋⠉⠁⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⣀⣀⠤⠤⠒⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> leakyrelu(-10f0, 0.2)
-2.0f0
julia> leakyrelu(-10f0, 0.02)
-0.5f0
```
"""
leakyrelu(x, a=oftf(x, leakyrelu_a)) = ifelse(x>0, float(x), oftf(x, a*x)) # max(a*x, x) is 3x slower
const leakyrelu_a = 0.01 # also used in gradient below
"""
relu6(x) = min(max(0, x), 6)
[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
activation function capped at 6.
See ["Convolutional Deep Belief Networks"](https://www.cs.toronto.edu/~kriz/conv-cifar10-aug2010.pdf) from CIFAR-10.
```julia-repl
julia> lineplot(relu6, -10, 10, height=7)
┌────────────────────────────────────────┐
6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠎⠉⠉⠉⠉⠉⠉⠉⠉│ relu6(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⡤⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⡠⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡔⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⡧⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-10⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀10⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
relu6(x) = clamp(x, oftype(x, 0), oftype(x, 6)) # clamp promotes, but clamp(x, 0, 6) would promote x::Int32
"""
rrelu(x, lo=1/8, hi=1/3) = max(a*x, x)
# where `a` is randomly sampled from uniform distribution `U(lo, hi)`
Randomized Leaky Rectified Linear Unit activation function.
See ["Empirical Evaluation of Rectified Activations"](https://arxiv.org/abs/1505.00853)
You can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`.
```julia-repl
julia> lineplot(rrelu, -20, 10, height=7)
┌────────────────────────────────────────┐
10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ rrelu(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⠤⣤⣤⢤⣤⣤⠤⠤⠤⢼⠮⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⣰⢀⣆⡄⣄⡄⡠⡰⠦⠷⡜⢢⠷⠳⠢⠊⠉⠉⠀⠀⠁⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠃⠉⠙⠘⠃⠈⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-20⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀10⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> extrema(rrelu.(fill(-10f0, 1000)))
(-3.3316886f0, -1.2548422f0)
```
"""
function rrelu(x::T, l=oftf(x,1/8), u=oftf(x,1/3)) where T<:Number
a = (u - l) * rand(float(T)) + l
return leakyrelu(x, a)
end
"""
elu(x, α=1) = x > 0 ? x : α * (exp(x) - 1)
Exponential Linear Unit activation function.
See ["Fast and Accurate Deep Network Learning by Exponential Linear Units"](https://arxiv.org/abs/1511.07289).
You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
```julia-repl
julia> lineplot(elu, -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ elu(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠔⠒⠋⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⠤⠤⠤⠤⠔⠒⠒⠒⠊⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> elu(-10f0)
-0.9999546f0
julia> elu(-10f0, 2)
-1.9999092f0
```
"""
elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1))
deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α))
"""
gelu_tanh(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) using tanh approximation.
This implementation uses `tanh` which allows for better pattern matching and fusion in optimizing
compilers compared to the sigmoid-based implementation. For a potentially faster implementation
that uses `sigmoid_fast`, see [`gelu_sigmoid`](@ref).
```julia-repl
julia> lineplot(gelu_tanh, -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu_tanh(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣤⣤⣤⣤⣤⣤⣤⣤⡤⠤⠤⠤⠤⠤⠤⠤⣤⣤⣤⡤⡧⠶⠶⠭⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠉⠉⠉⠉⠉⠉⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> lineplot(gelu_tanh, -5, 0, height=7);
julia> lineplot!(ans, swish)
┌────────────────────────────────────────┐
0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu_tanh(x)
│⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x)
│⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⣄⠀⠀⠀⠀⠀⢠⡞⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⢄⣀⣀⡤⢣⠃⠀⠀│
-0.2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠇⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
function gelu_tanh(x)
α = oftf(x, 0.044715)
λ = oftf(x, gelu_λ)
x/2 * (1 + tanh_fast(λ * (x + α * x^3)))
end
const gelu_λ = √(2 / π)
const gelu_2λ = √(8 / π)
function deriv_gelu_tanh(x)
α = oftf(x, 0.044715)
α2 = oftf(x, 0.08943)
λ = oftf(x, gelu_λ)
x2 = x * x
t = muladd(x2, α, one(x))
z = λ * x * t
Ω = tanh_fast(z)
sech2 = 1 - Ω^2
(1 + Ω)/2 + x * λ * muladd(x2, α2, t) * sech2 / 2
end
"""
gelu_sigmoid(x) = x * σ(√(8/π) * (x + 0.044715x^3))
Alternative implementation of the GELU activation function using `sigmoid` instead of `tanh`.
This is mathematically equivalent to [`gelu_tanh`](@ref) but may be faster in some cases.
The sigmoid-based implementation may prevent pattern matching and fusion in some optimizing
compilers. Use [`gelu_tanh`](@ref) if you need better compiler optimization support.
See ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415).
"""
function gelu_sigmoid(x)
α = oftf(x, 0.044715)
λλ = oftf(x, gelu_2λ)
x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x)))
end
function deriv_gelu_sigmoid(x)
α = oftf(x, 0.044715)
α2 = oftf(x, 0.08943)
λλ = oftf(x, gelu_2λ)
x2 = x * x
t = muladd(x2, α, one(x))
Ω = sigmoid_fast(λλ * x * t)
dσ = conj(Ω * (1 - Ω))
muladd(dσ * λλ * muladd(x2, α2, t), x, Ω)
end
"""
gelu_erf(x) = xΦ(x) = 0.5x * (1 + erf(x/√2))
Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) without approximation.
The SpecialFunctions.jl package needs to be loaded to use this function.
"""
function gelu_erf end
function deriv_gelu_erf end
"""
gelu(x) = gelu_tanh(x)
Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415).
See [`gelu_tanh`](@ref).
"""
const gelu = gelu_tanh
# Need to alias the type as well to ensure serialization libraries still work
# See https://github.com/FluxML/NNlib.jl/issues/631
const var"#gelu" = typeof(gelu_tanh)
const deriv_gelu = deriv_gelu_tanh
"""
swish(x) = x * σ(x)
Self-gated activation function.
See ["Swish: a Self-Gated Activation Function"](https://arxiv.org/abs/1710.05941).
```julia-repl
julia> lineplot(swish, -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤│ swish(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋⠁⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⢀⣀⡤⠔⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⣤⣤⡤⡧⠴⠶⠯⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠉⠑⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
@inline swish(x) = x * sigmoid_fast(x)
"""
hardswish(x) = x * hardσ(x)
Hard-Swish activation function.
See ["Searching for MobileNetV3"](https://arxiv.org/abs/1905.02244).
```julia-repl
julia> lineplot(hardswish, -2, 5, height = 7)
┌────────────────────────────────────────┐
5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠒⠉│ hardswish(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠒⠉⠁⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣤⣤⣖⣚⣉⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│
-1 │⠉⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> lineplot(hardswish, -4, 0, height = 7);
julia> lineplot!(ans, swish)
┌────────────────────────────────────────┐
0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜│ hardswish(x)
│⠒⠒⠢⠤⢄⣀⡀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀│ swish(x)
│⠀⠀⠀⠀⠀⠀⠈⠉⠑⠒⠦⢄⣘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡴⠃⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠑⡖⠦⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⢔⠏⠁⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠣⣄⠀⠉⠑⠒⠦⠤⢄⣀⣀⣀⣀⡠⠤⠖⣊⠕⠁⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⠤⡀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀│
-0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠒⠢⠤⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-4⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> hardswish.(-5:5)'
1×11 adjoint(::Vector{Float64}) with eltype Float64:
-0.0 -0.0 -0.0 -0.333333 -0.333333 0.0 0.666667 1.66667 3.0 4.0 5.0
```
"""
@inline hardswish(x) = x * hardσ(x)
deriv_hardswish(x) = ifelse(x < -3, oftf(x,0), ifelse(x > 3, oftf(x,1), x/3 + oftf(x,1/2)))
"""
lisht(x) = x * tanh(x)
Activation function from
["LiSHT: Non-Parametric Linearly Scaled Hyperbolic Tangent ..."](https://arxiv.org/abs/1901.05894)
```julia-repl
julia> lineplot(lisht, -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x)
│⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│
│⠀⠀⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⠦⣄⣀⣀⣇⣀⣀⠤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> lineplot!(ans, logcosh)
┌────────────────────────────────────────┐
2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x)
│⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│ logcosh(x)
│⠢⣄⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⣀⠔│
f(x) │⠀⠈⠑⠢⣀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⣀⠔⠊⠁⠀│
│⠀⠀⠀⠀⠀⠉⠢⢄⡀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⡠⠔⠋⠁⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠦⣌⡓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⣁⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠪⠷⣦⣄⣀⣀⣇⣀⣀⣤⠶⠕⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
lisht(x) = x * tanh_fast(x)
"""
selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1))
λ ≈ 1.05070...
α ≈ 1.67326...
Scaled exponential linear units.
See ["Self-Normalizing Neural Networks"](https://arxiv.org/abs/1706.02515).
```julia-repl
julia> lineplot(selu, -3, 2, height=7)
┌────────────────────────────────────────┐
3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠊⠉⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⡠⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⣉⠭⠛⡏⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⡤⠤⠒⠊⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-2 │⠤⠤⠖⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> selu(-10f0)
-1.7580194f0
```
"""
function selu(x)
λ = oftf(x, selu_λ)
α = oftf(x, selu_α)
λ * ifelse(x > 0, x, @fastmath α * (exp(x) - 1))
end
const selu_λ = 1.0507009873554804934193349852946
const selu_α = 1.6732632423543772848170429916717
function deriv_selu(Ω)
λ = oftf(Ω, selu_λ)
α = oftf(Ω, selu_α)
ifelse(Ω > 0, λ, Ω + α * λ)
end
"""
celu(x, α=1) = x ≥ 0 ? x : α * (exp(x/α) - 1)
Activation function from ["Continuously Differentiable Exponential Linear Units"](https://arxiv.org/abs/1704.07483).
```julia-repl
julia> lineplot(celu, -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ celu(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠔⠒⠋⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⠤⠤⠤⠤⠔⠒⠒⠒⠊⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> celu(-10f0)
-0.9999546f0
```
"""
celu(x, α=1) = ifelse(x ≥ 0, float(x), oftf(x,α) * (exp(x/oftf(x,α)) - 1))
deriv_celu(Ω, α=1) = ifelse(Ω > 0, oftf(Ω, 1), Ω / oftf(Ω, α) + 1)
"""
trelu(x, theta=1) = x > theta ? x : 0
Threshold gated rectified linear activation function.
See ["Zero-bias autoencoders and the benefits of co-adapting features"](https://arxiv.org/abs/1402.3337)
```julia-repl
julia> lineplot(trelu, -2, 4, height=7)
┌────────────────────────────────────────┐
4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ trelu(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠴⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⡏⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣀⣀⣀⣀⣀⣀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀4⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
trelu(x, theta=1) = ifelse(x <= theta, zero(x), x)
const thresholdrelu = trelu
"""
softsign(x) = x / (1 + |x|)
See ["Quadratic Polynomials Learn Better Image Features"](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205) (2009).
```julia-repl
julia> lineplot(softsign, -5, 5, height=7)
┌────────────────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⣀⣀⠤⠤⠤⠤⠤│ softsign(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⡤⠖⠒⠋⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⡔⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡯⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠒⠋⠁⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⠒⠒⠒⠒⠒⠊⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> lineplot!(ans, tanh)
┌────────────────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡤⠖⠊⠉⠉⠉⣉⣉⣉⣉⣉⠭⠭⠭⠭⠭│ softsign(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡔⣃⡤⠖⠒⠋⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanh(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣧⡞⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡯⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡴⠃⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠒⢋⠕⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-1 │⣒⣒⣒⣒⣒⣊⣉⣉⣉⣉⣁⣀⣀⡠⠤⠒⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> softsign(1f0)
0.5f0
julia> softsign(100f0)
0.990099f0
```
"""
softsign(x) = x / (1 + abs(x))
deriv_softsign(x) = 1 / (1 + abs(x))^2
"""
softplus(x) = log(exp(x) + 1)
See ["Deep Sparse Rectifier Neural Networks"](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf), JMLR 2011.
```julia-repl
julia> lineplot(softplus, -3, 3, height=7)
┌────────────────────────────────────────┐
4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠊⠁⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⠤⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡧⠤⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⣀⣀⣀⣀⣀⣀⣀⡠⠤⠤⠤⠤⠔⠒⠒⠚⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> lineplot!(ans, relu)
┌────────────────────────────────────────┐
4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠│ relu(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠⡴⠞⠋⠁│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣤⡴⠞⠋⠁⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⢤⡲⠝⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡧⠤⠒⠊⣉⠥⠚⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⣀⣀⣀⣀⣀⣀⣀⣠⣤⣤⣤⣤⣔⣒⣒⣚⣉⣉⣁⣀⣇⠴⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> softplus(16f0)
16.0f0
```
"""
softplus(x) = log1p(exp(-abs(x))) + relu(x)
"""
logcosh(x)
Return `log(cosh(x))` which is computed in a numerically stable way.
```julia-repl
julia> lineplot(logcosh, -5, 5, height=7)
┌────────────────────────────────────────┐
5 │⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ logcosh(x)
│⠉⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│
│⠀⠀⠀⠑⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠑⠦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠊⠁⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⠦⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠢⢄⣀⣀⣇⣀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
logcosh(x) = x + softplus(-2x) - oftf(x, log2)
const log2 = log(2)
"""
mish(x) = x * tanh(softplus(x))
Activation function from ["Mish: A Self Regularized Non-Monotonic Neural Activation Function"](https://arxiv.org/abs/1908.08681).
```julia-repl
julia> lineplot(mish, -5, 5, height=7)
┌────────────────────────────────────────┐
5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋│ mish(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠒⠁⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠔⠋⠁⠀⠀⠀⠀⠀⠀│
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⡠⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣧⣔⣊⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│
-1 │⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
```
"""
mish(x) = x * tanh(softplus(x))
"""
tanhshrink(x) = x - tanh(x)
See ["Tanhshrink Activation Function"](https://www.gabormelli.com/RKB/Tanhshrink_Activation_Function).
```julia-repl
julia> lineplot(tanhshrink, -3, 3, height=7)
┌────────────────────────────────────────┐
3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanhshrink(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠊│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⣀⡠⠤⠒⠊⠉⠁⠀⠀⠀⠀│
f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⣤⡤⠤⠤⠤⠤⠤⠤⡷⠶⠶⠶⠶⠶⠮⠭⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⣀⡠⠴⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⡠⠴⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> tanhshrink.((-10f0, 10f0))
(-9.0f0, 9.0f0)
```
"""
tanhshrink(x) = x - tanh_fast(x)
"""
softshrink(x, λ=0.5) =
(x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))
See ["Softshrink Activation Function"](https://www.gabormelli.com/RKB/Softshrink_Activation_Function).
```julia-repl
julia> lineplot(softshrink, -2, 2, height=7)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠒⠉⠁│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀│
f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⡤⠤⠤⠤⠤⠤⠤⡧⠤⠤⠤⠤⠶⠮⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⠀⠀⠀⠀⠀⢀⣀⠤⠖⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⣀⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-2 │⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
julia> lineplot!(ans, tanhshrink)
┌────────────────────────────────────────┐
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠒⣉⡡│ tanhshrink(x)
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⣒⣋⠥⠤⠒⠊⠉⠁⠀│
f(x) │⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⣤⣤⡤⠤⠤⠤⠤⠤⠤⡷⠶⠶⠶⠶⠶⠾⠿⠯⠭⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤│
│⠀⢀⣀⡠⠤⠖⢒⣋⠭⠗⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠊⣉⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
-2 │⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
└────────────────────────────────────────┘
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀
julia> softshrink.((-10f0, 10f0))
(-9.5f0, 9.5f0)
```
"""
function softshrink(x, λ = 0.5)
lo = x - oftf(x, λ)
hi = x + oftf(x, λ)
ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi)
end
# Define broadcasts for activation functions on arrays
for f in ACTIVATIONS
@eval $(f)(x::AbstractArray, args...) = $(f).(x, args...)
end
## Faster, less accurate, versions of some.
"""
tanh_fast(x)
This is a faster but slighly less accurate version of `tanh`.
Where Julia's `tanh` function has an error under 2 eps, this
may be wrong by 5 eps, a reduction by less than one decimal digit.
For `x::Float32` this is usually about 10 times faster,
with a smaller speedup for `x::Float64`.
For any other number types, it just calls `tanh`.
See also [`sigmoid_fast`](@ref).
```julia-repl
julia> tanh(0.5f0)
0.46211717f0
julia> tanh_fast(0.5f0)
0.46211714f0
julia> hard_tanh(0.5f0)
0.5f0
```
"""
@inline function tanh_fast(x::Float32)
# This method added in NNlib.jl#345 by @mcabbott and @oscardssmith,
# with coeffiecients found using Remez.jl
x2 = abs2(x)
n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8))
d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7))
ifelse(x2 < 66f0, x * (n / d), sign(x))
end
@inline function tanh_fast(x::Float64)
exp2x = @fastmath exp(x + x)
y = (exp2x - 1) / (exp2x + 1)
# That has large errors near zero; using `expm1` would more accurate, but about as slow as `tanh`.
# Instead, we switch to a polynomial, which is very accurate within its range:
x2 = x * x
ypoly = x * evalpoly(x2, (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953))
ifelse(x2 > 900.0, sign(x), ifelse(x2 < 0.017, ypoly, y))
end
# These approximations are very badly behaved for Float16; none are fast.
# They are also a bit slower with ForwardDiff.Dual numbers, let's use Base:
tanh_fast(x::Number) = Base.tanh(x)
"""
sigmoid_fast(x)
This is a faster, and very slightly less accurate, version of `sigmoid`.
For `x::Float32`, perhaps 3 times faster, and maximum errors 2 eps instead of 1.
See also [`tanh_fast`](@ref).
```julia-repl
julia> sigmoid(0.2f0)
0.54983395f0
julia> sigmoid_fast(0.2f0)
0.54983395f0
julia> hardσ(0.2f0)
0.53333336f0
```
"""
function sigmoid_fast(x::Real)
@static if VERSION ≥ v"1.11-"
@inline
end
t = @fastmath exp(-abs(x))
y = ifelse(x ≥ 0, inv(1 + t), t / (1 + t))
ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y))
end
# For x::Float32, this is not as quick as the rational tanh_fast(x) above,
# but that polynomial has poor relative accuracy for negative x.
sigmoid_fast(x::Float16) = sigmoid(x) # sigmoid_fast is extremely badly behaved at large x
function sigmoid_fast(x::Number)
Base.depwarn("sigmoid only makes sense on real numbers, got $(typeof(x))", :sigmoid_fast)
sigmoid(x)
end
"""
NNlib.fast_act(f, [x::AbstractArray])
Replaces `f == tanh` with [`tanh_fast`](@ref), etc.
Takes an optional 2nd argument, so that you can disable
this replacement for some array or element types.
"""
@inline fast_act(f::F, ::AbstractArray = 1:0) where {F<:Function} = f
@inline fast_act(::typeof(tanh), ::AbstractArray = 1:0) = tanh_fast
@inline fast_act(::typeof(sigmoid), ::AbstractArray = 1:0) = sigmoid_fast
## Define rrules for some activation functions, along with the
## broadcasted rrule activation functions.
## This is a performance hack specifically for Zygote, because it doesn't handle fused
## broadcasts well; but it generally should be good (or at least harmless) for any AD, as
## it saves ADing the broadcasting machinery.
## Related Issue https://github.com/JuliaDiff/ChainRulesCore.jl/issues/271
## TODO: add to the lists below all activations.
UNARY_ACTS = [ # f, dfdx
## In the same order as above!
(:σ, :(conj(Ω * (1 - Ω)))),
(:hardσ, :(ifelse((Ω>0)&(Ω<1), oftf(Ω, 1/6), oftf(Ω, 1)))),
(:logσ, :(sigmoid_fast(-x))),
(:hardtanh, :((Ω>-1) & (Ω<1))),
(:relu, :(Ω > 0)),
(:leakyrelu, :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, leakyrelu_a)))),
(:relu6, :((Ω>0) & (Ω<6))),
# rrelu is random, can't write a rule.
(:elu, :(deriv_elu(Ω))),
(:gelu_tanh, :(deriv_gelu_tanh(x))),
(:gelu_sigmoid, :(deriv_gelu_sigmoid(x))),
(:gelu_erf, :(deriv_gelu_erf(x))),
(:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))),
(:hardswish, :(deriv_hardswish(x))),
# lisht
(:selu, :(deriv_selu(Ω))),
(:celu, :(deriv_celu(Ω))),
(:trelu, :(Ω > 0)),
(:softsign, :(deriv_softsign(x))),
(:softplus, :(sigmoid_fast(x))),
# (:softplus, :(1 - @fastmath exp(-Ω))), # slightly faster, check accuracy?
# logcosh
# mish
(:tanhshrink, :((x - Ω)^2)),
(:softshrink, :(Ω != 0)),
## Fast variants are the same!
(:tanh_fast, :(conj(1 - Ω^2))),
(:sigmoid_fast, :(conj(Ω * (1 - Ω)))),
]
for (f, dfdx) in UNARY_ACTS
@eval @scalar_rule($f(x), $dfdx)
pullback = Symbol(:broadcasted_, f, :_pullback)
@eval function rrule(::typeof(broadcasted),
::typeof($f), x::Union{Numeric, Broadcast.Broadcasted})
Ω = $f.(x)
function $pullback(dΩ)
x_thunk = InplaceableThunk(
dx -> @.(dx += dΩ * $dfdx),
@thunk @.(dΩ * $dfdx)
)
NoTangent(), NoTangent(), x_thunk
end
return Ω, $pullback
end
end
# NO_ACT_GRAD = ChainRulesCore.@not_implemented "for simplicitly NNlib assumes the 2nd argument of this activation function is a constant"
NO_ACT_GRAD = NaN ## Still reminds you not to use this, but is perhaps more GPU friendly.
BINARY_ACTS = [ # f, dfdx1, dfdx2
## In the same order as above!
(:leakyrelu, :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, x2))), NO_ACT_GRAD),
(:elu, :(deriv_elu(Ω, x2)), NO_ACT_GRAD),
(:celu, :(deriv_celu(Ω, x2)), NO_ACT_GRAD),
(:trelu, :(Ω > 0), ZeroTangent()),
(:softshrink, :(Ω != 0), NO_ACT_GRAD),
]
for (f, dfdx1, dfdx2) in BINARY_ACTS
@eval @scalar_rule($f(x1, x2), ($dfdx1, $dfdx2))
pullback = Symbol(:broadcasted_, f, :_pullback_2arg)
@eval function rrule(::typeof(broadcasted),
::typeof($f),
x1::Union{Numeric, Broadcast.Broadcasted}, x2::Number)
Ω = $f.(x1, x2)
## Allowing x2::Array would allow size(Ω) != size(x1), which is not handled here:
$pullback(dΩ) = (NoTangent(), NoTangent(), @.(dΩ * $dfdx1), NO_ACT_GRAD)
return Ω, $pullback
end
end
================================================
FILE: src/attention.jl
================================================
const AA3{T} = AbstractArray{T,3}
const AA4{T} = AbstractArray{T,4}
const AA{N,T} = AbstractArray{T,N}
"""
dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])
Multihead dot product attention used in transformer architectures.
The input arrays must have the first two dimensions given by the number of features
and the sequence length, then an arbitrary number of batch dimensions or none.
Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores
of size `(kv_len, q_len, nheads, batch_size...)`.
See also [`dot_product_attention_scores`](@ref) if you only need the attention scores.
# Arguments
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
It will be added to the attention scores before applying the softmax. Default `nothing`.
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
Default `identity` (no dropout).
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
The mask is applied to the attention scores just before the softmax.
See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`.
- `nheads`: Number of heads to split the input arrays into. Default `1`.
# Examples
```julia
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
y, α = dot_product_attention(q, k, v)
```
"""
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N
batch_size = size(q)[3:end]
batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same."))
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))
x, α = dot_product_attention(q, k, v, args...; kws...)
x = reshape(x, size(x, 1), size(x, 2), batch_size...)
α = reshape(α, size(α)[1:3]..., batch_size...)
return x, α
end
function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
fdrop=identity, mask=nothing, nheads=1)
(all(size.((q, k, v), 1) .% nheads .== 0)) || throw(ArgumentError("""
First dimension in query, key and value must be divisible by `nheads`.
Instead:
- size(q): $(size(q))
- size(k): $(size(k))
- size(v): $(size(v))
- nheads: $nheads
"""))
(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("""
Batch dimensions have to be the same. Instead:
- size(q): $(size(q))
- size(k): $(size(k))
- size(v): $(size(v))
"""))
size(q, 1) == size(k, 1) || throw(ArgumentError("""
First dimension in query and key has to be the same. Instead:
- size(q): $(size(q))
- size(k): $(size(k))
"""))
size(k, 2) == size(v, 2) || throw(ArgumentError("""
Second dimension in key and value has to be the same. Instead:
- size(k): $(size(k))
- size(v): $(size(v))
"""))
# Multihead attention. TODO create fastpath for singlehead attention.
q, k, v = split_heads.((q, k, v), nheads)
x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)
return join_heads(x), α
end
function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]
α = dot_product_attention_scores(q, k, bias; fdrop, mask)
# [α] = [kv_len, q_len, nheads, batch_size]
# The following permutedims and batched_mul are equivalent to
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
vt = permutedims(v, (1, 3, 2, 4))
x = batched_mul(vt, α)
x = permutedims(x, (1, 3, 2, 4))
# [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size]
return x, α
end
"""
dot_product_attention_scores(query, key, [bias]; [fdrop, mask])
Return the attention scores for the [`dot_product_attention`](@ref).
Input arrays must have dimensions
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
See [`dot_product_attention`](@ref) for more details.
"""
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
fdrop=identity, mask=nothing) where T
# The following permutedims and batched_mul are equivalent to
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
kt = permutedims(k, (3, 1, 2, 4))
qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1))
logits = batched_mul(kt, qt)
# [logits] = [kv_len, q_len, nheads, batch_size]
logits = apply_attn_bias(logits, bias)
logits = apply_attn_mask(logits, mask)
α = softmax(logits, dims=1)
return fdrop(α)
end
apply_attn_bias(logits, bias::Nothing) = logits
apply_attn_bias(logits, bias) = logits .+ bias
apply_attn_mask(logits, mask::Nothing) = logits
function apply_attn_mask(logits, mask)
neginf = typemin(eltype(logits))
ifelse.(mask, logits, neginf)
end
"""
make_causal_mask(x, dims=2)
Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
Its elements are set such that `m[i, j] == i ≤ j`.
Can be used to mask the attention scores in [`dot_product_attention`](@ref).
"""
function make_causal_mask(x::AbstractArray; dims::Int=2)
len = size(x, dims)
mask = triu(trues_like(x, (len, len)))
return mask
end
trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)
split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)
join_heads(x) = reshape(x, :, size(x)[3:end]...)
@non_differentiable make_causal_mask(::Any...)
@non_differentiable trues_like(::Any...)
@non_differentiable falses_like(::Any...)
================================================
FILE: src/audio/mel.jl
================================================
"""
melscale_filterbanks(;
n_freqs::Int, n_mels::Int, sample_rate::Int,
fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2))
Create triangular Mel scale filter banks
(ref: [Mel scale - Wikipedia](https://en.wikipedia.org/wiki/Mel_scale)).
Each column is a filterbank that highlights its own frequency.
# Arguments:
- `n_freqs::Int`: Number of frequencies to highlight.
- `n_mels::Int`: Number of mel filterbanks.
- `sample_rate::Int`: Sample rate of the audio waveform.
- `fmin::Float32`: Minimum frequency in Hz.
- `fmax::Float32`: Maximum frequency in Hz.
# Returns:
Filterbank matrix of shape `(n_freqs, n_mels)` where each column is a filterbank.
```jldoctest
julia> n_mels = 8;
julia> fb = melscale_filterbanks(; n_freqs=200, n_mels, sample_rate=16000);
julia> plot = lineplot(fb[:, 1]);
julia> for i in 2:n_mels
lineplot!(plot, fb[:, i])
end
julia> plot
┌────────────────────────────────────────┐
1 │⠀⡀⢸⠀⢸⠀⠀⣧⠀⠀⢸⡄⠀⠀⠀⣷⠀⠀⠀⠀⠀⣷⠀⠀⠀⠀⠀⠀⢀⣿⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⡇⢸⡆⢸⡇⠀⣿⠀⠀⡜⡇⠀⠀⢰⠋⡆⠀⠀⠀⢰⠁⡇⠀⠀⠀⠀⠀⡸⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⣿⢸⡇⡇⡇⢰⠹⡄⠀⡇⢱⠀⠀⢸⠀⢣⠀⠀⠀⡜⠀⢸⡀⠀⠀⠀⢀⠇⠀⠈⡇⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⣿⡇⡇⡇⡇⢸⠀⡇⢀⠇⠸⡀⠀⡇⠀⠸⡀⠀⢀⠇⠀⠀⢇⠀⠀⠀⡸⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀│
│⢠⢻⡇⡇⡇⢱⢸⠀⢇⢸⠀⠀⡇⢀⠇⠀⠀⡇⠀⢸⠀⠀⠀⠸⡀⠀⢠⠇⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀│
│⢸⢸⡇⢱⡇⢸⡇⠀⢸⢸⠀⠀⢣⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⢇⠀⡜⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀⠀│
│⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⡎⠀⠀⠀⠈⣶⠁⠀⠀⠀⠀⠸⣤⠃⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀⠀⠀│
│⢸⠀⡇⢸⠀⠀⡇⠀⠀⡇⠀⠀⠀⡇⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⢱⡀⠀⠀⠀⠀│
│⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⢇⠀⠀⠀⢀⠿⡀⠀⠀⠀⠀⢰⠛⡄⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀⠀⠀│
│⢸⢸⡇⡸⡇⢸⡇⠀⢸⢸⠀⠀⡜⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⡎⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀│
│⢸⢸⡇⡇⡇⡸⢸⠀⡎⢸⠀⠀⡇⠈⡆⠀⠀⡇⠀⢸⠀⠀⠀⢰⠁⠀⠘⡆⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀│
│⡇⢸⡇⡇⡇⡇⢸⠀⡇⠈⡆⢰⠁⠀⡇⠀⢰⠁⠀⠈⡆⠀⠀⡎⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀│
│⡇⢸⢸⡇⡇⡇⠸⣰⠃⠀⡇⡸⠀⠀⢸⠀⡜⠀⠀⠀⢣⠀⢸⠁⠀⠀⠀⠈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀│
│⡇⡇⢸⠇⢸⡇⠀⣿⠀⠀⢣⡇⠀⠀⠸⣄⠇⠀⠀⠀⠸⡀⡇⠀⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄│
0 │⣇⣇⣸⣀⣸⣀⣀⣟⣀⣀⣸⣃⣀⣀⣀⣿⣀⣀⣀⣀⣀⣿⣀⣀⣀⣀⣀⣀⣈⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣱│
└────────────────────────────────────────┘
⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀200⠀
```
"""
function melscale_filterbanks(;
n_freqs::Int, n_mels::Int, sample_rate::Int,
fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2),
)
mel_min, mel_max = _hz_to_mel(fmin), _hz_to_mel(fmax)
mel_points = range(mel_min, mel_max; length=n_mels + 2)
all_freqs = collect(range(0f0, Float32(sample_rate ÷ 2); length=n_freqs))
freq_points = _mel_to_hz.(mel_points)
filter_banks = _triangular_filterbanks(freq_points, all_freqs)
if any(maximum(filter_banks; dims=1) .≈ 0f0)
@warn """At least one mel filterbank has all zero values.
The value for `n_mels=$n_mels` may be set too high.
Or the value for `n_freqs=$n_freqs` may be set too low.
"""
end
return filter_banks
end
_hz_to_mel(freq::T) where T = T(2595) * log10(T(1) + (freq / T(700)))
_mel_to_hz(mel::T) where T = T(700) * (T(10)^(mel / T(2595)) - T(1))
"""
_triangular_filterbanks(
freq_points::Vector{Float32}, all_freqs::Vector{Float32})
Create triangular filter banks.
# Arguments:
- `freq_points::Vector{Float32}`: Filter midpoints of size `n_filters`.
- `all_freqs::Vector{Float32}`: Frequency points of size `n_freqs`.
# Returns:
Array of size `(n_freqs, n_filters)`.
"""
function _triangular_filterbanks(
freq_points::Vector{Float32}, all_freqs::Vector{Float32},
)
diff = @view(freq_points[2:end]) .- @view(freq_points[1:end - 1])
slopes = transpose(reshape(freq_points, :, 1) .- reshape(all_freqs, 1, :))
down_slopes = -(@view(slopes[:, 1:end - 2]) ./ reshape(@view(diff[1:end - 1]), 1, :))
up_slopes = @view(slopes[:, 3:end]) ./ reshape(@view(diff[2:end]), 1, :)
return max.(0f0, min.(down_slopes, up_slopes))
end
================================================
FILE: src/audio/spectrogram.jl
================================================
"""
spectrogram(waveform;
pad::Int = 0, n_fft::Int, hop_length::Int, window,
center::Bool = true, power::Real = 2.0,
normalized::Bool = false, window_normalized::Bool = false,
)
Create a spectrogram or a batch of spectrograms from a raw audio signal.
# Arguments
- `pad::Int`:
Then amount of padding to apply on both sides.
- `window_normalized::Bool`:
Whether to normalize the waveform by the window’s L2 energy.
- `power::Real`:
Exponent for the magnitude spectrogram (must be ≥ 0)
e.g., `1` for magnitude, `2` for power, etc.
If `0`, complex spectrum is returned instead.
See [`stft`](@ref) for other arguments.
# Returns
Spectrogram in the shape `(T, F, B)`, where
`T` is the number of window hops and `F = n_fft ÷ 2 + 1`.
"""
function spectrogram(waveform::AbstractArray{T};
pad::Int = 0, n_fft::Int, hop_length::Int, window,
center::Bool = true, power::Real = 2.0,
normalized::Bool = false, window_normalized::Bool = false,
) where T
pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);)
# Pack batch dimensions.
sz = size(waveform)
spec_ = stft(reshape(waveform, (sz[1], :));
n_fft, hop_length, window, center, normalized)
# Unpack batch dimensions.
spec = reshape(spec_, (size(spec_)[1:2]..., sz[2:end]...))
window_normalized && (spec = spec .* inv(norm(window));)
if power > 0
p = T(power)
spec = abs.(spec .+ eps(T)).^p
end
return spec
end
"""
power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)
Convert a power spectrogram (amplitude squared) to decibel (dB) units.
# Arguments
- `s`: Input power.
- `ref`: Scalar w.r.t. which the input is scaled.
- `amin`: Minimum threshold for `s`.
- `top_db`: Threshold the output at `top_db` below the peak:
`max.(s_db, maximum(s_db) - top_db)`.
# Returns
`s_db ~= 10 * log10(s) - 10 * log10(ref)`
"""
function power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)
log_spec = 10f0 .* (log10.(max.(amin, s)) .- log10.(max.(amin, ref)))
return max.(log_spec, maximum(log_spec) - top_db)
end
"""
db_to_power(s_db; ref::Real = 1f0)
Inverse of [`power_to_db`](@ref).
"""
function db_to_power(s_db; ref::Real = 1f0)
return ref .* 10f0.^(s_db .* 0.1f0)
end
================================================
FILE: src/audio/stft.jl
================================================
"""
hamming_window(
window_length::Int, ::Type{T} = Float32; periodic::Bool = true,
α::T = T(0.54), β::T = T(0.46),
) where T <: Real
Hamming window function
(ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)).
Generalized version of `hann_window`.
``w[n] = \\alpha - \\beta \\cos(\\frac{2 \\pi n}{N - 1})``
Where ``N`` is the window length.
```julia-repl
julia> lineplot(hamming_window(100); width=30, height=10)
┌──────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠚⠉⠉⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠁⠀⠀⠀⠀⠀⠈⢢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⢰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⣠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⡀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⢰⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⡰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀│
│⠀⠀⠀⢀⠴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀│
│⠀⢀⡠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⣀⠀│
0 │⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉│
└──────────────────────────────┘
⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀
```
# Arguments:
- `window_length::Int`: Size of the window.
- `::Type{T}`: Elemet type of the window.
# Keyword Arguments:
- `periodic::Bool`: If `true` (default), returns a window to be used as
periodic function. If `false`, return a symmetric window.
Following always holds:
```jldoctest
julia> N = 256;
julia> hamming_window(N; periodic=true) ≈ hamming_window(N + 1; periodic=false)[1:end - 1]
true
```
- `α::Real`: Coefficient α in the equation above.
- `β::Real`: Coefficient β in the equation above.
# Returns:
Vector of length `window_length` and eltype `T`.
"""
function hamming_window(
window_length::Int, ::Type{T} = Float32; periodic::Bool = true,
α::T = T(0.54), β::T = T(0.46),
) where T <: Real
window_length < 1 && throw(ArgumentError(
"`window_length` must be > 0, instead: `$window_length`."))
n::T = ifelse(periodic, window_length, window_length - 1)
scale = T(2) * π / n
return [α - β * cos(scale * T(k)) for k in 0:(window_length - 1)]
end
"""
hann_window(
window_length::Int, ::Type{T} = Float32; periodic::Bool = true,
) where T <: Real
Hann window function
(ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)).
``w[n] = \\frac{1}{2}[1 - \\cos(\\frac{2 \\pi n}{N - 1})]``
Where ``N`` is the window length.
```julia-repl
julia> lineplot(hann_window(100); width=30, height=10)
┌──────────────────────────────┐
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠚⠉⠉⠉⠢⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡔⠁⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⠀⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⠀⠀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⠀⢀⡜⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀│
│⠀⠀⠀⠀⢀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀│
│⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠣⡀⠀⠀│
0 │⣀⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢤⣀│
└──────────────────────────────┘
⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀
```
# Arguments:
- `window_length::Int`: Size of the window.
- `::Type{T}`: Elemet type of the window.
# Keyword Arguments:
- `periodic::Bool`: If `true` (default), returns a window to be used as
periodic function. If `false`, return a symmetric window.
Following always holds:
```jldoctest
julia> N = 256;
julia> hann_window(N; periodic=true) ≈ hann_window(N + 1; periodic=false)[1:end - 1]
true
julia> hann_window(N) ≈ hamming_window(N; α=0.5f0, β=0.5f0)
true
```
# Returns:
Vector of length `window_length` and eltype `T`.
"""
function hann_window(
window_length::Int, ::Type{T} = Float32; periodic::Bool = true,
) where T <: Real
hamming_window(window_length, T; periodic, α=T(0.5), β=T(0.5))
end
"""
stft(x;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
)
Short-time Fourier transform (STFT).
The STFT computes the Fourier transform of short overlapping windows of the input,
giving frequency components of the signal as they change over time.
``Y[\\omega, m] = \\sum_{k = 0}^{N - 1} \\text{window}[k] \\text{input}[m \\times \\text{hop length} + k] \\exp(-j \\frac{2 \\pi \\omega k}{\\text{n fft}})``
where ``N`` is the window length,
``\\omega`` is the frequency ``0 \\le \\omega < \\text{n fft}``
and ``m`` is the index of the sliding window.
# Arguments:
- `x`: Input, must be either a 1D time sequence (`(L,)` shape)
or a 2D batch of time sequence (`(L, B)` shape).
# Keyword Arguments:
- `n_fft::Int`: Size of Fourier transform.
- `hop_length::Int`: Distance between neighboring sliding window frames.
- `window`: Optional window function to apply.
Must be 1D vector `0 < length(window) ≤ n_fft`.
If window is shorter than `n_fft`, it is padded with zeros on both sides.
If `nothing` (default), then no window is applied.
- `center::Bool`: Whether to pad input on both sides so that ``t``-th frame
is centered at time ``t \\times \\text{hop length}``.
Padding is done with `pad_reflect` function.
- `normalized::Bool`: Whether to return normalized STFT,
i.e. multiplied with ``\\text{n fft}^{-0.5}``.
# Returns:
Complex array of shape `(n_fft, n_frames, B)`,
where `B` is the optional batch dimension.
"""
function stft end
"""
istft(y;
n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
center::Bool = true, normalized::Bool = false,
return_complex::Bool = false,
original_length::Union{Nothing, Int} = nothing,
)
Inverse Short-time Fourier Transform.
Return the least squares estimation of the original signal
# Arguments:
- `y`: Input complex array in the `(n_fft, n_frames, B)` shape.
Where `B` is the optional batch dimension.
# Keyword Arguments:
- `n_fft::Int`: Size of Fourier transform.
- `hop_length::Int`: Distance between neighboring sliding window frames.
- `window`: Window function that was applied to the input of `stft`.
If `nothing` (default), then no window was applied.
- `center::Bool`: Whether input to `stft` was padded on both sides
so that ``t``-th frame is centered at time ``t \\times \\text{hop length}``.
Padding is done with `pad_reflect` function.
- `normalized::Bool`: Whether input to `stft` was normalized.
- `return_complex::Bool`: Whether the output should be complex,
or if the input should be assumed to derive from a real signal and window.
- `original_length::Union{Nothing, Int}`: Optional size of the first dimension
of the input to `stft`. Helps restoring the exact `stft` input size.
Otherwise, the array might be a bit shorter.
"""
function istft end
================================================
FILE: src/batched/batchedadjtrans.jl
================================================
import Base: -
import Adapt: adapt_structure, adapt
_batched_doc = """
batched_transpose(A::AbstractArray{T,3})
batched_adjoint(A)
Equivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`.
These exist to control how `batched_mul` behaves,
as it operates on such matrix slices of an array with `ndims(A)==3`.
`PermutedDimsArray(A, (2,1,3))` is equivalent to `batched_transpose(A)`,
and is also understood by `batched_mul` (and more widely supported elsewhere).
BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}
Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose` etc.
"""
@doc _batched_doc
struct BatchedTranspose{T, S} <: AbstractArray{T, 3}
parent::S
BatchedTranspose{T, S}(X::S) where {T, S} = new{T, S}(X)
end
@doc _batched_doc
batched_transpose(A::AbstractArray{T, 3}) where T = BatchedTranspose(A)
batched_transpose(A::BatchedTranspose) = A.parent
@doc _batched_doc
struct BatchedAdjoint{T, S} <: AbstractArray{T, 3}
parent::S
BatchedAdjoint{T, S}(X::S) where {T, S} = new{T, S}(X)
end
@doc _batched_doc
batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A)
batched_adjoint(A::BatchedAdjoint) = A.parent
batched_adjoint(A::BatchedTranspose{<:Real}) = A.parent
batched_transpose(A::BatchedAdjoint{<:Real}) = A.parent
batched_adjoint(A::PermutedDimsArray{<:Real,3,(2,1,3)}) = A.parent
batched_transpose(A::PermutedDimsArray{<:Number,3,(2,1,3)}) = A.parent
# if you can't unwrap, put BatchedAdjoint outside (for dispatch):
batched_transpose(A::BatchedAdjoint{<:Complex}) = BatchedAdjoint(BatchedTranspose(A.parent))
BatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
BatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)
const BatchedAdjOrTrans{T, S} = Union{BatchedTranspose{T, S}, BatchedAdjoint{T, S}}
LinearAlgebra.wrapperop(A::BatchedAdjoint) = batched_adjoint
LinearAlgebra.wrapperop(B::BatchedTranspose) = batched_transpose
# AbstractArray Interface
Base.length(A::BatchedAdjOrTrans) = length(A.parent)
Base.size(m::BatchedAdjOrTrans) = (size(m.parent, 2), size(m.parent, 1), size(m.parent, 3))
Base.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.parent, 3))
Base.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian()
Base.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k)
Base.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k))
Base.@propagate_inbounds Base.setindex!(m::BatchedTranspose, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjoint, v, i::Int, j::Int, k::Int) = setindex!(m.parent, adjoint(v), j, i, k)
Base.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims)
Base.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims)
Base.similar(A::BatchedAdjOrTrans, T::Type) = similar(A.parent, T, size(A))
Base.similar(A::BatchedAdjOrTrans) = similar(A.parent, size(A))
Base.parent(A::BatchedAdjOrTrans) = A.parent
(-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent)
(-)(A::BatchedTranspose) = BatchedTranspose(-A.parent)
# C interface
function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
sp = strides(A.parent)
(sp[2], sp[1], sp[3])
end
function Base.stride(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}, d::Integer)
d == 1 && return Base.stride(A.parent, 2)
d == 2 && return Base.stride(A.parent, 1)
Base.stride(A.parent, d)
end
Base.pointer(A::BatchedAdjOrTrans) = pointer(parent(A))
Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =
Base.unsafe_convert(Ptr{T}, parent(A))
# Gradients
function rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3})
b_transpose_back(Δ) = (NoTangent(), batched_transpose(unthunk(Δ)))
batched_transpose(A), b_transpose_back
end
function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3})
b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(unthunk(Δ)))
batched_adjoint(A), b_adjoint_back
end
adapt_structure(to, x::BatchedAdjoint) = BatchedAdjoint(adapt(to, parent(x)))
adapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x)))
Broadcast.BroadcastStyle(::Type{<:BatchedAdjOrTrans{T, S}}) where {T, S} = Broadcast.BroadcastStyle(S)
================================================
FILE: src/batched/batchedmul.jl
================================================
_unbatch(A) = A
_unbatch(A::BatchedAdjOrTrans) = parent(A)
"""
batched_mul(A, B) -> C
A ⊠ B # \\boxtimes
Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent
any indices in the last dimensions.
If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.
To transpose each matrix, apply `batched_transpose` to the array,
or `batched_adjoint` for conjugate-transpose:
```jldoctest
julia> A, B = randn(2,5,17), randn(5,9,17);
julia> A ⊠ B |> size
(2, 9, 17)
julia> batched_adjoint(A) |> size
(5, 2, 17)
julia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size
(2, 9, 17)
julia> A ⊠ randn(5,9,1) |> size
(2, 9, 17)
julia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3))
true
```
The equivalent `PermutedDimsArray` may be used in place of `batched_transpose`.
Other permutations are also handled by BLAS,
provided that the batch index `k` is not the first dimension of the underlying array.
Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine.
However, `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS,
since the batch dimension is the contiguous one: `stride(A,3) == 1`.
This will be copied, as doing so is faster than `batched_mul_generic!`.
Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
"""
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :)
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end
function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 ||
throw(DimensionMismatch("batch size mismatch: A != B"))
_batched_mul(storage_typejoin(A, B), A, B)
end
const ⊠ = batched_mul
function _batched_mul(::Type, A, B)
T = promote_type(eltype(A), eltype(B))
C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))))
batched_mul!(C, A, B)
C
end
function _batched_mul(::Type{<:DenseArray{T}}, A, B) where {T<:BlasFloat}
C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))))
batched_mul!(C, _copy_if_faster(A), _copy_if_faster(B))
C
end
function _copy_if_faster(X::AbstractArray{<:Number, 3})
is_strided(X) || return X
if Base.stride(X, 3) == 1 && Base.stride(X, 1) != 1
@debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(X)
return copy(X)
end
X
end
function _copy_if_faster(X::BatchedAdjoint{<:Complex})
Xbase = _unbatch(X)
is_strided(Xbase) || return X
if Base.stride(Xbase, 1) != 1
@debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(_unbatch(X))
return copy(X) # or batched_adjoint(copy(Xbase)), may be better on GPU?
end
X
end
# Gradient, allowing that size(A,3)==1 means it's "broadcasted" out to size(B,3)
function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3})
function batched_mul_pullback(_Δ)
Δ = unthunk(_Δ)
Athunk = @thunk begin
tmp = batched_mul(Δ, batched_adjoint(B))
size(A,3) == 1 ? sum(tmp, dims=3) : tmp
end
Bthunk = @thunk begin
tmp = batched_mul(batched_adjoint(A), Δ)
size(B,3) == 1 ? sum(tmp, dims=3) : tmp
end
return (NoTangent(), Athunk, Bthunk)
end
batched_mul(A, B), batched_mul_pullback
end
"""
batched_mul(A::Array{T,3}, B::Matrix)
batched_mul(A::Matrix, B::Array{T,3})
A ⊠ B
This is always matrix-matrix multiplication, but
either `A` or `B` may lack a batch index.
* When `B` is a matrix, result has `C[:,:,k] == A[:,:,k] * B[:,:]` for all `k`.
* When `A` is a matrix, then `C[:,:,k] == A[:,:] * B[:,:,k]`.
This can also be done by reshaping and calling `*`,
for instance `A ⊡ B` using TensorCore.jl, but is implemented here using
`batched_gemm` instead of `gemm`.
```jldoctest
julia> randn(16,8,32) ⊠ randn(8,4) |> size
(16, 4, 32)
julia> randn(16,8,32) ⊠ randn(8,4,1) |> size # equivalent
(16, 4, 32)
julia> randn(16,8) ⊠ randn(8,4,32) |> size
(16, 4, 32)
```
See also `batched_vec` to regard `B` as a batch of vectors, `A[:,:,k] * B[:,k]`.
"""
batched_mul(A::AbstractArray{T,3} where T, B::AbstractMatrix) = _semi_batched_mul(A,B)
# Simplify signature of batched_mul by hiding dispatch on Adjoint etc:
_semi_batched_mul(A::AbstractArray{<:Any,3}, B::AbstractMatrix) =
batched_mul(A, reshape(B, size(B)..., 1))
_semi_batched_mul(A::AbstractArray{<:Any,3}, B::Adjoint{<:Number,<:AbstractMatrix}) =
batched_mul(A, batched_adjoint(reshape(parent(B), size(parent(B))..., 1)))
_semi_batched_mul(A::AbstractArray{<:Any,3}, B::Transpose{<:Number,<:AbstractMatrix}) =
batched_mul(A, batched_transpose(reshape(parent(B), size(parent(B))..., 1)))
batched_mul(A::AbstractMatrix, B::AbstractArray{T,3} where T) = _semi_batched_mul(A,B)
_semi_batched_mul(A::AbstractMatrix, B::AbstractArray{<:Any,3}) =
batched_mul(reshape(A, size(A)..., 1), B)
_semi_batched_mul(A::Adjoint{<:Number,<:AbstractMatrix}, B::AbstractArray{<:Any,3}) =
batched_mul(batched_adjoint(reshape(parent(A), size(parent(A))..., 1)), B)
_semi_batched_mul(A::Transpose{<:Number,<:AbstractMatrix}, B::AbstractArray{<:Any,3}) =
batched_mul(batched_transpose(reshape(parent(A), size(parent(A))..., 1)), B)
"""
batched_vec(A::AbstractArray{T,3}, B::AbstractMatrix)
batched_vec(A::AbstractArray{T,3}, b::AbstractVector)
batched_vec(A::AbstractArray, B::AbstractArray)
Batched matrix-vector multiplication. For the 3D case:
the result has `C[:,:,k] == A[:,:,k] * B[:,k]` for all `k`,
or else `C[:,:,k] == A[:,:,k] * b` for `b::Vector`.
For the general N-D case where `ndims(A) == ndims(B) + 1`:
the result has `C[:,k...] == A[:,:,k...] * B[:,k...]` for all batch indices `k...`.
The batch dimensions must match: `size(A)[3:end] == size(B)[2:end]`.
With the same argument types, `batched_mul(A, B)` would regard `B` as
a fixed matrix, not a batch of vectors. Both reshape and then
call `batched_mul(::Array{T,3}, ::Array{T,3})`.
```jldoctest
julia> A, B, b = randn(16,8,32), randn(8,32), randn(8);
julia> batched_vec(A,B) |> size
(16, 32)
julia> batched_vec(A,b) |> size
(16, 32)
julia> A4d, B3d = randn(16,8,10,32), randn(8,10,32); # 4D and 3D arrays
julia> batched_vec(A4d, B3d) |> size
(16, 10, 32)
```
"""
function batched_vec(A::AbstractArray, B::AbstractArray)
ndims(A) == ndims(B) + 1 || throw(DimensionMismatch(
"batched_vec requires ndims(A) == ndims(B) + 1, got ndims(A)=$(ndims(A)) and ndims(B)=$(ndims(B))"))
size(A)[3:end] == size(B)[2:end] || throw(DimensionMismatch(
"batch dimensions must match: size(A)[3:end]=$(size(A)[3:end]) != size(B)[2:end]=$(size(B)[2:end])"))
# Reshape B to add a singleton dimension for matrix multiplication
B_reshaped = reshape(B, size(B, 1), 1, size(B)[2:end]...)
# Perform batched multiplication
C = batched_mul(A, B_reshaped)
# Remove the singleton dimension
return dropdims(C, dims=2)
end
batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) =
reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3))
# If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway:
batched_vec(A::AbstractArray{T,3} where T, B::AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}) =
batched_vec(A, copy(B))
batched_vec(A::AbstractArray{T,3} where T, b::AbstractVector) =
reshape(batched_mul(A, reshape(b, length(b), 1, 1)), size(A,1), size(A,3))
"""
batched_mul!(C, A, B) -> C
batched_mul!(C, A, B, α=1, β=0)
In-place batched matrix multiplication, equivalent to
`mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`.
If `size(B,3) == 1` then every batch uses `B[:,:,1]` instead.
This will call `batched_gemm!` whenever possible. For real arrays this means that,
for `X ∈ [A,B,C]`, either `stride(X,1)==1` or `stride(X,2)==1`, the latter may
be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`.
Unlike `batched_mul` this will never make a copy.
For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen.
In this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then
only `stride(AorB::BatchedAdjoint,2) == 1` is accepted.
"""
function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3},
α::Number=one(T), β::Number=zero(T)) where {T}
_batched_mul!(storage_typejoin(C,A,B), C, A, B, α, β)
C
end
_batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β)
_batched_mul!(::Type{D
gitextract_0xi0432k/
├── .buildkite/
│ └── pipeline.yml
├── .codecov.yml
├── .github/
│ ├── copilot-instructions.md
│ ├── dependabot.yml
│ └── workflows/
│ ├── BenchmarkTrigger.yml
│ ├── CompatHelper.yml
│ ├── Downstream.yml
│ ├── TagBot.yml
│ ├── ci.yml
│ ├── clean_preview.yml
│ └── pr_comment.yml
├── .gitignore
├── LICENSE.md
├── Project.toml
├── README.md
├── benchmark/
│ ├── Project.toml
│ ├── benchmarks.jl
│ ├── perf_report.jl
│ └── runbenchmarks.jl
├── docs/
│ ├── .gitignore
│ ├── Project.toml
│ ├── make.jl
│ └── src/
│ ├── assets/
│ │ ├── flux.css
│ │ └── jfk.flac
│ ├── audio.md
│ ├── index.md
│ └── reference.md
├── ext/
│ ├── NNlibAMDGPUExt/
│ │ ├── NNlibAMDGPUExt.jl
│ │ ├── activations.jl
│ │ ├── batched_mul.jl
│ │ ├── conv.jl
│ │ └── pool.jl
│ ├── NNlibCUDACUDNNExt/
│ │ ├── NNlibCUDACUDNNExt.jl
│ │ ├── activations.jl
│ │ ├── batchnorm.jl
│ │ ├── conv.jl
│ │ ├── pooling.jl
│ │ └── softmax.jl
│ ├── NNlibCUDAExt/
│ │ ├── NNlibCUDAExt.jl
│ │ ├── activations.jl
│ │ ├── batchedadjtrans.jl
│ │ ├── batchedmul.jl
│ │ ├── ctc.jl
│ │ ├── sampling.jl
│ │ ├── scatter.jl
│ │ └── utils.jl
│ ├── NNlibEnzymeCoreExt/
│ │ └── NNlibEnzymeCoreExt.jl
│ ├── NNlibFFTWExt/
│ │ ├── NNlibFFTWExt.jl
│ │ └── stft.jl
│ ├── NNlibForwardDiffExt.jl
│ ├── NNlibMetalExt.jl
│ └── NNlibSpecialFunctionsExt.jl
├── src/
│ ├── NNlib.jl
│ ├── activations.jl
│ ├── attention.jl
│ ├── audio/
│ │ ├── mel.jl
│ │ ├── spectrogram.jl
│ │ └── stft.jl
│ ├── batched/
│ │ ├── batchedadjtrans.jl
│ │ └── batchedmul.jl
│ ├── bias_act.jl
│ ├── conv.jl
│ ├── conv_bias_act.jl
│ ├── ctc.jl
│ ├── deprecations.jl
│ ├── dim_helpers/
│ │ ├── ConvDims.jl
│ │ ├── DenseConvDims.jl
│ │ ├── DepthwiseConvDims.jl
│ │ └── PoolDims.jl
│ ├── dim_helpers.jl
│ ├── dropout.jl
│ ├── fold.jl
│ ├── functions.jl
│ ├── gather.jl
│ ├── gemm.jl
│ ├── impl/
│ │ ├── conv_direct.jl
│ │ ├── conv_im2col.jl
│ │ ├── depthwiseconv_direct.jl
│ │ ├── depthwiseconv_im2col.jl
│ │ ├── padding_edges.jl
│ │ └── pooling_direct.jl
│ ├── normalization.jl
│ ├── padding.jl
│ ├── pooling.jl
│ ├── rotation.jl
│ ├── sampling.jl
│ ├── scatter.jl
│ ├── softmax.jl
│ ├── upsample.jl
│ └── utils.jl
└── test/
├── Project.toml
├── activations.jl
├── attention.jl
├── batchedmul.jl
├── bias_act.jl
├── conv.jl
├── conv_bias_act.jl
├── ctc.jl
├── dropout.jl
├── ext_amdgpu/
│ ├── activations.jl
│ ├── attention.jl
│ ├── batched_mul.jl
│ ├── batched_repr.jl
│ ├── conv.jl
│ ├── dropout.jl
│ ├── pool.jl
│ ├── runtests.jl
│ ├── softmax.jl
│ └── storage_type.jl
├── ext_cuda/
│ ├── activations.jl
│ ├── batchedadjtrans.jl
│ ├── batchedmul.jl
│ ├── batchnorm.jl
│ ├── conv.jl
│ ├── ctc.jl
│ ├── dropout.jl
│ ├── fold.jl
│ ├── gather.jl
│ ├── pooling.jl
│ ├── runtests.jl
│ ├── sampling.jl
│ ├── scatter.jl
│ ├── softmax.jl
│ └── test_utils.jl
├── ext_metal/
│ ├── activations.jl
│ └── runtests.jl
├── functions.jl
├── inference.jl
├── padding.jl
├── pooling.jl
├── runtests.jl
├── sampling.jl
├── softmax.jl
├── test_utils.jl
├── testsuite/
│ ├── fold.jl
│ ├── gather.jl
│ ├── rotation.jl
│ ├── scatter.jl
│ ├── spectral.jl
│ └── upsample.jl
└── utils.jl
Condensed preview — 141 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (712K chars).
[
{
"path": ".buildkite/pipeline.yml",
"chars": 3603,
"preview": "steps:\n - label: \":julia: Julia {{matrix.julia}} - CUDA GPU\"\n command:\n - echo 'CUDA = \"052768ef-5323-5732-b1bb"
},
{
"path": ".codecov.yml",
"chars": 15,
"preview": "comment: false\n"
},
{
"path": ".github/copilot-instructions.md",
"chars": 7201,
"preview": "# NNlib.jl Copilot Instructions\n\n## Repository Overview\n\nNNlib.jl is a library providing fundamental neural network oper"
},
{
"path": ".github/dependabot.yml",
"chars": 496,
"preview": "# To get started with Dependabot version updates, you'll need to specify which\n# package ecosystems to update and where "
},
{
"path": ".github/workflows/BenchmarkTrigger.yml",
"chars": 1660,
"preview": "name: Benchmark Trigger\n\non:\n pull_request_target:\n types: [ labeled ]\n workflow_dispatch:\n inputs:\n pr_id:"
},
{
"path": ".github/workflows/CompatHelper.yml",
"chars": 1380,
"preview": "name: CompatHelper\non:\n schedule:\n - cron: 0 0 * * *\n workflow_dispatch:\npermissions:\n contents: write\n pull-requ"
},
{
"path": ".github/workflows/Downstream.yml",
"chars": 2087,
"preview": "name: IntegrationTest\non:\n push:\n branches: [master]\n tags: [v*]\n pull_request:\n\n# needed to allow julia-actions"
},
{
"path": ".github/workflows/TagBot.yml",
"chars": 825,
"preview": "name: TagBot\non:\n issue_comment:\n types:\n - created\n workflow_dispatch:\n inputs:\n lookback:\n de"
},
{
"path": ".github/workflows/ci.yml",
"chars": 2596,
"preview": "name: CI\n\non:\n push:\n branches:\n - master\n - staging\n - trying\n tags: '*'\n pull_request:\n\n# neede"
},
{
"path": ".github/workflows/clean_preview.yml",
"chars": 855,
"preview": "# from https://github.com/CliMA/ClimaTimeSteppers.jl\nname: Doc Preview Cleanup\n\non:\n pull_request:\n types: [closed]\n"
},
{
"path": ".github/workflows/pr_comment.yml",
"chars": 738,
"preview": "name: pr_comment\non:\n pull_request:\n types: [labeled]\njobs:\n pr_comment:\n runs-on: ubuntu-latest\n steps:\n "
},
{
"path": ".gitignore",
"chars": 186,
"preview": "*.jl.cov\n*.jl.*.cov\n*.jl.mem\n*.o\n*.so\n*.dylib\n*.dll\n*~\n\\#*\ndeps/usr\ndeps.jl\n*.log\n.vscode/\n/Manifest.toml\ntest/Manifest."
},
{
"path": "LICENSE.md",
"chars": 1207,
"preview": "The NNlib.jl package is licensed under the MIT \"Expat\" License:\n\n> Copyright (c) 2017-19: Julia Computing, Inc., Mike J "
},
{
"path": "Project.toml",
"chars": 1584,
"preview": "name = \"NNlib\"\nuuid = \"872c559c-99b0-510c-b3b7-b6c96a88d5cd\"\nversion = \"0.9.34\"\n\n[deps]\nAdapt = \"79e6a3ab-5dfb-504d-930d"
},
{
"path": "README.md",
"chars": 1428,
"preview": "<img align=\"right\" width=\"200px\" src=\"https://github.com/FluxML/NNlib.jl/raw/master/docs/src/assets/logo.png\">\n\n# NNlib."
},
{
"path": "benchmark/Project.toml",
"chars": 437,
"preview": "[deps]\nArgParse = \"c7e460c6-2fb9-53a9-8c5b-16f535851c63\"\nBenchmarkCI = \"20533458-34a3-403d-a444-e18f38190b5b\"\nBenchmarkT"
},
{
"path": "benchmark/benchmarks.jl",
"chars": 1408,
"preview": "using BenchmarkTools\nusing NNlib\nusing NNlib.ChainRulesCore: rrule\nusing Random\n\nRandom.seed!(1234567890)\n\nconst SUITE ="
},
{
"path": "benchmark/perf_report.jl",
"chars": 3179,
"preview": "using JLD2, NNlib, BenchmarkTools\n\n# TODO organize and compare benchmarks using BenchmarkGroups\n\n# We need things to go "
},
{
"path": "benchmark/runbenchmarks.jl",
"chars": 1774,
"preview": "# Adapted from\n# https://github.com/kul-forbes/ProximalOperators.jl/tree/master/benchmark\nusing ArgParse\nusing PkgBenchm"
},
{
"path": "docs/.gitignore",
"chars": 27,
"preview": "build/\nsite/\nManifest.toml\n"
},
{
"path": "docs/Project.toml",
"chars": 399,
"preview": "[deps]\nCairoMakie = \"13f3f980-e62b-5c42-98c6-ff1f3baf88f0\"\nDocumenter = \"e30172f5-a6a5-5a46-863b-614d45cd2de4\"\nFLAC = \"a"
},
{
"path": "docs/make.jl",
"chars": 685,
"preview": "using Documenter, NNlib\n\nDocMeta.setdocmeta!(NNlib, :DocTestSetup,\n :(using FFTW, NNlib, UnicodePlots); recursive = t"
},
{
"path": "docs/src/assets/flux.css",
"chars": 2090,
"preview": "@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');\n\nbody {\n font-family: Lato, \"Segoe UI\",Roboto,\"He"
},
{
"path": "docs/src/audio.md",
"chars": 1173,
"preview": "# Reference\n\n!!! note\n Spectral functions require importing `FFTW` package to enable them.\n\n## Window functions\n\n```@"
},
{
"path": "docs/src/index.md",
"chars": 979,
"preview": "# NNlib.jl\n\n`NNlib` provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multip"
},
{
"path": "docs/src/reference.md",
"chars": 2770,
"preview": "# Reference\n\nThe API reference of `NNlib`.\n\n## Activation Functions\n\nNon-linearities that go between layers of your mode"
},
{
"path": "ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl",
"chars": 1804,
"preview": "module NNlibAMDGPUExt\n\nusing Adapt\nusing AMDGPU\nusing ChainRulesCore\nusing NNlib\nusing NNlib: BatchedAdjoint, BatchedTra"
},
{
"path": "ext/NNlibAMDGPUExt/activations.jl",
"chars": 543,
"preview": "for (f, op) in [\n NNlib.relu => MIOpen.relu,\n NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6),\n NNlib."
},
{
"path": "ext/NNlibAMDGPUExt/batched_mul.jl",
"chars": 739,
"preview": "function _blas_at(x)\n Base.stride(x, 1) == 1 && return x, 'N'\n Base.stride(x, 2) == 1 && return batched_transpose("
},
{
"path": "ext/NNlibAMDGPUExt/conv.jl",
"chars": 3396,
"preview": "function NNlib.conv!(\n y::ROCArray{T, N}, x::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims,\n) where {T <: M"
},
{
"path": "ext/NNlibAMDGPUExt/pool.jl",
"chars": 1964,
"preview": "for poolname in (:maxpool, :meanpool)\n @eval function NNlib.$(poolname)(\n x::ROCArray{T, N}, pdims::PoolDims,\n"
},
{
"path": "ext/NNlibCUDACUDNNExt/NNlibCUDACUDNNExt.jl",
"chars": 713,
"preview": "module NNlibCUDACUDNNExt\n\nusing NNlib\nusing cuDNN\nusing CUDA\nusing Random, Statistics\n\nusing cuDNN: handle, with_workspa"
},
{
"path": "ext/NNlibCUDACUDNNExt/activations.jl",
"chars": 1799,
"preview": "\n# Activation\n\nusing Base.Broadcast\nusing cuDNN: cudnnActivationForward!, cudnnOpTensor!,\n CUDNN_ACTIVATION_"
},
{
"path": "ext/NNlibCUDACUDNNExt/batchnorm.jl",
"chars": 6705,
"preview": "using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,\n cudnnBatchNormalizationForwardInference"
},
{
"path": "ext/NNlibCUDACUDNNExt/conv.jl",
"chars": 12164,
"preview": "\nusing NNlib: DenseConvDims\nimport NNlib: conv!, ∇conv_filter!, ∇conv_data!, conv_bias_act!\n\nusing cuDNN: scalingParamet"
},
{
"path": "ext/NNlibCUDACUDNNExt/pooling.jl",
"chars": 3418,
"preview": "using cuDNN: cudnnPoolingMode_t, CUDNN_POOLING_MAX,\n CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING,\n "
},
{
"path": "ext/NNlibCUDACUDNNExt/softmax.jl",
"chars": 4039,
"preview": "import NNlib: softmax, softmax!, ∇softmax, ∇softmax!,\n logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!\n\n"
},
{
"path": "ext/NNlibCUDAExt/NNlibCUDAExt.jl",
"chars": 248,
"preview": "module NNlibCUDAExt\n\nusing NNlib\nusing CUDA\nusing Random, Statistics\n\ninclude(\"sampling.jl\")\ninclude(\"activations.jl\")\ni"
},
{
"path": "ext/NNlibCUDAExt/activations.jl",
"chars": 675,
"preview": "# Activation functions\n\n# Some of activation functions need a wrapper for GPU support\n# https://github.com/JuliaGPU/CuAr"
},
{
"path": "ext/NNlibCUDAExt/batchedadjtrans.jl",
"chars": 1198,
"preview": "using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans\nusing Adapt\nusing Adapt: WrappedArray\n\nconst CuBatchedA"
},
{
"path": "ext/NNlibCUDAExt/batchedmul.jl",
"chars": 377,
"preview": "# Batched matrix multiplication\n# 1st argument is produced by NNlib.storage_type(A)\nNNlib._batched_gemm!(::Type{<:CuArra"
},
{
"path": "ext/NNlibCUDAExt/ctc.jl",
"chars": 6131,
"preview": "# CTC loss moved from Flux.jl to NNlib\n\nimport NNlib: ctc_loss, ctc_alpha, ∇ctc_loss\n\n## GPU implementation\n\n# a port of"
},
{
"path": "ext/NNlibCUDAExt/sampling.jl",
"chars": 5101,
"preview": "@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 4}, value, ix, iy, c, n) where T\n @inbounds CUDA.@atomic dx[ix"
},
{
"path": "ext/NNlibCUDAExt/scatter.jl",
"chars": 7285,
"preview": "# supported op: +, -, *, /, max, min, &, |, mean\n\n## TODO support sparse dst/src/idx\n## See issue https://github.com/Flu"
},
{
"path": "ext/NNlibCUDAExt/utils.jl",
"chars": 1172,
"preview": "NNlib._rng_from_array(::CuArray) = CUDA.default_rng()\n\nNNlib._rng_compat_array(rng::CUDA.RNG, A::CuArray) = nothing\nNNli"
},
{
"path": "ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl",
"chars": 12284,
"preview": "module NNlibEnzymeCoreExt\n\nusing NNlib\nimport EnzymeCore\nusing Random\n\nusing EnzymeCore.EnzymeRules\n\nfor (name, dataname"
},
{
"path": "ext/NNlibFFTWExt/NNlibFFTWExt.jl",
"chars": 94,
"preview": "module NNlibFFTWExt\n\nusing FFTW\nusing NNlib\nusing KernelAbstractions\n\ninclude(\"stft.jl\")\n\nend\n"
},
{
"path": "ext/NNlibFFTWExt/stft.jl",
"chars": 4179,
"preview": "function NNlib.stft(x;\n n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,\n center::Bool = true, normalize"
},
{
"path": "ext/NNlibForwardDiffExt.jl",
"chars": 201,
"preview": "module NNlibForwardDiffExt\n\nusing ForwardDiff: ForwardDiff\nusing NNlib: NNlib\n\nNNlib.within_gradient(x::ForwardDiff.Dual"
},
{
"path": "ext/NNlibMetalExt.jl",
"chars": 157,
"preview": "module NNlibMetalExt\n\n\nusing Metal: method_table, @device_override\nusing NNlib: NNlib\n\n@device_override NNlib.tanh_fast("
},
{
"path": "ext/NNlibSpecialFunctionsExt.jl",
"chars": 308,
"preview": "module NNlibSpecialFunctionsExt\n\nusing NNlib: NNlib, oftf\nusing SpecialFunctions: erf\n\n# Full gelu (gelu_erf)\nNNlib.gelu"
},
{
"path": "src/NNlib.jl",
"chars": 3393,
"preview": "module NNlib\n\nimport Atomix\nimport ChainRulesCore: rrule\n\nusing Base.Broadcast: broadcasted\nusing Base.Threads\nusing Cha"
},
{
"path": "src/activations.jl",
"chars": 39313,
"preview": "## Activation functions\n#\n# Some of activation functions have its wrapper function for GPU in NNlibCUDAExt.jl.\n# https:/"
},
{
"path": "src/attention.jl",
"chars": 6010,
"preview": "const AA3{T} = AbstractArray{T,3}\nconst AA4{T} = AbstractArray{T,4}\nconst AA{N,T} = AbstractArray{T,N}\n\n\"\"\"\n dot_prod"
},
{
"path": "src/audio/mel.jl",
"chars": 3488,
"preview": "\"\"\"\n melscale_filterbanks(;\n n_freqs::Int, n_mels::Int, sample_rate::Int,\n fmin::Float32 = 0f0, fmax::F"
},
{
"path": "src/audio/spectrogram.jl",
"chars": 2311,
"preview": "\"\"\"\n spectrogram(waveform;\n pad::Int = 0, n_fft::Int, hop_length::Int, window,\n center::Bool = true, po"
},
{
"path": "src/audio/stft.jl",
"chars": 6605,
"preview": "\"\"\"\n hamming_window(\n window_length::Int, ::Type{T} = Float32; periodic::Bool = true,\n α::T = T(0.54), "
},
{
"path": "src/batched/batchedadjtrans.jl",
"chars": 4485,
"preview": "import Base: -\nimport Adapt: adapt_structure, adapt\n\n_batched_doc = \"\"\"\n batched_transpose(A::AbstractArray{T,3})\n "
},
{
"path": "src/batched/batchedmul.jl",
"chars": 14223,
"preview": "_unbatch(A) = A\n_unbatch(A::BatchedAdjOrTrans) = parent(A)\n\n\"\"\"\n batched_mul(A, B) -> C\n A ⊠ B # \\\\boxtimes\n\nBatc"
},
{
"path": "src/bias_act.jl",
"chars": 4654,
"preview": "\nusing NNlib: fast_act, tanh_fast\nusing ChainRulesCore\n\nconst RCR = RuleConfig{>:HasReverseMode}\n\n# This just saves typi"
},
{
"path": "src/conv.jl",
"chars": 19290,
"preview": "## Convolution API\n#\n# We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d,\n# 2d and "
},
{
"path": "src/conv_bias_act.jl",
"chars": 1558,
"preview": "function conv_bias_act(x::AbstractArray{xT,N}, w::AbstractArray{wT,N},\n cdims::ConvDims, b::AbstractArray"
},
{
"path": "src/ctc.jl",
"chars": 4012,
"preview": "# CTC loss moved from Flux.jl to NNlib\n\n## CPU implementation\n\n\"\"\"\n logaddexp(a, b)\n\nAdds log-space `a` and `b` such "
},
{
"path": "src/deprecations.jl",
"chars": 1521,
"preview": "### Deprecated while v0.8 was latest\n\nexport ∇softmax,\n ∇softmax!,\n logsoftmax,\n logsoftmax!,\n ∇logsoftmax,\n"
},
{
"path": "src/dim_helpers/ConvDims.jl",
"chars": 4985,
"preview": "\"\"\"\n ConvDims\n\nType system-level information about convolution dimensions. Critical for things like\n`im2col!()` to ge"
},
{
"path": "src/dim_helpers/DenseConvDims.jl",
"chars": 3970,
"preview": "\"\"\"\n DenseConvDims\n\nConcrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d.\n\"\"\"\nstruct DenseConvDims{N, K"
},
{
"path": "src/dim_helpers/DepthwiseConvDims.jl",
"chars": 3728,
"preview": "\"\"\"\n DepthwiseConvDims\n\nConcrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to\ncharac"
},
{
"path": "src/dim_helpers/PoolDims.jl",
"chars": 2981,
"preview": "\"\"\"\n PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int};\n stride=k, padding=0, dilation=1) where "
},
{
"path": "src/dim_helpers.jl",
"chars": 4793,
"preview": "# Various helper functions to calculate dimensions for operations\ninclude(\"dim_helpers/ConvDims.jl\")\ninclude(\"dim_helper"
},
{
"path": "src/dropout.jl",
"chars": 5076,
"preview": "\n\"\"\"\n dropout([rng], A, p; [dims])\n\nReturns an array in which each element of `A` is either replaced with zero,\nwith "
},
{
"path": "src/fold.jl",
"chars": 9220,
"preview": "\"\"\"\n unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)\n\nPlaces sliding windows of x into a co"
},
{
"path": "src/functions.jl",
"chars": 487,
"preview": "\"\"\"\n glu(x, dim = 1)\n\nThe gated linear unit from the [\"Language Modeling with Gated Convolutional Networks\"](https://"
},
{
"path": "src/gather.jl",
"chars": 3950,
"preview": "\"\"\"\n NNlib.gather(src, idx) -> dst\n\nReverse operation of [`scatter`](@ref). Gathers data from source `src`\nand writes"
},
{
"path": "src/gemm.jl",
"chars": 6153,
"preview": "## Low level gemm! call with pointers\n## Borrowed from Knet.jl, adapted for compile-time constants\n\nusing LinearAlgebra."
},
{
"path": "src/impl/conv_direct.jl",
"chars": 7844,
"preview": "## This file contains direct Julia implementations of 2d and 3d convolutions\n\n# Helper functions for restricting x/w ove"
},
{
"path": "src/impl/conv_im2col.jl",
"chars": 14351,
"preview": "## This file contains im2col-backed implementations of convolution for 2d and 3d\n## convolutions. Expect to see a lot o"
},
{
"path": "src/impl/depthwiseconv_direct.jl",
"chars": 8112,
"preview": "## This file contains direct Julia implementations of depwthwise convolutions\n\n\"\"\"\n depthwiseconv_direct!(y, x, w, cd"
},
{
"path": "src/impl/depthwiseconv_im2col.jl",
"chars": 6247,
"preview": "## This file contains adapter code for doing depthwise convolutions with im2col.\n\n\n\"\"\"\n depthwiseconv_im2col!(y, x, w"
},
{
"path": "src/impl/padding_edges.jl",
"chars": 4394,
"preview": "\"\"\"\n calc_padding_regions(dims)\n\nPadding is a jerk. A HUGE jerk that tries to sneak a bunch of conditionals and edge"
},
{
"path": "src/impl/pooling_direct.jl",
"chars": 14810,
"preview": "# Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing\n# the inner loop operation and a "
},
{
"path": "src/normalization.jl",
"chars": 475,
"preview": "# TODO: add CPU implementation\nfunction batchnorm end\n\nfunction ∇batchnorm end\n\n\nfunction ChainRulesCore.rrule(::typeof("
},
{
"path": "src/padding.jl",
"chars": 12076,
"preview": "\"\"\"\n pad_zeros(x, pad::Tuple; [dims])\n pad_zeros(x, pad::Int; [dims])\n\nPad the array `x` with zeros.\nEquivalent to"
},
{
"path": "src/pooling.jl",
"chars": 8497,
"preview": "## Pooling API\n#\n# We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d,\n# 2d and 3d p"
},
{
"path": "src/rotation.jl",
"chars": 10460,
"preview": "\"\"\"\n _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, round_or_floor)\n\nThis rotates the coordinates and either "
},
{
"path": "src/sampling.jl",
"chars": 18531,
"preview": "@inline in_bounds(h, w, H, W) = 1 ≤ h ≤ H && 1 ≤ w ≤ W\n@inline in_bounds(h, w, d, H, W, D) = 1 ≤ h ≤ H && 1 ≤ w ≤ W && 1"
},
{
"path": "src/scatter.jl",
"chars": 10214,
"preview": "## Scatter API\n# - Scatter:\n# - scatter(op, src, idx)\n# - scatter!(op, dst, src, idx)\n# - Scatter destinatio"
},
{
"path": "src/softmax.jl",
"chars": 5452,
"preview": "\n\"\"\"\n softmax(x; dims = 1)\n\n[Softmax](https://en.wikipedia.org/wiki/Softmax_function) turns input array `x`\ninto prob"
},
{
"path": "src/upsample.jl",
"chars": 25750,
"preview": "\"\"\"\n pixel_shuffle(x, r::Integer)\n\nPixel shuffling operation, upscaling by a factor `r`.\n\nFor 4-arrays representing `"
},
{
"path": "src/utils.jl",
"chars": 4534,
"preview": "\"\"\"\n within_gradient(x) --> Bool\n\nReturns `false` except when used inside a `gradient` call, when it returns `true`.\n"
},
{
"path": "test/Project.toml",
"chars": 1378,
"preview": "[deps]\nAdapt = \"79e6a3ab-5dfb-504d-930d-738a2a938a0e\"\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nChainRules"
},
{
"path": "test/activations.jl",
"chars": 14333,
"preview": "\nACTIVATION_FUNCTIONS = [@eval($a) for a in NNlib.ACTIVATIONS]\n\nBINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Floa"
},
{
"path": "test/attention.jl",
"chars": 2609,
"preview": "@testset \"different batchsizes\" begin\n n = 15\n lenq = 3\n lenkv = 4\n for batch_size in [(), 1, 2, (2,1,3)], n"
},
{
"path": "test/batchedmul.jl",
"chars": 12134,
"preview": "using NNlib, Test, LinearAlgebra, Logging\nusing NNlib: storage_type, storage_typejoin, is_strided,\n batched_mul_gener"
},
{
"path": "test/bias_act.jl",
"chars": 5530,
"preview": "using NNlib, Zygote, ChainRulesCore, Test\nusing Zygote: ForwardDiff\n\nACTIVATION_FUNCTIONS =\n [@eval($a) for a in NNli"
},
{
"path": "test/conv.jl",
"chars": 45777,
"preview": "using NNlib, Test\nusing NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier,\n stri"
},
{
"path": "test/conv_bias_act.jl",
"chars": 237,
"preview": "@testset \"conv_bias_act\" begin\n x = rand(4,4,3,3)\n w = rand(2,2,3,3)\n b = rand(1,1,1,3)\n cdims = DenseConvDi"
},
{
"path": "test/ctc.jl",
"chars": 1342,
"preview": "using Test\nusing NNlib: ctc_loss\nusing Zygote: gradient\nusing LinearAlgebra\n\n# Custom function to check numerical gradie"
},
{
"path": "test/dropout.jl",
"chars": 3646,
"preview": "using NNlib, Test, Statistics, Random, LinearAlgebra\nusing Zygote, StableRNGs, ChainRulesCore, Enzyme\n\n@testset \"dropout"
},
{
"path": "test/ext_amdgpu/activations.jl",
"chars": 492,
"preview": "@testset \"Compare CPU & GPU\" begin\n for (T, atol) in ((Float16, 1.0f-2), (Float32, 1.0f-5))\n @testset \"ndims: "
},
{
"path": "test/ext_amdgpu/attention.jl",
"chars": 1487,
"preview": "@testset \"Compare CPU & GPU\" begin\n n = 15\n lenq = 3\n lenkv = 4\n for batch_size in [(), 1, 2, (2, 1, 3)], nh"
},
{
"path": "test/ext_amdgpu/batched_mul.jl",
"chars": 1057,
"preview": "@testset \"batched_mul\" begin\n A = rand(Float32, 3, 3, 2)\n B = rand(Float32, 3, 3, 2)\n dA, dB = ROCArray.((A, B)"
},
{
"path": "test/ext_amdgpu/batched_repr.jl",
"chars": 1301,
"preview": "function print_array_strs(x)\n str = sprint((io, x)->show(io, MIME\"text/plain\"(), x), x)\n return @view split(str, '"
},
{
"path": "test/ext_amdgpu/conv.jl",
"chars": 506,
"preview": "@testset \"Compare CPU & GPU\" begin\n channels, batch = 3, 2\n for T in (Float16, Float32), nd in (1, 2, 3)\n x"
},
{
"path": "test/ext_amdgpu/dropout.jl",
"chars": 806,
"preview": "@testset \"Test API\" begin\n x = AMDGPU.randn(Float32, 3, 4)\n @test size(@inferred dropout(x, 0.1)) == (3, 4)\n @t"
},
{
"path": "test/ext_amdgpu/pool.jl",
"chars": 438,
"preview": "@testset \"Compare CPU & GPU\" begin\n channels, batch = 3, 2\n for T in (Float16, Float32), nd in (1, 2, 3)\n x"
},
{
"path": "test/ext_amdgpu/runtests.jl",
"chars": 1345,
"preview": "using NNlib: batched_adjoint, batched_mul, batched_mul!, batched_transpose\nusing NNlib: is_strided, storage_type\nusing L"
},
{
"path": "test/ext_amdgpu/softmax.jl",
"chars": 539,
"preview": "@testset \"Compare CPU & GPU\" begin\n for (T, atol) in ((Float16, 1f-2), (Float32, 1f-5))\n for (sz, dims) in [\n "
},
{
"path": "test/ext_amdgpu/storage_type.jl",
"chars": 471,
"preview": "@testset \"NNlib storage type\" begin\n x = ROCArray(ones(Float32, 10, 10))\n @test storage_type(x) <: ROCArray{Float3"
},
{
"path": "test/ext_cuda/activations.jl",
"chars": 1375,
"preview": "@testset \"activation broadcast\" begin\n for f in NNlib.ACTIVATIONS\n if f ∉ [:rrelu]\n @eval gputest(x"
},
{
"path": "test/ext_cuda/batchedadjtrans.jl",
"chars": 1307,
"preview": "function print_array_strs(x)\n str = sprint((io, x)->show(io, MIME\"text/plain\"(), x), x)\n return @view split(str, '"
},
{
"path": "test/ext_cuda/batchedmul.jl",
"chars": 1833,
"preview": "@testset \"batched_mul\" begin\n using NNlib: batched_mul, batched_mul!, batched_vec, \n batched_adjoint,"
},
{
"path": "test/ext_cuda/batchnorm.jl",
"chars": 1568,
"preview": "using Statistics\n\n@testset \"Batchnorm\" begin\n v = CUDA.rand(Float32, 2)\n m = CUDA.rand(Float32, 2, 5)\n\n @testse"
},
{
"path": "test/ext_cuda/conv.jl",
"chars": 5827,
"preview": "using NNlib: DenseConvDims\n\n@testset \"convolution\" begin\n@testset \"$T\" for T in (Float64, ComplexF64)\n a, b, c = rand"
},
{
"path": "test/ext_cuda/ctc.jl",
"chars": 1499,
"preview": "# Custom function to check numerical gradient of ctc loss,\n# based on `ngradient` in `Tracker.jl`\nfunction ctc_ngradient"
},
{
"path": "test/ext_cuda/dropout.jl",
"chars": 1301,
"preview": "@testset \"dropout + CUDA\" begin\n # Basics\n x1 = CUDA.randn(3, 4)\n @test size(@inferred dropout(x1, 0.1)) == (3,"
},
{
"path": "test/ext_cuda/fold.jl",
"chars": 1193,
"preview": "\n@testset \"fold\" begin\n # Test for agreement between CPU/GPU versions, across a variety of kwargs\n options = Dict{"
},
{
"path": "test/ext_cuda/gather.jl",
"chars": 3338,
"preview": "@testset \"gather\" begin\n T = Float32\n CT = CuArray{Float32}\n\n ## 1d src, 2d index of ints -> 2d output\n src "
},
{
"path": "test/ext_cuda/pooling.jl",
"chars": 738,
"preview": "@testset \"pooling\" begin\n\n # Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs\n "
},
{
"path": "test/ext_cuda/runtests.jl",
"chars": 557,
"preview": "using Test\nusing NNlib\nusing Zygote\nusing ForwardDiff: Dual\nusing Statistics: mean\nusing CUDA, cuDNN\nimport CUDA.CUSPARS"
},
{
"path": "test/ext_cuda/sampling.jl",
"chars": 4493,
"preview": "@testset \"Grid Sampling\" begin\n for T in (Float32, Float64)\n x = ones(T, (2, 2, 1, 1))\n grid = Array{T}"
},
{
"path": "test/ext_cuda/scatter.jl",
"chars": 3941,
"preview": "dsts = Dict(\n 0 => cu([3, 4, 5, 6, 7]),\n 1 => cu([3 3 4 4 5;\n 5 5 6 6 7]),\n)\nsrcs = Dict(\n (0, true"
},
{
"path": "test/ext_cuda/softmax.jl",
"chars": 1033,
"preview": "@testset \"softmax\" begin\n for (sz, dims) in [((5,), :), ((5,), 1), ((5,5), :), ((5,5), 1), ((5,5), 2), ((5,5,5,5), (2"
},
{
"path": "test/ext_cuda/test_utils.jl",
"chars": 872,
"preview": "function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...)\n cpu_in = "
},
{
"path": "test/ext_metal/activations.jl",
"chars": 580,
"preview": "@testset \"activation broadcast\" begin\n broken_f = (:hardσ, :leakyrelu) \n for name in NNlib.ACTIVATIONS\n # p"
},
{
"path": "test/ext_metal/runtests.jl",
"chars": 919,
"preview": "using NNlib\nusing Test\nusing Metal\nusing Zygote: gradient\nusing MLDataDevices: gpu_device\nusing ForwardDiff: Dual\n\nMetal"
},
{
"path": "test/functions.jl",
"chars": 282,
"preview": "using NNlib: glu\nusing Zygote\n\n@testset \"glu\" begin\n x = [1 2 3; 4 5 6; 7 8 9; 10 11 12]\t \n @test ceil.(glu(x, 1))"
},
{
"path": "test/inference.jl",
"chars": 2194,
"preview": "import NNlib: conv_direct, conv_im2col, channels_in, channels_out\n\n@testset \"Conv Inference\" begin\n for T in (Float32"
},
{
"path": "test/padding.jl",
"chars": 5135,
"preview": "using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular\n\n@testset \"padding constant\" "
},
{
"path": "test/pooling.jl",
"chars": 34137,
"preview": "# using NNlib, Test\n\nmaxpool_answer_dict = Dict(\n 1 => Dict(\n \"y\" => [2, 4.],\n \"y_nostride\" =>"
},
{
"path": "test/runtests.jl",
"chars": 5694,
"preview": "using NNlib, Test, Statistics, Random\nusing ChainRulesCore, ChainRulesTestUtils\nusing Base.Broadcast: broadcasted\nimport"
},
{
"path": "test/sampling.jl",
"chars": 5561,
"preview": "@testset \"Known gradients\" begin\n x = ones(Float64, (2, 2, 1, 1))\n grid = Array{Float64}(undef, 2, 2, 2, 1)\n gr"
},
{
"path": "test/softmax.jl",
"chars": 5009,
"preview": "using Statistics: mean\nusing NNlib: ∇softmax_data, ∇logsoftmax_data\n\n@testset \"softmax integer input\" begin\n @test so"
},
{
"path": "test/test_utils.jl",
"chars": 2571,
"preview": "const IntOrTuple = Union{Int, NTuple{N,Int} where N}\n\ngradtest(f, dims::IntOrTuple...; kw...) =\n gradtest(f, randn.(R"
},
{
"path": "test/testsuite/fold.jl",
"chars": 1887,
"preview": "import NNlib\n\nfunction fold_testsuite(Backend)\n device(x) = adapt(Backend(), x)\n gradtest_fn = Backend == CPU ? gr"
},
{
"path": "test/testsuite/gather.jl",
"chars": 6210,
"preview": "using NNlib: gather, gather!\nimport EnzymeTestUtils\nusing EnzymeCore\n\nfunction gather_testsuite(Backend)\n device(x) ="
},
{
"path": "test/testsuite/rotation.jl",
"chars": 4169,
"preview": "function rotation_testsuite(Backend)\n device(x) = adapt(Backend(), x)\n gradtest_fn = Backend == CPU ? gradtest : g"
},
{
"path": "test/testsuite/scatter.jl",
"chars": 8429,
"preview": "using NNlib: scatter, scatter!\n\ndsts = Dict(\n 0 => [3, 4, 5, 6, 7],\n 1 => [3 3 4 4 5;\n 5 5 6 6 7],\n)\nsrcs"
},
{
"path": "test/testsuite/spectral.jl",
"chars": 5816,
"preview": "function spectral_testsuite(Backend)\n cpu(x) = adapt(CPU(), x)\n device(x) = adapt(Backend(), x)\n gradtest_fn = "
},
{
"path": "test/testsuite/upsample.jl",
"chars": 7405,
"preview": "function upsample_testsuite(Backend)\n device(x) = adapt(Backend(), x)\n gradtest_fn = Backend == CPU ? gradtest : g"
},
{
"path": "test/utils.jl",
"chars": 1518,
"preview": "@testset \"within_gradient\" begin\n @test NNlib.within_gradient([1.0]) === false\n @test gradient(x -> NNlib.within_g"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the FluxML/NNlib.jl GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 141 files (634.0 KB), approximately 226.8k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.