Full Code of FluxML/NNlib.jl for AI

master d3c38a49d9cf cached
141 files
634.0 KB
226.8k tokens
1 requests
Download .txt
Showing preview only (707K chars total). Download the full file or copy to clipboard to get everything.
Repository: FluxML/NNlib.jl
Branch: master
Commit: d3c38a49d9cf
Files: 141
Total size: 634.0 KB

Directory structure:
gitextract_0xi0432k/

├── .buildkite/
│   └── pipeline.yml
├── .codecov.yml
├── .github/
│   ├── copilot-instructions.md
│   ├── dependabot.yml
│   └── workflows/
│       ├── BenchmarkTrigger.yml
│       ├── CompatHelper.yml
│       ├── Downstream.yml
│       ├── TagBot.yml
│       ├── ci.yml
│       ├── clean_preview.yml
│       └── pr_comment.yml
├── .gitignore
├── LICENSE.md
├── Project.toml
├── README.md
├── benchmark/
│   ├── Project.toml
│   ├── benchmarks.jl
│   ├── perf_report.jl
│   └── runbenchmarks.jl
├── docs/
│   ├── .gitignore
│   ├── Project.toml
│   ├── make.jl
│   └── src/
│       ├── assets/
│       │   ├── flux.css
│       │   └── jfk.flac
│       ├── audio.md
│       ├── index.md
│       └── reference.md
├── ext/
│   ├── NNlibAMDGPUExt/
│   │   ├── NNlibAMDGPUExt.jl
│   │   ├── activations.jl
│   │   ├── batched_mul.jl
│   │   ├── conv.jl
│   │   └── pool.jl
│   ├── NNlibCUDACUDNNExt/
│   │   ├── NNlibCUDACUDNNExt.jl
│   │   ├── activations.jl
│   │   ├── batchnorm.jl
│   │   ├── conv.jl
│   │   ├── pooling.jl
│   │   └── softmax.jl
│   ├── NNlibCUDAExt/
│   │   ├── NNlibCUDAExt.jl
│   │   ├── activations.jl
│   │   ├── batchedadjtrans.jl
│   │   ├── batchedmul.jl
│   │   ├── ctc.jl
│   │   ├── sampling.jl
│   │   ├── scatter.jl
│   │   └── utils.jl
│   ├── NNlibEnzymeCoreExt/
│   │   └── NNlibEnzymeCoreExt.jl
│   ├── NNlibFFTWExt/
│   │   ├── NNlibFFTWExt.jl
│   │   └── stft.jl
│   ├── NNlibForwardDiffExt.jl
│   ├── NNlibMetalExt.jl
│   └── NNlibSpecialFunctionsExt.jl
├── src/
│   ├── NNlib.jl
│   ├── activations.jl
│   ├── attention.jl
│   ├── audio/
│   │   ├── mel.jl
│   │   ├── spectrogram.jl
│   │   └── stft.jl
│   ├── batched/
│   │   ├── batchedadjtrans.jl
│   │   └── batchedmul.jl
│   ├── bias_act.jl
│   ├── conv.jl
│   ├── conv_bias_act.jl
│   ├── ctc.jl
│   ├── deprecations.jl
│   ├── dim_helpers/
│   │   ├── ConvDims.jl
│   │   ├── DenseConvDims.jl
│   │   ├── DepthwiseConvDims.jl
│   │   └── PoolDims.jl
│   ├── dim_helpers.jl
│   ├── dropout.jl
│   ├── fold.jl
│   ├── functions.jl
│   ├── gather.jl
│   ├── gemm.jl
│   ├── impl/
│   │   ├── conv_direct.jl
│   │   ├── conv_im2col.jl
│   │   ├── depthwiseconv_direct.jl
│   │   ├── depthwiseconv_im2col.jl
│   │   ├── padding_edges.jl
│   │   └── pooling_direct.jl
│   ├── normalization.jl
│   ├── padding.jl
│   ├── pooling.jl
│   ├── rotation.jl
│   ├── sampling.jl
│   ├── scatter.jl
│   ├── softmax.jl
│   ├── upsample.jl
│   └── utils.jl
└── test/
    ├── Project.toml
    ├── activations.jl
    ├── attention.jl
    ├── batchedmul.jl
    ├── bias_act.jl
    ├── conv.jl
    ├── conv_bias_act.jl
    ├── ctc.jl
    ├── dropout.jl
    ├── ext_amdgpu/
    │   ├── activations.jl
    │   ├── attention.jl
    │   ├── batched_mul.jl
    │   ├── batched_repr.jl
    │   ├── conv.jl
    │   ├── dropout.jl
    │   ├── pool.jl
    │   ├── runtests.jl
    │   ├── softmax.jl
    │   └── storage_type.jl
    ├── ext_cuda/
    │   ├── activations.jl
    │   ├── batchedadjtrans.jl
    │   ├── batchedmul.jl
    │   ├── batchnorm.jl
    │   ├── conv.jl
    │   ├── ctc.jl
    │   ├── dropout.jl
    │   ├── fold.jl
    │   ├── gather.jl
    │   ├── pooling.jl
    │   ├── runtests.jl
    │   ├── sampling.jl
    │   ├── scatter.jl
    │   ├── softmax.jl
    │   └── test_utils.jl
    ├── ext_metal/
    │   ├── activations.jl
    │   └── runtests.jl
    ├── functions.jl
    ├── inference.jl
    ├── padding.jl
    ├── pooling.jl
    ├── runtests.jl
    ├── sampling.jl
    ├── softmax.jl
    ├── test_utils.jl
    ├── testsuite/
    │   ├── fold.jl
    │   ├── gather.jl
    │   ├── rotation.jl
    │   ├── scatter.jl
    │   ├── spectral.jl
    │   └── upsample.jl
    └── utils.jl

================================================
FILE CONTENTS
================================================

================================================
FILE: .buildkite/pipeline.yml
================================================
steps:
  - label: ":julia: Julia {{matrix.julia}} - CUDA GPU"
    command:
      - echo 'CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"' >> test/Project.toml
      - echo 'cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"' >> test/Project.toml
    plugins:
      - JuliaCI/julia#v1:
          version: "{{matrix.julia}}"
      - JuliaCI/julia-test#v1:
          test_args: "--quickfail"
      - JuliaCI/julia-coverage#v1:
          codecov: true
          dirs:
            - src
            - ext
    agents:
      queue: "juliagpu"
      cuda: "*"
    env:
      JULIA_NUM_THREADS: 4
      NNLIB_TEST_CUDA: "true"
      NNLIB_TEST_CPU: "false"
    if: build.message !~ /\[skip tests\]/
    timeout_in_minutes: 180
    matrix:
      setup:
        julia:
          - "1.10"
          - "1"
          - "nightly"
      adjustments:
        - with:
            julia: "nightly"
          soft_fail: true


  - label: ":julia: Julia {{matrix.julia}} - AMD GPU"
    command:
      - echo 'AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"' >> test/Project.toml
    plugins:
      - JuliaCI/julia#v1:
          version: "1"
      - JuliaCI/julia-test#v1:
          test_args: "--quickfail"
      - JuliaCI/julia-coverage#v1:
          codecov: true
          dirs:
            - src
            - ext
    agents:
      queue: "juliagpu"
      rocm: "*"
      rocmgpu: "*"
    timeout_in_minutes: 180
    env:
      JULIA_AMDGPU_CORE_MUST_LOAD: "1"
      JULIA_AMDGPU_HIP_MUST_LOAD: "1"
      JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
      NNLIB_TEST_AMDGPU: "true"
      NNLIB_TEST_CPU: "false"
      JULIA_NUM_THREADS: 4
    matrix:
      setup:
        julia:
          # - "1.10"  
          - "1"
          # - "nightly"
      # adjustments:
      #   - with:
      #       julia: "nightly"
      #     soft_fail: true


  - label: ":julia: Julia {{matrix.julia}} - Metal GPU"
    command:
      - echo 'Metal = "dde4c033-4e86-420c-a63e-0dd931031962"' >> test/Project.toml
    plugins:
      - JuliaCI/julia#v1:
          version: "{{matrix.julia}}"
      - JuliaCI/julia-test#v1:
          test_args: "--quickfail"
      - JuliaCI/julia-coverage#v1:
          codecov: true
          dirs:
            - src
            - ext
    agents:
      queue: "juliaecosystem"
      os: "macos"
      arch: "aarch64"
    timeout_in_minutes: 180
    env:
      NNLIB_TEST_METAL: "true"
      NNLIB_TEST_CPU: "false"
      JULIA_NUM_THREADS: 4
    matrix:
      setup:
        julia:
          # - "1.10"
          - "1"
          # - "nightly  "
      # adjustments:
      #   - with:
      #       julia: "nightly"
      #     soft_fail: true


  - label: "Benchmarks"
    plugins:
      - JuliaCI/julia#v1:
          version: 1
    env:
      JULIA_NUM_THREADS: 4
    command:
      - julia --project=benchmark -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
      - julia --project=benchmark benchmark/runbenchmarks.jl
      - printf '%b\n' "$(cat benchmark/report.md)" | buildkite-agent annotate --style 'info'
    agents:
      queue: "juliagpu"
    if: build.pull_request.labels includes "benchmark"
    timeout_in_minutes: 30

env:
  SECRET_CODECOV_TOKEN: "IlEMvDI6RciJQr5eX7qBBpHYFAe8+Svf3lNJh9gZi0MeJZQvMZWzHfW/lVncA9d9K+gDBBTv/zwqF86xOaIFLuACNdcGZiGgHS+NGeXN5CEppjqLnqKuaeHmLgJ43jygxRwgF88LhwTGcHG7pmESIp1Bn3Jd23UUv4t8hJLBDF+KJLZMefzCXnEVzfwJYxhJktnKJPA4dOv59w33Vj1x5uCYZbQlLP54IJPBm8UGdXS+JrUX8Z7lhxbkJUi6c+R6cvVBw27uRjF0pUJY26mt1frx8MzTGTOweXTpi+Kc5JhzlokMlan17j6T/b7qMC13IuKopfqu1GhkSBQD3ZhQqA==;U2FsdGVkX19l7JMB48k4oJHLoaqC7/MmvQWmaiBxRN472ZC6AcQ0uCBRy6Fw8tI0YcjIxKDScaBnJ2v/deOfhg=="


================================================
FILE: .codecov.yml
================================================
comment: false


================================================
FILE: .github/copilot-instructions.md
================================================
# NNlib.jl Copilot Instructions

## Repository Overview

NNlib.jl is a library providing fundamental neural network operations and primitives for Julia. It is primarily used by Flux.jl but can be used independently. The library provides:

- Activation functions (sigmoid, relu, gelu, etc.)
- Convolution and pooling operations
- Attention mechanisms
- Batched matrix operations
- Neural network utilities (dropout, normalization, etc.)
- GPU acceleration support (CUDA, AMDGPU)

## Project Structure

```
NNlib.jl/
├── src/              # Core library implementation
│   ├── NNlib.jl      # Main module file
│   ├── activations.jl # Activation functions
│   ├── attention.jl   # Attention mechanisms
│   ├── conv.jl        # Convolution operations
│   ├── pooling.jl     # Pooling operations
│   ├── batched/       # Batched operations
│   └── impl/          # Implementation details
├── ext/              # Package extensions for GPU backends
│   ├── NNlibCUDAExt/      # CUDA-specific implementations
│   ├── NNlibAMDGPUExt/    # AMDGPU-specific implementations
│   └── NNlibCUDACUDNNExt/ # cuDNN-specific implementations
├── test/             # Test suite
└── docs/             # Documentation
```

## Julia Version

- Minimum Julia version: 1.10
- CI tests on: minimum julia version, latest stable (1.x), and pre-release versions

## Coding Standards

### Julia Conventions

1. **Naming**:
   - Functions: lowercase with underscores (e.g., `dot_product_attention`)
   - Types: PascalCase (e.g., `ConvDims`, `PoolDims`)
   - Constants: UPPERCASE with underscores (e.g., `ACTIVATIONS`)

2. **Documentation**:
   - Use Julia docstrings (""" ... """) for all exported functions
   - Include examples in docstrings where appropriate
   - Keep documentation up-to-date with implementation changes

3. **Type Annotations**:
   - Use type parameters and abstract types for generic implementations
   - Leverage Julia's multiple dispatch for specialized implementations
   - Define clear type hierarchies (e.g., `DenseConvDims`, `DepthwiseConvDims`)

4. **Performance**:
   - Prefer in-place operations where appropriate (functions ending with `!`)
   - Use `@inbounds` judiciously when bounds checking is verified
   - Consider thread safety for multi-threaded operations
   - Use `NNlib.@disallow_spawns` to control threading behavior

### Code Organization

1. **Core Implementations**: CPU implementations go in `src/`
2. **GPU Extensions**: GPU-specific code belongs in `ext/` as package extensions
3. **Tests**: Mirror the structure of `src/` in `test/`
4. **Gradients**: Define gradients using ChainRules.jl (`rrule` functions)

## Testing

### Test Infrastructure

- Uses the standard Julia `Test` framework
- Tests are organized to mirror the source structure
- GPU tests are conditional (controlled by environment variables)

### Running Tests

```julia
# Run all CPU tests
julia --project -e 'using Pkg; Pkg.test()'

# Run tests with threading
JULIA_NUM_THREADS=4 julia --project -e 'using Pkg; Pkg.test()'
```

### Test Patterns

1. **Activation Functions**: Test at specific values (0.0, 1.0, -1.0) and verify expected outputs
2. **Gradient Tests**: Use `ChainRulesTestUtils` for gradient correctness
3. **Type Stability**: Use `@inferred` where appropriate
4. **GPU Tests**: Conditional testing based on environment variables:
   - `ENV["NNLIB_TEST_CUDA"]` for CUDA tests
   - `ENV["NNLIB_TEST_AMDGPU"]` for AMDGPU tests

### Writing New Tests

- Include tests for edge cases (zero inputs, negative values, boundary conditions)
- Test both forward pass and gradients (using ChainRulesTestUtils)
- For array operations, test multiple dimensions and batch sizes
- Include tests for type stability when performance-critical

## Dependencies

### Core Dependencies

- **ChainRulesCore**: For automatic differentiation support
- **KernelAbstractions**: For GPU kernel abstractions
- **Adapt**: For moving data between CPU/GPU
- **GPUArraysCore**: GPU array interface

### Weak Dependencies (Extensions)

- **CUDA.jl/cuDNN**: NVIDIA GPU support
- **AMDGPU.jl**: AMD GPU support
- **FFTW**: Fast Fourier transforms
- **ForwardDiff**: Forward-mode AD support
- **EnzymeCore**: Enzyme AD support

### Adding New Dependencies

- Consider whether the dependency should be a weak dependency (extension)
- Update `Project.toml` with version constraints
- Ensure compatibility with supported Julia versions
- Run full test suite after adding dependencies

## GPU Support

NNlib uses Julia's package extension system for GPU backends:

1. **CUDA**: Load with `using NNlib, CUDA, cuDNN`
2. **AMDGPU**: Load with `using NNlib, AMDGPU`

### GPU Implementation Guidelines

- Keep GPU-specific code in appropriate extensions (`ext/` directory)
- Provide CPU fallback implementations in `src/`
- Test GPU implementations separately (conditional on hardware availability)
- Use KernelAbstractions for portable GPU kernels when possible

## Build and CI/CD

### Continuous Integration

- **CI Workflow**: `.github/workflows/ci.yml`
  - Tests on Linux (always), Windows, and macOS
  - Tests with different Julia versions (LTS, stable, pre-release)
  - Tests with different thread counts
  
### Additional Workflows

- **TagBot**: Automatic release tagging
- **CompatHelper**: Dependency compatibility updates
- **Downstream**: Tests dependent packages
- **BenchmarkTrigger**: Performance regression testing

## Common Tasks

### Adding a New Activation Function

1. Add function to `src/activations.jl`
2. Add to `ACTIVATIONS` tuple for automatic export
3. Define gradient with `@scalar_rule` or `rrule`
4. Add tests in `test/activations.jl` at key values
5. Document with docstring and example

### Adding a New Operation

1. Implement in appropriate file in `src/`
2. Export from `src/NNlib.jl`
3. Define gradients using ChainRules
4. Add comprehensive tests
5. Add GPU implementations in extensions if applicable
6. Document in appropriate file in `docs/src/`

### Modifying Existing Functions

1. Check for dependent code in Flux.jl and other downstream packages
2. Maintain backward compatibility or document breaking changes
3. Update tests to cover new behavior
4. Update gradients if needed
5. Consider performance implications

## Performance Considerations

1. **Memory Allocation**: Minimize allocations in hot paths
2. **Threading**: NNlib uses Julia threads for parallel operations
   - Control with `NNlib.@disallow_spawns` if needed
   - Thread count controlled by `JULIA_NUM_THREADS`
3. **GPU Kernels**: Optimize kernel launch parameters and memory access patterns
4. **Type Stability**: Ensure type-stable code for performance-critical paths

## Documentation

- Documentation source: `docs/src/`
- Built with Documenter.jl
- Includes API reference, examples, and guides
- Documentation tests run via `DocTestSetup`

## Related Projects

- **Flux.jl**: Primary consumer of NNlib
- **Zygote.jl**: Automatic differentiation (uses ChainRules)
- **ChainRules.jl**: Gradient definitions
- **KernelAbstractions.jl**: GPU kernel abstraction

## Getting Help

- Documentation: https://fluxml.ai/NNlib.jl/dev/
- Issues: https://github.com/FluxML/NNlib.jl/issues
- FluxML community: https://github.com/FluxML


================================================
FILE: .github/dependabot.yml
================================================
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file

version: 2
updates:
  - package-ecosystem: "github-actions"
    directory: "/" # Location of package manifests
    schedule:
      interval: "weekly"


================================================
FILE: .github/workflows/BenchmarkTrigger.yml
================================================
name: Benchmark Trigger

on:
  pull_request_target:
    types: [ labeled ]
  workflow_dispatch:
    inputs:
      pr_id:
        type: string
        description: id of the pull request that triggers this workflow
      target_url:
        type: string
        description: url of target
      baseline_url:
        type: string
        description: url of baseline

jobs:
  benchmark_trigger:
    if: ${{ github.event.label.name == 'benchmark' }}
    runs-on: ubuntu-latest
    env:
      REPOSITORY: ${{ github.event.repository.full_name }}
      PR_ID: ${{ github.event.inputs.pr_id || github.event.pull_request.number }}
      TARGET_URL: ${{ github.event.inputs.target_url || format('{0}#{1}', github.event.pull_request.head.repo.html_url, github.event.pull_request.head.sha) }}
      BASELINE_URL: ${{ github.event.inputs.baseline_url || format('{0}#{1}', github.event.pull_request.base.repo.html_url, github.event.pull_request.base.sha) }}
    steps:
      -
        name: Get app installation token (ghs)
        id: get-app-token
        uses: tibdex/github-app-token@v2
        with: 
          app_id: ${{ secrets.BENCH_APP_ID }}
          installation_id: ${{ secrets.BENCH_INSTALL_ID }}
          private_key: ${{ secrets.BENCH_PRIVATE_KEY }}
      -
        uses: benc-uk/workflow-dispatch@v1
        with:
          repo: FluxML/FluxMLBenchmarks.jl
          ref: refs/heads/main
          workflow: Benchmark.yml
          token: ${{ steps.get-app-token.outputs.token }}
          inputs: '{ "repository": "${{ env.REPOSITORY }}", "pr_id": "${{ env.PR_ID }}", "target_url": "${{ env.TARGET_URL }}", "baseline_url": "${{ env.BASELINE_URL }}" }'


================================================
FILE: .github/workflows/CompatHelper.yml
================================================
name: CompatHelper
on:
  schedule:
    - cron: 0 0 * * *
  workflow_dispatch:
permissions:
  contents: write
  pull-requests: write
jobs:
  CompatHelper:
    runs-on: ubuntu-latest
    steps:
      - name: Check if Julia is already available in the PATH
        id: julia_in_path
        run: which julia
        continue-on-error: true
      - name: Install Julia, but only if it is not already available in the PATH
        uses: julia-actions/setup-julia@v2
        with:
          version: '1'
          arch: ${{ runner.arch }}
        if: steps.julia_in_path.outcome != 'success'
      - name: "Add the General registry via Git"
        run: |
          import Pkg
          ENV["JULIA_PKG_SERVER"] = ""
          Pkg.Registry.add("General")
        shell: julia --color=yes {0}
      - name: "Install CompatHelper"
        run: |
          import Pkg
          name = "CompatHelper"
          uuid = "aa819f21-2bde-4658-8897-bab36330d9b7"
          version = "3"
          Pkg.add(; name, uuid, version)
        shell: julia --color=yes {0}
      - name: "Run CompatHelper"
        run: |
          import CompatHelper
          CompatHelper.main()
        shell: julia --color=yes {0}
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
          COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
          # COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }}

================================================
FILE: .github/workflows/Downstream.yml
================================================
name: IntegrationTest
on:
  push:
    branches: [master]
    tags: [v*]
  pull_request:

# needed to allow julia-actions/cache to delete old caches that it has created
permissions:
  actions: write
  contents: read

jobs:
  test:
    name: ${{ matrix.package.repo }}/${{ matrix.package.group }}
    runs-on: ${{ matrix.os }}
    env:
      GROUP: ${{ matrix.package.group }}
    strategy:
      fail-fast: false
      matrix:
        julia-version: [1]
        os: [ubuntu-latest]
        package:
          - {user: FluxML, repo: Flux.jl, group: All}
          - {user: FluxML, repo: Tracker.jl, group: All}
          - {user: LuxDL, repo: Lux.jl, group: All}
    steps:
      - uses: actions/checkout@v6
      - uses: julia-actions/setup-julia@v2
        with:
          version: ${{ matrix.julia-version }}
          arch: x64
      - uses: julia-actions/cache@v2
      - uses: julia-actions/julia-buildpkg@latest
      - name: Clone Downstream
        uses: actions/checkout@v6
        with:
          repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
          path: downstream
      - name: Load this and run the downstream tests
        shell: julia --color=yes --project=downstream {0}
        run: |
          using Pkg
          try
            # force it to use this PR's version of the package
            Pkg.develop(PackageSpec(path="."))  # resolver may fail with main deps
            Pkg.update()
            Pkg.test()  # resolver may fail with test time deps
          catch err
            err isa Pkg.Resolve.ResolverError || rethrow()
            # If we can't resolve that means this is incompatible by SemVer and this is fine
            # It means we marked this as a breaking change, so we don't need to worry about
            # Mistakenly introducing a breaking change, as we have intentionally made one
            @info "Not compatible with this release. No problem." exception=err
            exit(0)  # Exit immediately, as a success
          end
        env:
          RETESTITEMS_NWORKERS: 4
          BACKEND_GROUP: CPU  # for Lux.jl



================================================
FILE: .github/workflows/TagBot.yml
================================================
name: TagBot
on:
  issue_comment:
    types:
      - created
  workflow_dispatch:
    inputs:
      lookback:
        default: 3
permissions:
  actions: read
  checks: read
  contents: write
  deployments: read
  issues: read
  discussions: read
  packages: read
  pages: read
  pull-requests: read
  repository-projects: read
  security-events: read
  statuses: read
jobs:
  TagBot:
    if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
    runs-on: ubuntu-latest
    steps:
      - uses: JuliaRegistries/TagBot@v1
        with:
          token: ${{ secrets.GITHUB_TOKEN }}
          # Edit the following line to reflect the actual name of the GitHub Secret containing your private key
          ssh: ${{ secrets.DOCUMENTER_KEY }}
          # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }}

================================================
FILE: .github/workflows/ci.yml
================================================
name: CI

on:
  push:
    branches:
      - master
      - staging
      - trying
    tags: '*'
  pull_request:

# needed to allow julia-actions/cache to delete old caches that it has created
permissions:
  actions: write
  contents: read

defaults:
  run:
    shell: bash

jobs:
  test:
    name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.julia-threads }} thread(s) 
    runs-on: ${{ matrix.os }}
    env:
      JULIA_NUM_THREADS: ${{ matrix.julia-threads }}
    strategy:
      fail-fast: false
      matrix:
        version:
          - '1.10' # uncomment when julia 1.10 is out
          - '1'   # automatically expands to the latest stable 1.x release of Julia
          - 'nightly'
        os:
          - ubuntu-latest
          # - macOS-latest
          # - windows-latest
        julia-threads:
          - '1'

        include:
          - os: windows-latest
            version: '1'
            julia-threads: '1'
          - os: macOS-latest
            version: '1'
            julia-threads: '1'
          - os: ubuntu-latest
            version: '1'
            julia-threads: '2'
  
    steps:
      - uses: actions/checkout@v6
      - uses: julia-actions/setup-julia@v2
        with:
          version: ${{ matrix.version }}
      - uses: julia-actions/cache@v2
      - uses: julia-actions/julia-buildpkg@v1

      - name: "Run test without coverage"
        uses: julia-actions/julia-runtest@v1
        if: ${{ !contains(fromJson('["1"]'), matrix.version) || matrix.os != 'ubuntu-latest' }}
        with:
          coverage: false

      - name: "Run test with coverage"
        uses: julia-actions/julia-runtest@v1
        if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest'
      - uses: julia-actions/julia-processcoverage@v1
        if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest'
      - uses: codecov/codecov-action@v5
        if: contains(fromJson('["1"]'), matrix.version) && matrix.os == 'ubuntu-latest'
        with:
          file: lcov.info

  docs:
    name: Documentation
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v6
      - uses: julia-actions/setup-julia@v2
        with:
          version: '1.10'
      - uses: julia-actions/cache@v2
      - run: |
          julia --project=docs -e '
            using Pkg
            Pkg.develop(PackageSpec(path=pwd()))
            Pkg.instantiate()'
      - run: julia --project=docs docs/make.jl
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
          DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}


================================================
FILE: .github/workflows/clean_preview.yml
================================================
# from https://github.com/CliMA/ClimaTimeSteppers.jl
name: Doc Preview Cleanup

on:
  pull_request:
    types: [closed]

jobs:
  doc-preview-cleanup:
    runs-on: ubuntu-latest
    steps:
      - name: Checkout gh-pages branch
        uses: actions/checkout@v6
        with:
          ref: gh-pages
      - name: Delete preview and history + push changes
        run: |
            if [ -d "previews/PR$PRNUM" ]; then
              git config user.name "Documenter.jl"
              git config user.email "documenter@juliadocs.github.io"
              git rm -rf "previews/PR$PRNUM"
              git commit -m "delete preview"
              git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree})
              git push --force origin gh-pages-new:gh-pages
            fi
        env:
            PRNUM: ${{ github.event.number }}


================================================
FILE: .github/workflows/pr_comment.yml
================================================
name: pr_comment
on:
  pull_request:
    types: [labeled]
jobs:
  pr_comment:
    runs-on: ubuntu-latest
    steps:
      - name: Create PR comment
        if: github.event_name == 'pull_request' && github.repository == github.event.pull_request.head.repo.full_name && github.event.label.name == 'documentation' # if this is a pull request build AND the pull request is NOT made from a fork
        uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b
        with:
          message: 'Once the documentation build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/NNlib.jl/previews/PR${{ github.event.number }}/'
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}


================================================
FILE: .gitignore
================================================
*.jl.cov
*.jl.*.cov
*.jl.mem
*.o
*.so
*.dylib
*.dll
*~
\#*
deps/usr
deps.jl
*.log
.vscode/
/Manifest.toml
test/Manifest.toml
benchmark/Manifest.toml
benchmark/*.json
benchmark/report.md


================================================
FILE: LICENSE.md
================================================
The NNlib.jl package is licensed under the MIT "Expat" License:

> Copyright (c) 2017-19: Julia Computing, Inc., Mike J Innes, and Contributors
> 
> Permission is hereby granted, free of charge, to any person obtaining a copy
> of this software and associated documentation files (the "Software"), to deal
> in the Software without restriction, including without limitation the rights
> to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
> copies of the Software, and to permit persons to whom the Software is
> furnished to do so, subject to the following conditions:
> 
> The above copyright notice and this permission notice shall be included in all
> copies or substantial portions of the Software.
> 
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
> IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
> FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
> AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
> LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
> OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
> SOFTWARE.
> 


================================================
FILE: Project.toml
================================================
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.9.34"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
NNlibForwardDiffExt = "ForwardDiff"
NNlibMetalExt = "Metal"
NNlibSpecialFunctionsExt = "SpecialFunctions"

[compat]
AMDGPU = "1, 2"
Adapt = "3.2, 4"
Atomix = "0.1, 1"
CUDA = "4, 5, 6"
ChainRulesCore = "1.25"
EnzymeCore = "0.7, 0.8"
FFTW = "1.8.0"
ForwardDiff = "1"
GPUArraysCore = "0.2"
KernelAbstractions = "0.9.2"
LinearAlgebra = "1"
Metal = "1.6"
Random = "1"
ScopedValues = "1.3.0"
SpecialFunctions = "2"
Statistics = "1"
cuDNN = "1, 6"
julia = "1.10"


================================================
FILE: README.md
================================================
<img align="right" width="200px" src="https://github.com/FluxML/NNlib.jl/raw/master/docs/src/assets/logo.png">

# NNlib.jl

[![Documentation][docs-dev-img]][docs-dev-url]
[![CI](https://github.com/FluxML/NNlib.jl/actions/workflows/ci.yml/badge.svg)](https://github.com/FluxML/NNlib.jl/actions/workflows/ci.yml)
[![Coverage](https://codecov.io/gh/FluxML/NNlib.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/FluxML/NNlib.jl) 

[docs-stable-img]: https://img.shields.io/badge/docs-stable-blue.svg
[docs-stable-url]: https://fluxml.ai/NNlib.jl/stable/

[docs-dev-img]: https://img.shields.io/badge/docs-latest-blue.svg
[docs-dev-url]: https://fluxml.ai/NNlib.jl/dev/

This package provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multiplication, convolutions and pooling. Many of these are used by [Flux.jl](https://github.com/FluxML/Flux.jl), which loads this package, but they may be used independently.

For use with automatic differentiation, this package defines gradients using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). These will be seen by various packages including [Zygote.jl](https://github.com/FluxML/Zygote.jl).

GPU support is provided as package extensions (see the `ext/` folder). In order to load the extensions, use the imports
```julia
using NNlib, CUDA, cuDNN
```
for CUDA support, or
```julia
using NNlib, AMDGPU
```
for AMDGPU support.


================================================
FILE: benchmark/Project.toml
================================================
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
BenchmarkCI = "20533458-34a3-403d-a444-e18f38190b5b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"

[compat]
# No compat bounds for NNlib because we may test breaking versions
ArgParse = "1"
BenchmarkCI = "0.1"
BenchmarkTools = "1.3"
PkgBenchmark = "0.2"
julia = "1.6"


================================================
FILE: benchmark/benchmarks.jl
================================================
using BenchmarkTools
using NNlib
using NNlib.ChainRulesCore: rrule
using Random

Random.seed!(1234567890)

const SUITE = BenchmarkGroup()

SUITE["activations"] = BenchmarkGroup()
for et in (Float16, Float32, Float64)
    et_suite = BenchmarkGroup()
    SUITE["activations"][string(et)] = et_suite
    let x = rand(et, 1024, 1024), y = similar(x)
        for f in NNlib.ACTIVATIONS
            act = @eval($f)
            et_suite[string(f)] = @benchmarkable broadcast!($act, $y, $x)
        end
    end
end

for (fn!, fn_bw) in [(softmax!, NNlib.∇softmax_data), (logsoftmax!, NNlib.∇logsoftmax_data)]
    fn_suite = BenchmarkGroup()
    SUITE[rstrip(string(fn!), '!')] = fn_suite
    let SIZES = [
        (128, 384, 8),
        (512, 784, 8),
        (768, 1024, 4),
        (1024, 2048, 4),
        (2048, 2048, 2),
        (4096, 2048, 2),
        (4096, 4096, 2),
        (12288, 2048, 1)
    ]
        for et in (Float16, Float32)
            et_suite = BenchmarkGroup("fw" => BenchmarkGroup(), "bw" => BenchmarkGroup())
            fn_suite[string(et)] = et_suite
            for sz in SIZES
                x = randn(et, sz)
                y = similar(x)
                dy = zero(x)
                fn!(y, x)
                et_suite["fw"][string(sz)] = @benchmarkable $fn!($y, $x)
                et_suite["bw"][string(sz)] = @benchmarkable $fn_bw($dy, $y)
            end
        end
    end
end



================================================
FILE: benchmark/perf_report.jl
================================================
using JLD2, NNlib, BenchmarkTools

# TODO organize and compare benchmarks using BenchmarkGroups

# We need things to go quickly here
BenchmarkTools.DEFAULT_PARAMETERS.samples = 20
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 2.5

results = Dict()

function add_result(val, keys...)
    r = results
    for k in keys[1:end-1]
        if !haskey(r, k)
            r[k] = Dict()
        end
        r = r[k]
    end
    r[keys[end]] = val
    return r
end

# Modify these as needed
for rank in (2,),
    N in (20, 40, 80),
    C_in in (1,),
    C_out in (1,),
    K in (3,),
    stride in (1,),
    dilation in (1,),
    padding in (0, 2)

    benchmark_items = [
            (NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, "direct"),
            (NNlib.conv_im2col!, NNlib.∇conv_data_im2col!, NNlib.∇conv_filter_im2col!, DenseConvDims, "im2col"),
            (NNlib.depthwiseconv_direct!, NNlib.∇depthwiseconv_data_direct!, NNlib.∇depthwiseconv_filter_direct!, DepthwiseConvDims, "direct"),
            (NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, "im2col"),
    ]

    for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in benchmark_items

        x = zeros(Float32, repeat([N], rank)..., C_in, 1)
        if cT == DenseConvDims
            w = zeros(Float32, repeat([K], rank)..., C_in, C_out)
        else
            w = zeros(Float32, repeat([K], rank)..., C_out, C_in)
        end
        cdims = try
            cT(x, w; stride=stride, dilation=dilation, padding=padding)
        catch
            continue
        end

        if cT == DenseConvDims
            y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1)
        else
            y = zeros(Float32, NNlib.output_size(cdims)..., C_out*C_in, 1)
        end

        dx = similar(x)
        dw = similar(w)
        dy = similar(y)

        t_fwd = @benchmark $(conv!)($y, $x, $w, $cdims)
        t_dx = @benchmark $(∇conv_data!)($dx, $y, $w, $cdims)
        t_dw = @benchmark $(∇conv_filter!)($dw, $x, $y, $cdims)

        add_result(t_fwd, "conv$(rank)d", backend, cdims)
        add_result(t_dx, "conv$(rank)d_data", backend, cdims)
        add_result(t_dw, "conv$(rank)d_filter", backend, cdims)

        @show(cdims)
        @save "results.jld2" results
    end
end


# Modify these as needed
for rank in (2,),
    N in (20,),
    K in (2, 4),
    stride in (1, 2, 4)

    x = zeros(Float32, repeat([N], rank)..., 1, 1)
    pdims = PoolDims(x, K; stride=stride)
    y = zeros(Float32, NNlib.output_size(pdims)..., 1, 1)
    dx = similar(x)

    for (pool, ∇pool, name) in (
            (NNlib.maxpool!, NNlib.∇maxpool!, "maxpool"),
            (NNlib.meanpool!, NNlib.∇meanpool!, "meanpool"),
            (NNlib.lpnormpool!, NNlib.∇lpnormpool!, "lpnormpool"),
        )

        t_fwd  = @benchmark $(pool)( $y, $x, $pdims)
        t_data = @benchmark $(∇pool)($dx, $y, $y, $x, $pdims)

        add_result(t_fwd, "$(name)$(rank)d", "direct", pdims)
        add_result(t_data, "$(name)$(rank)d_data", "direct", pdims)

        @show(pdims)
        @save "results.jld2" results
    end
end


================================================
FILE: benchmark/runbenchmarks.jl
================================================
# Adapted from
# https://github.com/kul-forbes/ProximalOperators.jl/tree/master/benchmark
using ArgParse
using PkgBenchmark
using BenchmarkCI: displayjudgement, printresultmd, CIResult
using Markdown

function markdown_report(judgement)
    md = sprint(printresultmd, CIResult(judgement = judgement))
    md = replace(md, ":x:" => "❌")
    md = replace(md, ":white_check_mark:" => "✅")
    return md
end

function parse_commandline()
    s = ArgParseSettings()

    @add_arg_table! s begin
        "--target"
            help = "the branch/commit/tag to use as target"
            default = "HEAD"
        "--baseline"
            help = "the branch/commit/tag to use as baseline"
            default = "master"
        "--retune"
            help = "force re-tuning (ignore existing tuning data)"
            action = :store_false
    end

    return parse_args(s)
end

function main()
    parsed_args = parse_commandline()

    mkconfig(; kwargs...) =
        BenchmarkConfig(
            env = Dict(
                "JULIA_NUM_THREADS" => get(ENV, "JULIA_NUM_THREADS", "1"),
            );
            kwargs...
        )

    target = parsed_args["target"]
    group_target = benchmarkpkg(
        dirname(@__DIR__),
        mkconfig(id = target),
        resultfile = joinpath(@__DIR__, "result-$(target).json"),
        retune = parsed_args["retune"],
    )

    baseline = parsed_args["baseline"]
    group_baseline = benchmarkpkg(
        dirname(@__DIR__),
        mkconfig(id = baseline),
        resultfile = joinpath(@__DIR__, "result-$(baseline).json"),
    )

    judgement = judge(group_target, group_baseline)
    report_md = markdown_report(judgement)
    write(joinpath(@__DIR__, "report.md"), report_md)
    display(Markdown.parse(report_md))
end

main()


================================================
FILE: docs/.gitignore
================================================
build/
site/
Manifest.toml


================================================
FILE: docs/Project.toml
================================================
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FLAC = "abae9e3b-a9a0-4778-b5c6-ca109b507d99"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"


================================================
FILE: docs/make.jl
================================================
using Documenter, NNlib

DocMeta.setdocmeta!(NNlib, :DocTestSetup,
    :(using FFTW, NNlib, UnicodePlots); recursive = true)

makedocs(modules = [NNlib],
    sitename = "NNlib.jl",
    doctest = true,
    pages = ["Home" => "index.md",
             "Reference" => "reference.md",
             "Audio" => "audio.md"],
    format = Documenter.HTML(
        canonical = "https://fluxml.ai/NNlib.jl/stable/",
        # analytics = "UA-36890222-9",
        assets = ["assets/flux.css"],
        prettyurls = get(ENV, "CI", nothing) == "true"),
    warnonly=[:missing_docs,]
)

deploydocs(repo = "github.com/FluxML/NNlib.jl.git",
           target = "build",
           push_preview = true)


================================================
FILE: docs/src/assets/flux.css
================================================
@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');

body {
  font-family: Lato, "Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;
}

nav.toc {
  padding-top: 0;
  background: rgb(240, 240, 240);
  line-height: 2em;
  cursor: default;
  user-select: none;
}

h1+h2 {
  margin-top: 0;
}

/* Green banner in ToC */
nav.toc > h1 {
  margin-top: 0;
  padding-top: 0.4em;
  padding-bottom: 0.5em;
  border-bottom: 5px solid white;
  box-shadow: 0px -2px 5px rgb(60,60,60);
  margin-bottom: 0.5em;
  background: rgb(60, 150, 60);

  font-style: italic;
  font-weight: normal;
  font-size: 50pt;
  text-transform: lowercase;
  text-shadow: 2px 2px 5px rgba(0,0,0,0.2);
  color: white;
}

/* Reduce ToC font size */
.toctext {
  font-size: 10pt;
}

/* Fade out non-clickable ToC headers */
nav.toc ul span.toctext {
  color: rgb(180, 180, 180);
}

nav.toc ul .toctext {
  color: rgb(100, 100, 100);
}

nav.toc ul a.toctext:hover {
  color: inherit;
  background: rgb(220, 220, 220);
  cursor: default;
}

nav.toc li.current > .toctext {
  background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);
  font-weight: normal;
}

nav.toc ul.internal li.toplevel {
  font-weight: normal;
}

/* Content */

article { max-width: none; }

article > p, article > ul {
  max-width: 45em;
}

/* Links */
a, a:visited { color: rgb(0, 120, 0); }
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
a:hover, a:visited:hover { color: rgb(0, 80, 0); }

/* Article Links */
article p a { border-bottom: 1px solid rgb(200, 230, 200); }
article p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }
article p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }

/* Doctstrings */
article section.docstring {
  padding: 0.5em 0;
  border-left: none;
  border-right: none;
  border-bottom: none;
}

/* Code */

article pre, article p > code {
  background: rgb(245, 250, 245);
}

article pre {
  border: none;
  max-width: none;
  padding: 1em;
  border-radius: 10px 0px 0px 10px;
}

.hljs-comment {
  font-style: italic;
}

.hljs-number {
  color: rgb(0, 150, 150);
}


================================================
FILE: docs/src/audio.md
================================================
# Reference

!!! note
    Spectral functions require importing `FFTW` package to enable them.

## Window functions

```@docs
hann_window
hamming_window
```

## Spectral

```@docs
stft
istft
NNlib.power_to_db
NNlib.db_to_power
```

## Spectrogram

```@docs
melscale_filterbanks
spectrogram
```

Example:

```@example 1
using FFTW # <- required for STFT support.
using NNlib
using FileIO
using Makie, CairoMakie
CairoMakie.activate!()

waveform, sampling_rate = load("./assets/jfk.flac")
fig = lines(reshape(waveform, :))
save("waveform.png", fig)

# Spectrogram.

n_fft = 1024
spec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft))
fig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1]))
save("spectrogram.png", fig)

# Mel-scale spectrogram.

n_freqs = n_fft ÷ 2 + 1
fb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate))
mel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels)
fig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1])
save("mel-spectrogram.png", fig)
nothing # hide
```

|Waveform|Spectrogram|Mel Spectrogram|
|:---:|:---:|:---:|
|![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)|


================================================
FILE: docs/src/index.md
================================================
# NNlib.jl

`NNlib` provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multiplication, convolutions and pooling. Many of these are used by [Flux.jl](https://github.com/FluxML/Flux.jl), which loads this package, but they may be used independently.

For use with automatic differentiation, this package defines gradients using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). These will be seen by various packages including [Zygote.jl](https://github.com/FluxML/Zygote.jl).

GPU support is provided as package extensions. In order to load the extensions, use the imports
```julia
using NNlib, CUDA, cuDNN
```
for CUDA support, or
```julia
using NNlib, AMDGPU
```
for AMDGPU support.

## Threading

Various `NNlib` functions utilize available julia threads on divisible workloads. To disable this use
the `ScopedValue`-backed switch `NNlib.@disallow_spawns`
i.e.
```julia
NNlib.@disallow_spawns function_that_uses_nnlib()
```


================================================
FILE: docs/src/reference.md
================================================
# Reference

The API reference of `NNlib`.

## Activation Functions

Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.

```@docs
celu
elu
gelu
gelu_tanh
gelu_sigmoid
gelu_erf
hardsigmoid
sigmoid_fast
hardtanh
tanh_fast
leakyrelu
lisht
logcosh
logsigmoid
mish
relu
relu6
rrelu
selu
sigmoid
softplus
softshrink
softsign
swish
hardswish
tanhshrink
trelu
```

## Attention 

```@docs
dot_product_attention
dot_product_attention_scores
make_causal_mask
```

## Softmax

`Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally.

```@docs
softmax
logsoftmax
```

## Pooling

`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, `MeanPool` and `lpnormpool` use `NNlib.PoolDims`, `NNlib.maxpool`, `NNlib.meanpool` and `NNlib.lpnormpool` as their backend.

```@docs
PoolDims
maxpool
meanpool
lpnormpool
```

## Padding

```@docs
pad_reflect
pad_symmetric
pad_circular
pad_repeat
pad_constant
pad_zeros
```

## Convolution

`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.

`NNlib.conv` supports complex datatypes on CPU and CUDA devices.

!!! note "AMDGPU MIOpen supports only cross-correlation (`flipkernel=true`)."

    Therefore for every regular convolution (`flipkernel=false`)
    kernel is flipped before calculation.
    For better performance, use cross-correlation (`flipkernel=true`)
    and manually flip the kernel before `NNlib.conv` call.
    `Flux` handles this automatically, this is only required for direct calls.

```@docs
conv
ConvDims
depthwiseconv
DepthwiseConvDims
DenseConvDims
NNlib.unfold
NNlib.fold
```

## Upsampling

`Flux`'s `Upsample` layer uses `NNlib.upsample_nearest`, `NNlib.upsample_bilinear`, and `NNlib.upsample_trilinear` as its backend. Additionally, `Flux`'s `PixelShuffle` layer uses `NNlib.pixel_shuffle` as its backend.

```@docs
upsample_nearest
∇upsample_nearest
upsample_linear
∇upsample_linear
upsample_bilinear
∇upsample_bilinear
upsample_trilinear
∇upsample_trilinear
pixel_shuffle
```

## Rotation
Rotate images in the first two dimensions of an array.

```@docs
imrotate
∇imrotate
```

## Batched Operations

`Flux`'s `Bilinear` layer uses `NNlib.batched_mul` internally.

```@docs
batched_mul
batched_mul!
batched_adjoint
batched_transpose
batched_vec
```

## Gather and Scatter

`Flux`'s `Embedding` layer uses `NNlib.gather` as its backend.

```@docs
NNlib.gather
NNlib.gather!
NNlib.scatter
NNlib.scatter!
```

## Sampling

```@docs
grid_sample
∇grid_sample
```

## Losses

```@docs
ctc_loss
```

## Miscellaneous

```@docs
logsumexp
NNlib.glu
NNlib.within_gradient
bias_act!
```


================================================
FILE: ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl
================================================
module NNlibAMDGPUExt

using Adapt
using AMDGPU
using ChainRulesCore
using NNlib
using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
using NNlib: DenseConvDims, PoolDims

const MIOPENFloat = Union{Float16, Float32}

const ROCBatchedAdjoint{T} = BatchedAdjoint{T, <: ROCArray{T}}
const ROCBatchedTranspose{T} = BatchedTranspose{T, <: ROCArray{T}}
const ROCBatchedAdjOrTrans{T} = Union{ROCBatchedAdjoint{T}, ROCBatchedTranspose{T}}
const WrappedROCBatchedAdjOrTrans{T, N} = Adapt.WrappedArray{T, N, ROCBatchedAdjOrTrans{T}, ROCBatchedAdjOrTrans{T}}
const AnyROCBatchedAdjOrTrans = Union{ROCBatchedAdjOrTrans, WrappedROCBatchedAdjOrTrans}

function Base.convert(::Type{T}, b::AnyROCBatchedAdjOrTrans) where {T <: Array}
    Base.convert(T, adapt(Array, b))
end

function Base.Array{T, N}(b::AnyROCBatchedAdjOrTrans) where {T, N}
    Array{T, N}(adapt(Array, b))
end

Base.collect(b::AnyROCBatchedAdjOrTrans) = collect(adapt(Array, b))

function Base.show(
    io::IO, mime::MIME{Symbol("text/plain")}, x::AnyROCBatchedAdjOrTrans,
)
    show(io, mime, adapt(Array, x))
end

Base.show(io::IO, x::AnyROCBatchedAdjOrTrans) = show(io, adapt(Array, x))

Base.display(x::AnyROCBatchedAdjOrTrans) = display(adapt(Array, x))

function nnlib_padding(dims)
    pd = NNlib.padding(dims)
    if !all(pd[1:2:end] .== pd[2:2:end])
        @warn """
        MIOpen does not support asymmetric padding, defaulting to symmetric choice:
        $pd -> $(pd[1:2:end]).
        """ maxlog=1
    end
    pd[1:2:end]
end

include("batched_mul.jl")

@static if AMDGPU.functional(:MIOpen)
    using AMDGPU.MIOpen

    include("conv.jl")
    include("pool.jl")
    include("activations.jl")
else
    @warn """
    ROCm MIOpen is not available for AMDGPU.
    NNlib has limited functionality for AMDGPU.
    """
end

end


================================================
FILE: ext/NNlibAMDGPUExt/activations.jl
================================================
for (f, op) in [
        NNlib.relu => MIOpen.relu,
        NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6),
        NNlib.softplus => MIOpen.softrelu,
        NNlib.σ => MIOpen.sigmoid,
        Base.tanh => MIOpen.tanh,
        # TODO define for leakyrelu, elu, etc.?
    ], N in 1:5
    @eval function Base.materialize(
        bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat,$N}}}
    )
        return $op(bc.args[1])
    end
end

Base.broadcasted(::typeof(identity), x::ROCArray{T}) where {T<:MIOPENFloat} = x


================================================
FILE: ext/NNlibAMDGPUExt/batched_mul.jl
================================================
function _blas_at(x)
    Base.stride(x, 1) == 1 && return x, 'N'
    Base.stride(x, 2) == 1 && return batched_transpose(x), 'T'
    throw(ArgumentError("""
    Unsupported array layout for batched mul.
    - Size: $(size(x))
    - Strides: $(strides(x))
    """))
end

function NNlib._batched_mul!(
    ::Type{AT}, C, A, B, α::Float16, β::Float16,
) where AT <: ROCArray{Float16}
    blasA, transA = _blas_at(A)
    blasB, transB = _blas_at(B)
    NNlib._batched_gemm!(AT, transA, transB, α, blasA, blasB, β, C)
    C
end

function NNlib._batched_gemm!(
    ::Type{<:ROCArray{T}}, transA::Char, transB::Char, α::T, A, B, β::T, C,
) where T <: Union{MIOPENFloat, Float64}
    AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C)
end


================================================
FILE: ext/NNlibAMDGPUExt/conv.jl
================================================
function NNlib.conv!(
    y::ROCArray{T, N}, x::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims,
) where {T <: MIOPENFloat, N}
    if !NNlib.flipkernel(cdims)
        @warn """
        MIOpen supports only cross-correlation (flipkernel=true).
        Therefore for every regular convolution (flipkernel=false)
        kernel is flipped before calculation.
        For better performance, use cross-correlation (flipkernel=true)
        and manually flip the kernel before `NNlib.conv` call.
        """ maxlog=1
        flip_dims = ntuple(
            i -> (i ≤ ndims(w) - 2) ? (size(w, i):-1:1) : Colon(),
            ndims(w))
        w = w[flip_dims...]
    end

    nd = max(0, 4 - N)
    ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)
    MIOpen.convolution!(
        NNlib.insert_singleton_spatial_dimension(y, nd),
        NNlib.insert_singleton_spatial_dimension(x, nd),
        NNlib.insert_singleton_spatial_dimension(w, nd);
        padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),
        dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))
    return y
end

function NNlib.∇conv_data!(
    dx::ROCArray{T, N}, dy::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims,
) where {T <: MIOPENFloat, N}
    if !NNlib.flipkernel(cdims)
        @warn """
        MIOpen supports only cross-correlation (flipkernel=true).
        Therefore for every regular convolution (flipkernel=false)
        kernel is flipped before calculation.
        For better performance, use cross-correlation (flipkernel=true)
        and manually flip the kernel before `NNlib.conv` call.
        """ maxlog=1
        flip_dims = ntuple(
            i -> (i ≤ ndims(w) - 2) ? (size(w, i):-1:1) : Colon(),
            ndims(w))
        w = w[flip_dims...]
    end

    nd = max(0, 4 - N)
    ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)
    MIOpen.∇convolution_data!(
        NNlib.insert_singleton_spatial_dimension(dx, nd),
        NNlib.insert_singleton_spatial_dimension(dy, nd),
        NNlib.insert_singleton_spatial_dimension(w, nd);
        padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),
        dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))
    return dx
end

function NNlib.∇conv_filter!(
    dw::ROCArray{T, N}, x::ROCArray{T, N}, dy::ROCArray{T, N}, cdims::DenseConvDims,
) where {T <: MIOPENFloat, N}
    nd = max(0, 4 - N)
    ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)
    MIOpen.∇convolution_weight!(
        NNlib.insert_singleton_spatial_dimension(dw, nd),
        NNlib.insert_singleton_spatial_dimension(dy, nd),
        NNlib.insert_singleton_spatial_dimension(x, nd);
        padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),
        dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))

    if !NNlib.flipkernel(cdims)
        @warn """
        MIOpen supports only cross-correlation (flipkernel=true).
        Therefore for every regular convolution (flipkernel=false)
        kernel is flipped before calculation.
        For better performance, use cross-correlation (flipkernel=true)
        and manually flip the kernel before `NNlib.conv` call.
        """ maxlog=1
        flip_dims = ntuple(
            i -> (i ≤ ndims(dw) - 2) ? (size(dw, i):-1:1) : Colon(),
            ndims(dw))
        dw = dw[flip_dims...]
    end
    return dw
end


================================================
FILE: ext/NNlibAMDGPUExt/pool.jl
================================================
for poolname in (:maxpool, :meanpool)
    @eval function NNlib.$(poolname)(
        x::ROCArray{T, N}, pdims::PoolDims,
    ) where {T <: MIOPENFloat, N}
        y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N))
        nd = max(0, 4 - N)
        npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd)
        MIOpen.$(Symbol("$(poolname)!"))(
            NNlib.insert_singleton_spatial_dimension(y, nd),
            NNlib.insert_singleton_spatial_dimension(x, nd);
            dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),
            stride=NNlib.stride(npdims), do_backward=false)
        return y
    end

    @eval function ChainRulesCore.rrule(
        ::typeof(NNlib.$(poolname)), x::ROCArray{T, N}, pdims::PoolDims,
    ) where {T <: MIOPENFloat, N}
        y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N))
        nd = max(0, 4 - N)
        npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd)

        # `workspace` is used in the pullback.
        _, workspace = MIOpen.$(Symbol("$(poolname)!"))(
            NNlib.insert_singleton_spatial_dimension(y, nd),
            NNlib.insert_singleton_spatial_dimension(x, nd);
            dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),
            stride=NNlib.stride(npdims))

        function _pooling_pullback(Δ)
            dx = similar(x)
            MIOpen.$(Symbol("∇$(poolname)!"))(
                NNlib.insert_singleton_spatial_dimension(dx, nd),
                NNlib.insert_singleton_spatial_dimension(unthunk(Δ), nd),
                NNlib.insert_singleton_spatial_dimension(y, nd),
                NNlib.insert_singleton_spatial_dimension(x, nd);
                dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),
                stride=NNlib.stride(npdims), workspace)
            return NoTangent(), dx, NoTangent()
        end
        y, _pooling_pullback
    end
end


================================================
FILE: ext/NNlibCUDACUDNNExt/NNlibCUDACUDNNExt.jl
================================================
module NNlibCUDACUDNNExt

using NNlib
using cuDNN
using CUDA
using Random, Statistics

using cuDNN: handle, with_workspace, cudnnTensorDescriptor, cudnnFilterDescriptor,
             cudnnDataType, math_mode, CUDNN_DEFAULT_REORDER, CUDNN_CROSS_CORRELATION,
             CUDNN_NOT_PROPAGATE_NAN, CUDNN_TENSOR_NCHW, dim4

cudnnversion() = cuDNN.version()

function nnlibPadding(dims)
    pd = NNlib.padding(dims)
    if !all(pd[1:2:end] .== pd[2:2:end])
        @warn "cuDNN does not support asymmetric padding; defaulting to symmetric choice" maxlog=1
    end
    return pd[1:2:end]
end

include("conv.jl")
include("pooling.jl")
include("softmax.jl")
include("activations.jl")
include("batchnorm.jl")

end # module

================================================
FILE: ext/NNlibCUDACUDNNExt/activations.jl
================================================

# Activation

using Base.Broadcast
using cuDNN: cudnnActivationForward!, cudnnOpTensor!,
             CUDNN_ACTIVATION_TANH, CUDNN_ACTIVATION_SIGMOID, CUDNN_ACTIVATION_ELU,
             CUDNN_ACTIVATION_RELU, CUDNN_ACTIVATION_CLIPPED_RELU, CUDNN_OP_TENSOR_MAX,
             CUDNN_ACTIVATION_IDENTITY

for (f, op) in [
    CUDA.tanh       => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_TANH),
    NNlib.σ         => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_SIGMOID),
    NNlib.elu       => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_ELU),
    NNlib.relu      => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_RELU),
    # NNlib.relu6     => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_CLIPPED_RELU, coef=6.0),
    # NNlib.leakyrelu => (src,dst)->cudnnOpTensor!(dst, src, src; op=CUDNN_OP_TENSOR_MAX, alpha1=0.01),
    ]

    @eval begin
        # in-place
        function Base.materialize!(dst::DenseCuArray{<:CUDNNFloat},
                                   bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
            $op(bc.args[1], dst)
            return dst
        end

        # out of place
        function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})
            ElType = Broadcast.combine_eltypes(bc.f, bc.args)
            dst = similar(bc, ElType)
            $op(bc.args[1], dst)
            return dst
        end
    end
end

# CUDNN_ACTIVATION_IDENTITY does not work with cudnnActivationForward
# FIXME: put this optimization in GPUArrays' `copyto!` (like Base.Broadcast's `copyto!`)
Base.broadcasted(::typeof(identity), x::DenseCuArray{T}) where {T<:CUDNNFloat} = x



================================================
FILE: ext/NNlibCUDACUDNNExt/batchnorm.jl
================================================
using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,
             cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL,
             cudnnBatchNormalizationForwardTraining
import NNlib: batchnorm, ∇batchnorm

# TODO: replace with new cudnn normalization interface
# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl

mutable struct BNCache
  mean
  ivar
end

BNCache() = BNCache(nothing, nothing)

@inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)

function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray,
                   running_mean, running_var, momentum; kws...)
  affine_sz = _wsize(x)
  g = fill!(similar(x, affine_sz), 1)
  b = fill!(similar(x, affine_sz), 0)
  return batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
end

# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
# so reshape a 2D Tensor into 4D
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
                   running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
  x = reshape(x, 1, 1, size(x, 1), size(x, 2))
  y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
  return dropdims(y, dims = (1, 2))
end

function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}},
                   running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
  cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)
end

function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T},
                        running_mean, running_var, momentum;
                        cache = nothing,
                        alpha = T(1), beta = T(0),
                        eps = T(1e-5),
                        training = true,
                        affine = true,
                        track_stats = true) where T<:CUDNNFloat
  dims = _wsize(x)
  if eps < CUDNN_BN_MIN_EPSILON
    @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
    eps = CUDNN_BN_MIN_EPSILON
  end

  if running_mean === nothing || running_var === nothing
    running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing"))
    if track_stats || !training
      running_mean = fill!(similar(x, dims), 0)
      running_var = fill!(similar(x, dims), 1)
    end
  end

  xd = cudnnTensorDescriptor(x)
  yd = cudnnTensorDescriptor(y)
  gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW)))

  if training
    if !track_stats
      running_mean = CU_NULL
      running_var = CU_NULL
    end

    if cache !== nothing
      mean = fill!(similar(x, dims), 0)
      ivar = fill!(similar(x, dims), 1)
    else
      mean = CU_NULL
      ivar = CU_NULL
    end

    cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar)

    if cache !== nothing
      cache.mean = mean
      cache.ivar = ivar
    end
  else
    if track_stats
      cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps)
    else
      # cudnnBatchNormalizationForwardInference does not accept CV_NULL for running_mean
      # and running_var. We could calculate mean and var of `x` here, but instead use
      # cudnnBatchNormalizationFowardTraining. cudnnBatchNormalizationForwardTraining does
      # accept CV_NULL and will calculate mean and var itself.
      cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, CU_NULL, CU_NULL, eps, CU_NULL, CU_NULL)
    end
  end
  return y
end

function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray,
                    running_mean, running_var, momentum; kws...)
  affine_sz = _wsize(x)
  g = fill!(similar(x, affine_sz), 1)
  b = fill!(similar(x, affine_sz), 0)
  return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)
end

function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},
            running_mean, running_var, momentum;
            kws...) where T<:CUDNNFloat
  dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
                          size(dy, 2)), running_mean, running_var, momentum; kws...)
  (dg, db, dropdims(dx, dims = (1, 2)))
end


function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
                    running_mean, running_var, momentum;
                    affine=true, kws...) where T<:CUDNNFloat
  dg = similar(g)
  db = similar(b)
  dx = similar(x)
  cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...)
  if affine
    (dg, db, dx)
  else
    # cuDNN always calculates dg and db, therefore we just have to drop them
    (nothing, nothing, dx)
  end
end

function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T},
                          dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
                          running_mean, running_var,
                          momentum; cache = nothing, eps = T(1e-5),
                          alpha = T(1), beta = T(0),
                          dalpha = T(1), dbeta = T(0), training = true,
                          track_stats = true) where T<:CUDNNFloat
  if !track_stats
    running_mean = CU_NULL
    running_var = CU_NULL
  end

  xd = cudnnTensorDescriptor(x)
  dyd = cudnnTensorDescriptor(dy)
  dxd = cudnnTensorDescriptor(dx)
  gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
  if cache !== nothing
    @debug "fetching mean and ivar from the cache"
    mean, ivar = cache.mean, cache.ivar
  else
    mean, ivar = CU_NULL, CU_NULL
  end

  if eps < CUDNN_BN_MIN_EPSILON
    @warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
    eps = CUDNN_BN_MIN_EPSILON
  end

  cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
        scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
        xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
end


================================================
FILE: ext/NNlibCUDACUDNNExt/conv.jl
================================================

using NNlib: DenseConvDims
import NNlib: conv!, ∇conv_filter!, ∇conv_data!, conv_bias_act!

using cuDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
             cudnnConvolutionBwdDataAlgoPerf,
             cudnnConvolutionForward!, cudnnConvolutionBwdFilterAlgoPerf,
             cudnnConvolutionBackwardData, cudnnConvolutionBackwardFilter,
             cudnnConvolutionBackwardBias
import cuDNN: cudnnConvolutionDescriptor

const CUDNNFloat = Union{Float16,Float32,Float64}
const CUDNNComplexFloat = Union{ComplexF16,ComplexF32,ComplexF64}

function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T
    # The main purpose of this function is to catch asymmetric padding which cudnn does not support
    # If we find asymmetric padding we'll make a copy of x which is manually padded so that we can
    # call cudnn with symmetric padding.
    pad = NNlib.padding(cdims)
    sdims = NNlib.spatial_dims(cdims)
    all(i -> pad[i] .== pad[i+1], 1:2:2sdims) && return (cudnnConvolutionDescriptor(cdims, x), x, identity)

    # Naive implementation, is there a faster way?
    # How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top
    # and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out
    # which side of x to pad. Oh, and we use a CartesianIndex as we will mainly use this to index in x
    pad_manual = CartesianIndex(ntuple(i -> i > sdims ? 0 : pad[2(i-1)+1] - pad[2(i-1)+2], ndims(x)))

    # How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top
    # and pad_bottom etc. respectively
    pad_cudnn = ntuple(i -> min(pad[2(i-1)+1], pad[2(i-1)+2]), sdims)

    x_padded_size = ntuple(i -> i <= sdims ? size(x, i) + abs(pad_manual[i]) : size(x ,i), ndims(x))
    x_padded = similar(x, x_padded_size)
    fill!(x_padded, 0)
    # This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x
    # Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim
    # by dim to an array in a loop
    xIs = CartesianIndices(x)
    xI_first = first(xIs)
    xI_last = last(xIs)
    xIs_pad = max(xI_first, xI_first + pad_manual) : max(xI_last, xI_last + pad_manual)
    x_padded[xIs_pad] = x

    return cudnnConvolutionDescriptor(cdims, x_padded, pad_cudnn), x_padded, _x -> _x[xIs_pad]
end

function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pad = nnlibPadding(cdims)) where T
    mode=(NNlib.flipkernel(cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION)
    cudnnConvolutionDescriptor(convdims(pad, size(x),0),
                               convdims(NNlib.stride(cdims),size(x),1),
                               convdims(NNlib.dilation(cdims),size(x),1),
                               mode,
                               cudnnDataType(real(T)),
                               math_mode(),
                               CUDNN_DEFAULT_REORDER,
                               Cint(NNlib.groupcount(cdims)))
end

@inline function _complex!(y::DenseCuArray{T1}, yr::DenseCuArray{T2}, yi::DenseCuArray{T2}; bias=zero(T1), alpha=one(T1), beta=zero(T1), σ=identity) where {T1 <: CUDNNComplexFloat, T2<:CUDNNFloat}
    # if y is from similar(), it may have NaNs, and beta*NaN will propagate.
    if beta != 0
        @. y = σ(alpha*(yr + im*yi) + bias + beta*y)
    else
        @. y = σ(alpha*(yr + im*yi) + bias)
    end
    return y
end

function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;
               alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
    if cudnnversion() < v"6"
        all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
    end
    if algo != -1
        @warn "algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
    end
    d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
    cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y)
end

# Complex convolution with Gauss's trick (1 complex mul === 3 real mul):
# Consider x = xr + im*xi, y = yr + im*yi,
# so x*y = (xr*yr - xi*yi) + im*(xr*yi + xi*yr).
# Let a = xr*yr,
#     b = xi*yi,
#     c = (xr + xi)*(yr + yi) = xr*yr + xr*yi + xi*yr + xi*yi.
# Then,
# x*y = (a - b) + im*(c - a - b).
# Convolution is linear so this multiplication trick translates to convolution.
function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;
               alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
    xr, xi = reim(x)
    wr, wi = reim(w)
    a = conv!(similar(real(y)), xr, wr, cdims; algo=algo)
    b = conv!(similar(a), xi, wi, cdims; algo=algo)
    c = conv!(similar(a), xr + xi, wr + wi, cdims; algo=algo)
    return _complex!(y, a - b, c - a - b; alpha=alpha, beta=beta)
end

# (xr + im*xi) * w = xr*w + im*(xi*w)
function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T1}, w::DenseCuArray{T2}, cdims::DenseConvDims;
               alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
    xr, xi = reim(x)
    yr = conv!(similar(real(y)), xr, w, cdims; algo=algo)
    yi = conv!(similar(yr), xi, w, cdims; algo=algo)
    return _complex!(y, yr, yi; alpha=alpha, beta=beta)
end

# x * (wr + im*wi) = x*wr + im*(x*wi)
function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T2}, w::DenseCuArray{T1}, cdims::DenseConvDims;
               alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
    wr, wi = reim(w)
    yr = conv!(similar(real(y)), x, wr, cdims; algo=algo)
    yi = conv!(similar(yr), x, wi, cdims; algo=algo)
    return _complex!(y, yr, yi; alpha=alpha, beta=beta)
end

function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
                        cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
                        z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
    if cudnnversion() < v"6"
        all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
    end
    if algo != -1
        @warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
    end
    d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
    # only relu and identity are supported by cudnnConvolutionForward!
    activation = (σ == NNlib.relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY)
    cudnnConvolutionForward!(y, w, x, d; z, bias, activation, alpha, beta)
    if activation === CUDNN_ACTIVATION_IDENTITY && σ ∉ (nothing, identity)
        @. y = σ(y)
    end
    return y
end

function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
                        cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
                        z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
    xr, xi = reim(x)
    wr, wi = reim(w)
    a = conv!(similar(real(y)), xr, wr, cdims; alpha=1, beta=0, algo=algo)
    b = conv!(similar(a), xi, wi, cdims; alpha=1, beta=0, algo=algo)
    c = conv!(similar(a), xr + xi, wr + wi, cdims; alpha=1, beta=0, algo=algo)
    return _complex!(y, a - b, c - a - b; bias=bias, alpha=alpha, beta=beta, σ=σ)
end

function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},
                     cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
    if cudnnversion() < v"6"
        all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
    end
    if algo != -1
        @warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
    end
    alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
    convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx)
    xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)
    p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx, beta!=0)
    with_workspace(p.memory) do workspace
        cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx)
    end
    return depad(dx)
end

function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},
                     cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
    dyr, dyi = reim(dy)
    wr, wi = reim(w)
    # note: w is conjugated, i.e. wi is negated below
    a = ∇conv_data!(similar(real(dx)), dyr, wr, cdims; alpha=1, beta=0, algo=algo)
    b = ∇conv_data!(similar(a), dyi, -wi, cdims; alpha=1, beta=0, algo=algo)
    c = ∇conv_data!(similar(a), dyr + dyi, wr - wi, cdims; alpha=1, beta=0, algo=algo)
    return _complex!(dx, a - b, c - a - b; alpha=alpha, beta=beta)
end

# dx = (dyr + im*dyi)*w = dyr*w + im*(dyi*w)
function ∇conv_data!(dx::DenseCuArray{T1}, dy::DenseCuArray{T1}, w::DenseCuArray{T2},
                     cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
    dyr, dyi = reim(dy)
    dxr = ∇conv_data!(similar(real(dx)), dyr, w, cdims; alpha=1, beta=0, algo=algo)
    dxi = ∇conv_data!(similar(dxr), dyi, w, cdims; alpha=1, beta=0, algo=algo)
    return _complex!(dx, dxr, dxi; alpha=alpha, beta=beta)
end

function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
                       cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
    if cudnnversion() < v"6"
        all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
    end
    if algo != -1
        @warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
    end
    alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
    convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
    xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)
    p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw, beta!=0);
    with_workspace(p.memory) do workspace
        cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw);
    end
    return dw
end

function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
                       cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
    xr, xi = reim(x)
    dyr, dyi = reim(dy)
    # note: x is conjugated, i.e. xi is negated below
    a = ∇conv_filter!(similar(real(dw)), xr, dyr, cdims; alpha=1, beta=0, algo=algo)
    b = ∇conv_filter!(similar(a), -xi, dyi, cdims; alpha=1, beta=0, algo=algo)
    c = ∇conv_filter!(similar(a), xr - xi, dyr + dyi, cdims; alpha=1, beta=0, algo=algo)
    return _complex!(dw, a - b, c - a - b; alpha=alpha, beta=beta)
end

# dw = x*(dyr + im*dyi) = x*dyr + im*(x*dyi)
function ∇conv_filter!(dw::DenseCuArray{T1}, x::DenseCuArray{T2}, dy::DenseCuArray{T1},
                       cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
    dyr, dyi = reim(dy)
    dwr = ∇conv_filter!(similar(real(dw)), x, dyr, cdims; alpha=1, beta=0, algo=algo)
    dwi = ∇conv_filter!(similar(dwr), x, dyi, cdims; alpha=1, beta=0, algo=algo)
    return _complex!(dw, dwr, dwi; alpha=alpha, beta=beta)
end

function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNFloat
    alpha,beta = scalingParameter(T,alpha), scalingParameter(T,beta)
    bDesc, yDesc = cudnnTensorDescriptor.((db,dy))
    cudnnConvolutionBackwardBias(handle(), alpha, yDesc, dy, beta, bDesc, db)
    return db
end

function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNComplexFloat
    dyr, dyi = reim(dy)
    dbr = ∇conv_bias!(similar(real(db)), dyr; alpha=1, beta=0)
    dbi = ∇conv_bias!(similar(dbr), dyi; alpha=1, beta=0)
    return _complex!(db, dbr, dbi; alpha=alpha, beta=beta)
end


================================================
FILE: ext/NNlibCUDACUDNNExt/pooling.jl
================================================
using cuDNN: cudnnPoolingMode_t, CUDNN_POOLING_MAX,
             CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING,
             cudnnPoolingForward!, pooldims, cudnnPoolingBackward

import NNlib: maxpool!, ∇maxpool!, meanpool!, ∇meanpool!
import cuDNN: cudnnPoolingDescriptor

function cudnnPoolingDescriptor(pdims::PoolDims, x::DenseCuArray{T}, mode::cudnnPoolingMode_t) where T
    window, padding, stride = NNlib.kernel_size(pdims), nnlibPadding(pdims), NNlib.stride(pdims)
    nanOpt = CUDNN_NOT_PROPAGATE_NAN
    cudnnPoolingDescriptor(mode, nanOpt, Cint(ndims(x)-2), pooldims(window,size(x)), pooldims(padding,size(x)), pooldims(stride,size(x)))
end

function maxpool!(y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat
    d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_MAX)
    cudnnPoolingForward!(y, x, d)
end

function ∇maxpool!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat
    xDesc, yDesc = cudnnTensorDescriptor.((x, y))
    d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_MAX)
    alpha, beta = scalingParameter(T,1), scalingParameter(T,0)
    cudnnPoolingBackward(handle(), d, alpha, yDesc, y, yDesc, dy, xDesc, x, beta, xDesc, dx)
    return dx
end

function meanpool!(y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat
    d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING)
    cudnnPoolingForward!(y, x, d)
end

function ∇meanpool!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat
    xDesc, yDesc = cudnnTensorDescriptor.((x, y))
    d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING)
    alpha, beta = scalingParameter(T,1), scalingParameter(T,0)
    cudnnPoolingBackward(handle(), d, alpha, yDesc, y, yDesc, dy, xDesc, x, beta, xDesc, dx)
    return dx
end

### Since CUDA.jl does not support 1D pooling, we have to convert to 2d

add1d(x) = reshape(x, 1, size(x)...)

function fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D}
    PoolDims{2, K + 1, S + 1, P + 2, D + 1}((1, NNlib.input_size(pdims)...),
                                            (1, NNlib.kernel_size(pdims)...),
                                            NNlib.channels_in(pdims),
                                            (1, NNlib.stride(pdims)...),
                                            (0, 0, NNlib.padding(pdims)...),
                                            (1, NNlib.dilation(pdims)...))
end

function maxpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
    maxpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims))
    return y
end

function meanpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
    meanpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims))
    return y
end

function ∇maxpool!(dx::DenseCuArray{T,3}, dy::DenseCuArray{T,3}, y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
    ∇maxpool!(add1d(dx), add1d(dy), add1d(y), add1d(x), fix_pooldims_1d(pdims))
    return dx
end

function ∇meanpool!(dx::DenseCuArray{T,3}, dy::DenseCuArray{T,3}, y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
    ∇meanpool!(add1d(dx), add1d(dy), add1d(y), add1d(x), fix_pooldims_1d(pdims))
    return dx
end




================================================
FILE: ext/NNlibCUDACUDNNExt/softmax.jl
================================================
import NNlib: softmax, softmax!, ∇softmax, ∇softmax!,
              logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!

using cuDNN: CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL,
             CUDNN_SOFTMAX_FAST, CUDNN_SOFTMAX_ACCURATE, cudnnSoftmaxForward!,
             cudnnSoftmaxBackward

# Softmax

# @denizyuret: do not do inplace operations with softmax/logsoftmax when (1) cpu version is not, (2) one can use softmax!
function softmax(x::T; dims=1) where {T<:DenseCuArray}
    softmax!(similar(x), x; dims)
end

function ∇softmax(dy::T, x::T, y::T; dims=1) where {T<:DenseCuArray}
    ∇softmax!(similar(x), dy, x, y; dims)
end

function logsoftmax(x::T; dims=1) where {T<:DenseCuArray}
    logsoftmax!(similar(x), x; dims)
end

function ∇logsoftmax(dy::T, x::T, y::T; dims=1) where {T<:DenseCuArray}
    ∇logsoftmax!(similar(x), dy, x, y; dims)
end

# @denizyuret: backup implementations for unsupported/slow size/dims combinations:
function _softmax!(y::T, x::T; dims) where {T<:DenseCuArray}
    y .= exp.(x .- maximum(x; dims))
    y ./= sum(y; dims)
end

function _∇softmax!(dx::T, dy::T, x::T, y::T; dims) where {T<:DenseCuArray}
    dx .= y .* (dy .- sum(dy .* y; dims))
end

function _logsoftmax!(y::T, x::T; dims) where {T<:DenseCuArray}
    y .= x .- maximum(x; dims)
    y .-= log.(sum(exp.(y); dims))
end

function _∇logsoftmax!(dx::T, dy::T, x::T, y::T; dims) where {T<:DenseCuArray}
    dx .= dy .- sum(dy; dims) .* exp.(y)
end

# Trick by @norci to use cudnn for softmax dims args that are contiguous:
# If dims=(dmin:dmax) then CUDNN_SOFTMAX_MODE_CHANNEL does the trick with reshape
#    (1, prod(size(x)[1:dmin-1]), prod(size(x)[dmin:dmax]), :)
# softmaxdims returns nothing when the backup implementation should be used.

function softmaxdims(x, dims)
    dims === Colon() && return (1, 1, length(x), 1)
    mind,maxd = minimum(dims),maximum(dims)
    all(i in dims for i in mind:maxd) || return nothing # cannot handle if not contiguous
    stride = dimsize = 1
    for i in 1:(mind-1); stride *= size(x,i); end # Using size(x,i) assumes trailing dims = 1, robust to maxd > ndims(x)
    for i in mind:maxd; dimsize *= size(x,i); end
    batchsize = length(x)÷(stride*dimsize)
    # Here is a region where cudnn is slower, so we go with the backup:
    batchsize == 1 && 64 <= stride <= 4096 && 64 <= dimsize <= 4096 && return nothing
    return (1, stride, dimsize, batchsize)
end

# Determine softmax algo based on math_mode

softmaxalgo() = (CUDA.math_mode()===CUDA.FAST_MATH ? CUDNN_SOFTMAX_FAST : CUDNN_SOFTMAX_ACCURATE)

# Main implementations:

function softmax!(y::T, x::T = y; dims=1) where {T<:DenseCuArray}
    s = softmaxdims(x, dims)
    s === nothing && return _softmax!(y, x; dims)
    cudnnSoftmaxForward!(reshape(y,s), reshape(x,s); mode = CUDNN_SOFTMAX_MODE_CHANNEL, algo = softmaxalgo())
    return y
end

function ∇softmax!(dx::T, dy::T, x::T, y::T; dims=1) where {R,T<:DenseCuArray{R}}
    s = softmaxdims(x, dims)
    s === nothing && return _∇softmax!(dx, dy, x, y; dims)
    xDesc = cudnnTensorDescriptor(reshape(x,s))
    alpha, beta = scalingParameter(R,1), scalingParameter(R,0)
    cudnnSoftmaxBackward(handle(), softmaxalgo(), CUDNN_SOFTMAX_MODE_CHANNEL,
                         alpha, xDesc, y, xDesc, dy, beta, xDesc, dx)
    return dx
end

function logsoftmax!(y::T, x::T = y; dims=1) where {T<:DenseCuArray}
    s = softmaxdims(x, dims)
    s === nothing && return _logsoftmax!(y, x; dims)
    cudnnSoftmaxForward!(reshape(y,s), reshape(x,s); mode = CUDNN_SOFTMAX_MODE_CHANNEL, algo = CUDNN_SOFTMAX_LOG)
    return y
end

function ∇logsoftmax!(dx::T, dy::T, x::T, y::T; dims=1) where {R,T<:DenseCuArray{R}}
    s = softmaxdims(x, dims)
    s === nothing && return _∇logsoftmax!(dx, dy, x, y; dims)
    xDesc = cudnnTensorDescriptor(reshape(x,s))
    alpha, beta = scalingParameter(R,1), scalingParameter(R,0)
    cudnnSoftmaxBackward(handle(), CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL,
                         alpha, xDesc, y, xDesc, dy, beta, xDesc, dx)
    return dx
end


================================================
FILE: ext/NNlibCUDAExt/NNlibCUDAExt.jl
================================================
module NNlibCUDAExt

using NNlib
using CUDA
using Random, Statistics

include("sampling.jl")
include("activations.jl")
include("batchedadjtrans.jl")
include("batchedmul.jl")
include("ctc.jl")
include("scatter.jl")
include("utils.jl")

end # module


================================================
FILE: ext/NNlibCUDAExt/activations.jl
================================================
# Activation functions

# Some of activation functions need a wrapper for GPU support
# https://github.com/JuliaGPU/CuArrays.jl/issues/614

# @cufunc softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))

# @cufunc logσ(x::Real) = -softplus(-x)

# @cufunc function gelu(x::Real)
#     p = oftype(x / 1, π)
#     λ = oftype(x / 1, √(2 / p))
#     α = oftype(x / 1, 0.044715)
#     h = oftype(x / 1, 0.5)
#     h * x * (one(x) + tanh(λ * (x + α * x^3)))
# end

# @cufunc lisht(x::Real) = x * tanh(x)

# @cufunc logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2))

# @cufunc mish(x::Real) = x * tanh(softplus(x))

# @cufunc tanhshrink(x::Real) = x - tanh(x)


================================================
FILE: ext/NNlibCUDAExt/batchedadjtrans.jl
================================================
using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
using Adapt
using Adapt: WrappedArray

const CuBatchedAdjoint{T} = BatchedAdjoint{T, <: CuArray{T}}
const CuBatchedTranspose{T} = BatchedTranspose{T, <: CuArray{T}}
const CuBatchedAdjOrTrans{T} = Union{CuBatchedAdjoint{T}, CuBatchedTranspose{T}}
const WrappedCuBatchedAdjOrTrans{T, N} = WrappedArray{T, N, CuBatchedAdjOrTrans{T}, CuBatchedAdjOrTrans{T}}


Base.print_array(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))
Base._show_nonempty(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)
Base.show_vector(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)

Base.convert(::Type{T}, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))
Base.Array{T, N}(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))
Base.collect(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = collect(adapt(Array, b))


================================================
FILE: ext/NNlibCUDAExt/batchedmul.jl
================================================
# Batched matrix multiplication
# 1st argument is produced by NNlib.storage_type(A)
NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
     CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C)

Base.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} =
    Base.unsafe_convert(CuPtr{T}, parent(A))


================================================
FILE: ext/NNlibCUDAExt/ctc.jl
================================================
# CTC loss moved from Flux.jl to NNlib

import NNlib: ctc_loss, ctc_alpha, ∇ctc_loss

## GPU implementation

# a port of the GPU kernels from Baidu's C++ warp-ctc package,
# which itself is Copyright 2015-2016 Baidu USA LLC
# and available under the Apache 2.0 license
#
# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0
# GitHub: https://github.com/baidu-research/warp-ctc/
# paper: https://arxiv.org/pdf/1512.02595.pdf

const MAX_THREADS = 256

@inline function log_plus_f(p1, p2)
  isinf(p1) && return p2
  isinf(p2) && return p1
  if p1 < p2
    p1, p2 = p2, p1
  end
  return p1 + log(1+exp(p2 - p1))
end

function count_repeats(A)
  repeats = 0
  for (i,elem) in enumerate(A)
    if i > 1 && A[i] == A[i-1]
      repeats += 1
    end
  end
  return repeats
end

function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)

  tid = threadIdx().x
  L = labelSize
  T = uttLength
  S = length(labelsWithBlanks)

  if L + repeats > T
    return nothing
  end
  labels = labelsWithBlanks

  # Corner-case checking
  start = (L + repeats <= T) ? 0 : 1
  last = S > 1 ? 2 : 1

  # Fill in first column (time step)
  i = tid
  while i <= last - start
    alpha[start+i, 1] = probs[labels[start+i], 1]
    i += blockDim().x
  end
  sync_threads()

  # Fill in coefficients for each time step
  for t=2:T
    # Corner-case checking
    if tid == 1 && !(1 < S - 2*(T-t) - 1)
      if start == 0
        alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t]
      elseif start == 1
        alpha[1, t] = alpha[1, t-1]
      end
    end
    sync_threads()

    # Fill in coefficients for each label class in the target output sequence;
    # each thread will process the calculations for one class
    idx = tid+1
    while idx <= S
      prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])
      if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
        prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
      end
      if idx < S - 2*(T-t) - 1
        alpha[idx, t] = -Inf32
      else
        alpha[idx, t] = prevSum + probs[labels[idx], t]
      end
      idx += blockDim().x
    end
    sync_threads()
  end
  return nothing
end

function compute_beta_and_grad_kernel(probs, labelSize, uttLength,
                  repeatsInLabel, labelsWithBlanks,
                  alphas, beta, output, accum,
                  grad, blankLabel, loss)

  tid = threadIdx().x
  L = labelSize
  T = uttLength
  S = 2*L + 1
  repeats = repeatsInLabel
  labels = labelsWithBlanks

  if (L+repeats) > T
    return nothing
  end

  # Corner-case checking
  start = S > 1 ? S-2 : 0
  last = L + repeats < T ? S : S-1
  sync_threads()
  i = tid

  # Calculate coefficients for last column (time step)
  # then determine alpha and beta product
  while i <= last - start
    beta[i+start, T] = 0
    output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
    i += blockDim().x
  end
  sync_threads()

  # Fill in `accum` for last column (time step)
  if tid == 1
    for i=1:S
      labelIdx = labels[i]
      accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
    end
  end
  sync_threads()

  # Fill in `grad` for last column (time step)
  idx = tid
  while idx <= size(grad, 1)
    s = -Inf32
    for i=1:S
      s = log_plus_f(s, output[i, T])
    end

    # ∂L/∂a (where a is activation before logsoftmax)
    grad[idx, T] = exp(probs[idx, T]) - exp(accum[idx, T] - s)
    idx += blockDim().x
  end
  sync_threads()

  # Fill in the rest of the coefficients
  t = T-1
  while t >= 1
    if t < T
      idx = tid
      while idx <= S
        nextSum = probs[labels[idx], t+1] + beta[idx, t+1]
        if idx < S
          nextSum = log_plus_f(nextSum,
            probs[labels[idx+1], t+1] + beta[idx+1, t+1])
        end
        if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
          nextSum = log_plus_f(nextSum,
            probs[labels[idx+2], t+1] + beta[idx + 2, t+1])
        end
        if idx > 2*t
          beta[idx, t] = -Inf32
        else
          beta[idx, t] = nextSum
        end
        idx += blockDim().x
      end
      sync_threads()
      idx = tid
      while idx <= S
        output[idx, t] = alphas[idx, t] + beta[idx, t]
        idx += blockDim().x
      end
      sync_threads()
    end
    sync_threads()

    # Calculate accumulated alpha-beta products for each label class for
    # each time step; used in calculating gradients
    if tid == 1
      for i=1:S
        labelIdx = labels[i]
        accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
      end
    end
    sync_threads()
    idx = tid

    # Calculate gradients
    while idx <= size(grad, 1)

      # ∂L/∂a (where a is activation before logsoftmax)
      grad[idx, t] = exp(probs[idx, t]) - exp(accum[idx, t] + loss)
      idx += blockDim().x
    end
    sync_threads()
    t -= 1
    sync_threads()
  end
  return nothing
end

function ctc_alpha(ŷ::CuArray, y)
  ŷ = logsoftmax(ŷ)
  blank = size(ŷ, 1)
  ycu = cu(y)
  z′ = CUDA.fill(blank, 2 * length(y) + 1)
  z′[eachindex(y) .* 2] .= ycu
  T = size(ŷ, 2)
  U′ = 2*length(y) + 1
  alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T)
  nRepeats = count_repeats(CUDA.adapt(Array, y))
  nThreads = min(U′, MAX_THREADS)
  @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank)
  return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)
end

ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss

function ∇ctc_loss(ŷ::CuArray, y, out)
  loss, alphas, z′, ŷ, nRepeats = out
  U′, T = size(alphas)
  blank = size(ŷ, 1)
  typed_zero = zero(eltype(ŷ))
  betas = CUDA.fill(log(typed_zero), U′, T)
  output = CUDA.fill(log(typed_zero), U′, T)
  nThreads = min(U′, MAX_THREADS)
  grads = CUDA.fill(log(typed_zero), size(ŷ))
  accum = CUDA.fill(log(typed_zero), size(ŷ))
  @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss)
  return grads
end


================================================
FILE: ext/NNlibCUDAExt/sampling.jl
================================================
@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 4}, value, ix, iy, c, n) where T
    @inbounds CUDA.@atomic dx[ix, iy, c, n] += value
end

function grid_sample_kernel!(n_elem, output::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V}
    index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
    if index < n_elem
        iW, iH, iC, _ = size(input)
        _, gW, gH, _ = size(grid)

        w = index % gW + 1
        h = (index ÷ gW) % gH + 1
        n = index ÷ (gW * gH) + 1
        NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, n, iW, iH, iC)
    end
    nothing
end

function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 4}, dgrid::AbstractArray{V, 4}, Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V}
    index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
    if index < n_elem
        iW, iH, iC, _ = size(input)
        _, gW, gH, _ = size(grid)

        w = index % gW + 1
        h = (index ÷ gW) % gH + 1
        n = index ÷ (gW * gH) + 1
        NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, n, iW, iH, iC)
    end
    nothing
end

function NNlib.grid_sample(x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V}
    pad = Val(padding_mode)
    _, _, xC, xN = size(x)
    _, gW, gH, _ = size(grid)
    n_elem = gW * gH * xN
    y = similar(x, T, (gW, gH, xC, xN))

    kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad)
    config = launch_configuration(kernel.fun; max_threads=256)
    threads = min(n_elem, config.threads)
    blocks = cld(n_elem, threads)
    kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks)
    y
end

function NNlib.∇grid_sample(Δ::CuArray{T, 4}, x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V}
    pad = Val(padding_mode)
    xN = size(x, 4)
    _, gW, gH, _ = size(grid)
    n_elem = gW * gH * xN
    dx, dgrid = CUDA.zeros(T, size(x)), similar(grid)

    kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)
    config = launch_configuration(kernel.fun; max_threads=256)
    threads = min(n_elem, config.threads)
    blocks = cld(n_elem, threads)
    kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)
    dx, dgrid
end


@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 5}, value, ix, iy, iz, c, n) where T
    @inbounds CUDA.@atomic dx[ix, iy, iz, c, n] += value
end

function grid_sample_kernel!(n_elem, output::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V}
    index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
    if index < n_elem
        iW, iH,iD, iC, _ = size(input)
        _, gW, gH, gD, _ = size(grid)

        w = index % gW + 1
        h = (index ÷ gW) % gH + 1
        d = (index ÷ (gW * gH)) % gD + 1
        n = index ÷ (gW * gH * gD) + 1
        # n = index ÷ (gW * gH) + 1
        # d = (index ÷ (gW * gH * n)) + 1

        NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)
    end
    nothing
end

function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 5}, dgrid::AbstractArray{V, 5}, Δ::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V}
    index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
    if index < n_elem
        iW, iH, iD, iC, _ = size(input)
        _, gW, gH, gD, _ = size(grid)

        w = index % gW + 1
        h = (index ÷ gW) % gH + 1
        d = (index ÷ (gW * gH)) % gD + 1
        n = index ÷ (gW * gH * gD) + 1
        # n = index ÷ (gW * gH) + 1
        # d = (index ÷ (gW * gH * n)) + 1

        NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)
    end
    nothing
end

function NNlib.grid_sample(x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V}
    pad = Val(padding_mode)
    _, _, _, xC, xN = size(x)
    _, gW, gH, gD, _ = size(grid)
    n_elem = gW * gH * gD * xN
    y = similar(x, T, (gW, gH, gD, xC, xN))

    kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad)
    config = launch_configuration(kernel.fun; max_threads=256)
    threads = min(n_elem, config.threads)
    blocks = cld(n_elem, threads)
    kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks)
    y
end

function NNlib.∇grid_sample(Δ::CuArray{T, 5}, x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V}
    pad = Val(padding_mode)
    xN = size(x, 5)
    _, gW, gH, gD, _ = size(grid)
    n_elem = gW * gH * gD * xN
    dx, dgrid = CUDA.zeros(T, size(x)), similar(grid)

    kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)
    config = launch_configuration(kernel.fun; max_threads=256)
    threads = min(n_elem, config.threads)
    blocks = cld(n_elem, threads)
    kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)
    dx, dgrid
end

================================================
FILE: ext/NNlibCUDAExt/scatter.jl
================================================
# supported op: +, -, *, /, max, min, &, |, mean

## TODO support sparse dst/src/idx
## See issue https://github.com/FluxML/NNlib.jl/issues/647
# import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector
# const AnyCuSparseMatrix{Tv,Ti} = Union{
#     AbstractCuSparseMatrix{Tv,Ti},
#     CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix
#     CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX
#     CUDA.CuSparseMatrixCOO{Tv,Ti},
#     }
# const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}}

function scatter_kernel!(op::OP, dst, src, idx) where OP
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= length(idx)
        CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src[index])
    end
    return nothing
end

function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}) where OP
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= length(idx)
        li = Base._to_linear_index(dst, Tuple(idx[index])...)
        CUDA.@atomic dst[li] = op(dst[li], src[index])
    end
    return nothing
end

function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size) where OP
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= max_idx
        j, k = divrem(index-1, max_dims_idx)
        dims_i = CartesianIndices(dims_size)[k+1]
        CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index])
    end
    return nothing
end

function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
            max_idx, max_dims_idx, dims_size) where OP
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= max_idx
        j, k = divrem(index-1, max_dims_idx)
        dims_i = CartesianIndices(dims_size)[k+1]
        li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...)
        CUDA.@atomic dst[li] = op(dst[li], src[index])
    end
    return nothing
end


function NNlib.scatter!(op::OP, dst::AnyCuArray,
        src::AnyCuArray,
        idx::AnyCuArray) where OP
    isempty(idx) && return dst
    dims = NNlib.scatter_dims(dst, src, idx)
    args = if dims == 0
        max_idx = length(idx)
        op, dst, src, idx
    else
        dims_size = size(dst)[1:dims]
        max_dims_idx = prod(dims_size)
        max_idx = max_dims_idx * length(idx)
        op, dst, src, idx, max_idx, max_dims_idx, dims_size
    end

    kernel = @cuda launch=false scatter_kernel!(args...)
    config = launch_configuration(kernel.fun; max_threads=256)
    threads = min(max_idx, config.threads)
    blocks = cld(max_idx, threads)
    kernel(args...; threads=threads, blocks=blocks)
    return dst
end

function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray,
        src::AnyCuArray,
        idx::AnyCuArray)
    Ns = NNlib.scatter!(+, zero(dst), one.(src), idx)
    dst_ = NNlib.scatter!(+, zero(dst), src, idx)
    dst .+= NNlib.safe_div.(dst_, Ns)
    return dst
end


## Gradients

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
    rev_idx, max_idx, T::Type{TT}) where {OP,TT}
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= max_idx
        cart_j = CartesianIndices(idx)[index]
        # get aggregating indeices, which is to be aggregated together, and itself index
        inds = rev_idx[idx[cart_j]...]
        # multiply all values to be aggregated but not itself
        x = one(T)
        for k in inds
            x *= src[k]
        end
        x /= src[cart_j]
        # apply `op` on `Δsrc[i, k]` and `x`
        Δsrc[cart_j] = op(Δsrc[cart_j], x)
    end
    return nothing
end

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
            rev_idx, max_idx, T::Type{TT}) where {OP,TT}
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= max_idx
        cart_j = CartesianIndices(idx)[index]
        # get aggregating indeices, which is to be aggregated together, and itself index
        inds = rev_idx[Tuple(idx[cart_j])...]
        # multiply all values to be aggregated but not itself
        x = one(T)
        for k in inds
            x *= src[k]
        end
        x /= src[cart_j]
        # apply `op` on `Δsrc[i, k]` and `x`
        Δsrc[cart_j] = op(Δsrc[cart_j], x)
    end
    return nothing
end

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
    rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= max_idx
        i, j = fldmod1(index, max_dims_idx)
        cart_i = CartesianIndices(idx)[i]
        cart_j = pre_cart_idx[j]
        # get aggregating indeices, which is to be aggregated together, and itself index
        inds = rev_idx[idx[cart_i]...]
        # multiply all values to be aggregated but not itself
        x = one(T)
        for k in inds
            jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)
            x *= src[jk]
        end
        x /= src[index]
        # apply `op` on `Δsrc[i, k]` and `x`
        Δsrc[index] = op(Δsrc[index], x)
    end
    return nothing
end

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
                rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= max_idx
        i, j = fldmod1(index, max_dims_idx)
        cart_i = CartesianIndices(idx)[i]
        cart_j = pre_cart_idx[j]
        # get aggregating indeices, which is to be aggregated together, and itself index
        inds = rev_idx[Tuple(idx[cart_i])...]
        # multiply all values to be aggregated but not itself
        x = one(T)
        for k in inds
            jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)
            x *= src[jk]
        end
        x /= src[index]
        # apply `op` on `Δsrc[i, k]` and `x`
        Δsrc[index] = op(Δsrc[index], x)
    end
    return nothing
end

function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
    src::AnyCuArray,
    idx::AnyCuArray)
    dims = ndims(src) - ndims(idx)
    Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
    rev_idx = NNlib.reverse_indices(idx)
    rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))

    if dims == 0
        max_idx = length(idx)
        args = op, Δsrc, src, idx, rev_idx, max_idx, eltype(src)
    else
        pre_cart_idx = CartesianIndices(axes(src)[1:dims])
        max_dims_idx = length(pre_cart_idx)
        max_idx = max_dims_idx * length(idx)
        args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, eltype(src)
    end

    kernel = @cuda launch=false ∇scatter_src_kernel!(args...)
    config = launch_configuration(kernel.fun; max_threads=256)
    threads = min(max_idx, config.threads)
    blocks = cld(max_idx, threads)
    kernel(args...; threads=threads, blocks=blocks)

    CUDA.unsafe_free!(rev_idx)
    return Δsrc
end


================================================
FILE: ext/NNlibCUDAExt/utils.jl
================================================
NNlib._rng_from_array(::CuArray) = CUDA.default_rng()

NNlib._rng_compat_array(rng::CUDA.RNG, A::CuArray) = nothing
NNlib._rng_compat_array(rng::AbstractRNG, A::CuArray) = throw(ArgumentError(
    "cannot use rng::$(typeof(rng)) with array::CuArray, only CUDA's own RNG type works"))

function divide_kernel!(xs, ys, max_idx)
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= max_idx
        xs[index] = xs[index] / ys[index]
    end
    return nothing
end

function divide_kernel!(xs, counts, max_idx, max_dims_idx, dims_size)
    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

    @inbounds if index <= max_idx
        j, k = divrem(index-1, max_dims_idx)
        dims_i = Tuple(CartesianIndices(dims_size)[k+1])
        CUDA.@atomic xs[dims_i..., j+1] = xs[dims_i..., j+1] / counts[j+1]
    end
    return nothing
end

function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N
    max_dims = NNlib.maximum_dims(idx)
    T = CartesianIndex{N}
    rev = Array{Vector{T}}(undef, max_dims...)
    for i in eachindex(rev)
        rev[i] = T[]
    end
    NNlib.reverse_indices!(rev, idx)
    return map(cu, rev)
end


================================================
FILE: ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl
================================================
module NNlibEnzymeCoreExt

using NNlib
import EnzymeCore
using Random

using EnzymeCore.EnzymeRules

for (name, dataname, filtername) in (
                                     (typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!),
                                     (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!),
                                     (typeof(NNlib.∇conv_data!), NNlib.conv!, NNlib.∇conv_filter!),
                                     (typeof(NNlib.∇conv_filter!), NNlib.∇conv_data!, NNlib.conv!),
                                    )
    @eval begin

		function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT},
		                                                y::EnzymeCore.Annotation{<:AbstractArray{yT, N}},
		                                                x::EnzymeCore.Annotation{<:AbstractArray{xT, N}},
		                                                w::EnzymeCore.Annotation{<:AbstractArray{wT, N}},
		                                                cdims; kwargs...) where {RT, yT, xT, wT, N}

		    if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated
		        func.val(y.val, x.val, w.val, cdims.val; kwargs...)
		    end

		    primal = if EnzymeRules.needs_primal(config)
		        y.val
		    else
		        nothing
		    end
		    shadow = if EnzymeRules.needs_shadow(config)
		        y.dval
		    else
		        nothing
		    end

		    # Cache x if its overwritten and w is active (and thus required)
		    cache_x = ( EnzymeRules.overwritten(config)[3]
		                && !(typeof(w) <: EnzymeCore.Const)
		                && !(typeof(y) <: EnzymeCore.Const)
		                ) ? copy(x.val) : nothing

		    # Cache w if its overwritten and x is active (and thus required)
		    cache_w = ( EnzymeRules.overwritten(config)[4]
		                && !(typeof(x) <: EnzymeCore.Const)
		                && !(typeof(y) <: EnzymeCore.Const)
		                ) ? copy(w.val) : nothing

		    cache = (cache_x, cache_w)

		    return EnzymeRules.AugmentedReturn(primal, shadow, cache)
		end

		function EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache,
		                                                y::EnzymeCore.Annotation{<:AbstractArray{yT, N}},
		                                                x::EnzymeCore.Annotation{<:AbstractArray{xT, N}},
		                                                w::EnzymeCore.Annotation{<:AbstractArray{wT, N}},
		                                                cdims; kwargs...) where {RT, yT, xT, wT, N}
		    cache_x, cache_w = cache

		    # Don't cache x if not overwritten and w is active (and thus required)
		    if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
		        if !EnzymeRules.overwritten(config)[3]
		            cache_x = x.val
		        end
		    end

		    # Don't cache w if not overwritten and x is active (and thus required)
		    if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
		        if !EnzymeRules.overwritten(config)[4]
		            cache_w = w.val
		        end
		    end

		    dys = y.dval
		    dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval
		    dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval

		    if EnzymeRules.width(config) == 1
		        dys = (dys,)
		        dxs = (dxs,)
		        dws = (dws,)
		    end

		    for (dy, dx, dw) in zip(dys, dxs, dws)
		        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val

		            if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
		                # dx += grad wrt x.val
		                $dataname(dx, $(name != typeof(NNlib.∇conv_filter!) ? :dy : :cache_w), $(name != typeof(NNlib.∇conv_filter!) ? :cache_w : :dy), cdims.val; alpha=xT(1), beta=xT(1), kwargs...)
		            end
		            if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
		                # dw += grad wrt w.val
                        $filtername(dw, $(name != typeof(NNlib.∇conv_data!) ? :cache_x : :dy), $(name != typeof(NNlib.∇conv_data!) ? :dy : :cache_x), cdims.val; alpha=wT(1), beta=wT(1), kwargs...)
		            end
		            
		            dy .= 0
		        end
		    end

		    return (nothing, nothing, nothing, nothing)
		end

end
end

function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}

    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
        func.val(dst.val, src.val, idx.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        dst.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        dst.dval
    else
        nothing
    end

    # Cache idx if its overwritten
    cache_idx = ( EnzymeRules.overwritten(config)[4]
                    && !(typeof(src) <: EnzymeCore.Const)
                    && !(typeof(dst) <: EnzymeCore.Const)
                    ) ? copy(idx.val) : nothing

    return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
end

function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}

    # Don't cache idx if not overwritten
    if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)
        if !EnzymeRules.overwritten(config)[4]
            cache_idx = idx.val
        end
    end

    ddsts = dst.dval
    dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

    if EnzymeRules.width(config) == 1
        ddsts = (ddsts,)
        dsrcs = (dsrcs,)
    end

    for (ddst, dsrc) in zip(ddsts, dsrcs)
        if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val

            if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
                NNlib.scatter!(+, dsrc, ddst, cache_idx)
            end

            ddst .= 0
        end
    end

    return (nothing, nothing, nothing)
end



function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}

    @assert !(OutType <: EnzymeCore.Const)
    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
        func.val(op.val, dst.val, src.val, idx.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        dst.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        dst.dval
    else
        nothing
    end

    # Cache idx if its overwritten
    cache_idx = ( EnzymeRules.overwritten(config)[4]
                    && !(typeof(src) <: EnzymeCore.Const)
                    && !(typeof(dst) <: EnzymeCore.Const)
                    ) ? copy(idx.val) : nothing

    return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
end

function EnzymeRules.reverse(config,
										func::EnzymeCore.Const{typeof(NNlib.scatter!)},
										::Type{RT},
										cache_idx,
										op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType,
										src,
										idx::EnzymeCore.Const) where {OutType, RT}

    # Don't cache idx if not overwritten
    if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)
        if !EnzymeRules.overwritten(config)[4]
            cache_idx = idx.val
        end
    end

    ddsts = dst.dval
    dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

    if EnzymeRules.width(config) == 1
        ddsts = (ddsts,)
        dsrcs = (dsrcs,)
    end

    for (ddst, dsrc) in zip(ddsts, dsrcs)
        if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val

            if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val

                if eltype(typeof(op)) == typeof(+)
                    dsrc .+= NNlib.gather(ddst, cache_idx)
                else
                    @assert eltype(typeof(op)) == typeof(-)
                    dsrc .-= NNlib.gather(ddst, cache_idx)
                end
            end

        end
    end

    return (nothing, nothing, nothing, nothing)
end



for pool in [:maxpool, :meanpool, :lpnormpool]
    pool! = Symbol(pool, :!)
    ∇pool = Symbol(:∇, pool, :!)

    @eval begin

function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT}

    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
        func.val(y.val, x.val, dims.val; kwargs...)
    end

    primal = if EnzymeRules.needs_primal(config)
        y.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        y.dval
    else
        nothing
    end

    cache_y = ( EnzymeRules.overwritten(config)[2] 
                && !(typeof(x) <: EnzymeCore.Const) 
                && !(typeof(y) <: EnzymeCore.Const) 
                ) ? copy(y.val) : nothing

    cache_x = ( EnzymeRules.overwritten(config)[3]
                && !(typeof(x) <: EnzymeCore.Const) 
                && !(typeof(y) <: EnzymeCore.Const) 
                ) ? copy(x.val) : nothing

    cache = (cache_y, cache_x)

    return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT}
    cache_y, cache_x = cache

    # Don't cache y if not overwritten
    if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
        if !EnzymeRules.overwritten(config)[2]
            cache_y = y.val
        end
    end

    # Don't cache x if not overwritten
    if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
        if !EnzymeRules.overwritten(config)[3]
            cache_x = x.val
        end
    end

    dys = y.dval
    dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval

    if EnzymeRules.width(config) == 1
        dys = (dys,)
        dxs = (dxs,)
    end

    for (dy, dx) in zip(dys, dxs)
        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val

            if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
                NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...)
            end

            dy .= 0
        end
    end

    return (nothing, nothing, nothing)
end

end
end

function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT}

    T = float(real(eltype(dst.val)))
    val = convert(T, 1/(1-p.val))
    keep = if dims.val isa Colon
        similar(dst.val, T, size(dst.val))
    else
        similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val)))
    end
    rand!(rng.val, keep)
    
    keep = keep .> p.val

    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
        dst.val .= (keep .* val) .* src.val
    end

    primal = if EnzymeRules.needs_primal(config)
        dst.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        dst.dval
    else
        nothing
    end

    if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const
        keep = nothing
    end

    return EnzymeRules.AugmentedReturn(primal, shadow, keep)
end

function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT}
    T = float(real(eltype(dst.val)))
    val = convert(T, 1/(1-p.val))

    ddsts = dst.dval
    dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

    if EnzymeRules.width(config) == 1
        ddsts = (ddsts,)
        dsrcs = (dsrcs,)
    end

    for (ddst, dsrc) in zip(ddsts, dsrcs)
        if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val

            if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
                dsrc .+= (keep .* val) .* ddst
            end

            ddst .= 0
        end
    end

    dp = if typeof(p) <: EnzymeCore.Active
        typeof(p.val)(0)
    else
        nothing
    end

    return (nothing, nothing, nothing, dp, nothing)
end


end


================================================
FILE: ext/NNlibFFTWExt/NNlibFFTWExt.jl
================================================
module NNlibFFTWExt

using FFTW
using NNlib
using KernelAbstractions

include("stft.jl")

end


================================================
FILE: ext/NNlibFFTWExt/stft.jl
================================================
function NNlib.stft(x;
    n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
    center::Bool = true, normalized::Bool = false,
)
    kab = get_backend(x)
    use_window = !isnothing(window)

    use_window && kab != get_backend(window) && throw(ArgumentError(
        "`window` must be on the same device as stft input `x` ($kab), \
        instead: `$(get_backend(window))`."))
    use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError(
        "Expected `0 < length(window) ≤ n_fft=$n_fft`, \
        but got `length(window)=$(length(window))`."))
    hop_length < 0 && throw(ArgumentError(
        "Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

    # Pad window on both sides with `0` to `n_fft` length if needed.
    if use_window && length(window) < n_fft
        left = ((n_fft - length(window)) ÷ 2) + 1
        tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
        tmp[left:left + length(window) - 1] .= window
        window = tmp
    end

    if center
        pad_amount = n_fft ÷ 2
        x = pad_reflect(x, pad_amount; dims=1)
    end

    n = size(x, 1)
    (0 < n_fft ≤ n) || throw(ArgumentError(
        "Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`."))

    n_frames = 1 + (n - n_fft) ÷ hop_length

    # time2col.
    # Reshape `x` to (n_fft, n_frames, B) if needed.
    # Each row in `n_frames` is shifted by `hop_length`.
    if n_frames > 1
        # TODO can be more efficient if we support something like torch.as_strided
        ids = [
            row + hop_length * col
            for row in 1:n_fft, col in 0:(n_frames - 1)]
        x = @inbounds x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
    end

    region = 1
    use_window && (x = x .* window;)
    y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region)

    normalized && (y = y .* eltype(y)(n_fft^-0.5);)
    return y
end

function NNlib.istft(y;
    n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
    center::Bool = true, normalized::Bool = false,
    return_complex::Bool = false,
    original_length::Union{Nothing, Int} = nothing,
)
    kab = get_backend(y)
    use_window = !isnothing(window)

    use_window && kab != get_backend(window) && throw(ArgumentError(
        "`window` must be on the same device as istft input `y` ($kab), \
        instead: `$(get_backend(window))`."))
    use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError(
        "Expected `0 < length(window) ≤ n_fft=$n_fft`, \
        but got `length(window)=$(length(window))`."))
    hop_length < 0 && throw(ArgumentError(
        "Expected `hop_length > 0`, but got `hop_length=$hop_length`."))

    # TODO check `y` eltype is complex

    n_frames = size(y, 2)

    # Pad window on both sides with `0` to `n_fft` length if needed.
    if use_window && length(window) < n_fft
        left = ((n_fft - length(window)) ÷ 2) + 1
        tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)
        tmp[left:left + length(window) - 1] .= window
        window = tmp
    end

    # Denormalize.
    normalized && (y = y .* eltype(y)(n_fft^0.5);)

    region = 1
    x = return_complex ? ifft(y, region) : irfft(y, n_fft, region)

    # De-apply window.
    use_window && (x = x ./ window;)

    # col2time.
    expected_output_len = n_fft + hop_length * (n_frames - 1)

    ids = Vector{Int}(undef, expected_output_len)
    in_idx, out_idx = 0, 0
    prev_e, v = 0, 0

    for col in 0:(n_frames - 1)
        for row in 1:n_fft
            in_idx += 1
            v = row + hop_length * col
            v > prev_e || continue

            out_idx += 1
            ids[out_idx] = in_idx
        end
        prev_e = v
    end

    # In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
    nd = ntuple(_ -> Colon(), ndims(x) - 2)
    ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
    x = @inbounds x[ids, nd...]

    # Trim padding.
    left = center ? (n_fft ÷ 2 + 1) : 1
    right = if isnothing(original_length)
        center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len
    else
        left + original_length - 1
    end
    x = x[left:right, nd...]
    return x
end


================================================
FILE: ext/NNlibForwardDiffExt.jl
================================================
module NNlibForwardDiffExt

using ForwardDiff: ForwardDiff
using NNlib: NNlib

NNlib.within_gradient(x::ForwardDiff.Dual) = true
NNlib.within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true

end


================================================
FILE: ext/NNlibMetalExt.jl
================================================
module NNlibMetalExt


using Metal: method_table, @device_override
using NNlib: NNlib

@device_override NNlib.tanh_fast(x) = Base.FastMath.tanh_fast(x)

end


================================================
FILE: ext/NNlibSpecialFunctionsExt.jl
================================================
module NNlibSpecialFunctionsExt

using NNlib: NNlib, oftf
using SpecialFunctions: erf

# Full gelu (gelu_erf)
NNlib.gelu_erf(x) = x/2*(1 + erf(x/sqrt(oftf(x,2))))

function NNlib.deriv_gelu_erf(x)
    SQRT2 = sqrt(oftf(x,2))
    Φ = (1 + erf(x/SQRT2))/2
    Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π))
end

end

================================================
FILE: src/NNlib.jl
================================================
module NNlib

import Atomix
import ChainRulesCore: rrule

using Base.Broadcast: broadcasted
using Base.Threads
using ChainRulesCore
using GPUArraysCore
using KernelAbstractions
using KernelAbstractions: @atomic
using LinearAlgebra
using LinearAlgebra.BLAS: @blasfunc, BlasInt
using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose
using Random
using ScopedValues
using Statistics
using Statistics: mean

const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

# internal. TODO: change to an approach where amount of threading is controlled, not just on/off
const ALLOW_SPAWNS = ScopedValue(true)
should_use_spawn() = Threads.nthreads(:default) > 1 && ALLOW_SPAWNS[]
"""
    @disallow_spawns ex

Disallow NNlib to use `@spawn` on divisible workloads. i.e. within `conv` etc.
"""
macro disallow_spawns(ex)
    quote
        @with ALLOW_SPAWNS => false $(esc(ex))
    end
end

# Include APIs
include("dim_helpers.jl")
export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims

include("activations.jl")
for f in ACTIVATIONS
    @eval export $(f)
end
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu, gelu # Aliases

include("attention.jl")
export dot_product_attention, dot_product_attention_scores, make_causal_mask

include("dropout.jl")
export dropout, dropout!

include("softmax.jl")
export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,
    logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp

include("batched/batchedadjtrans.jl")
include("batched/batchedmul.jl")
export batched_mul, batched_mul!, ⊠,  batched_vec,
    batched_transpose, batched_adjoint

include("gemm.jl")
export grid_sample, ∇grid_sample

include("conv.jl")
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
    ∇conv_filter!, depthwiseconv, depthwiseconv!,
    ∇depthwiseconv_data, ∇depthwiseconv_data!,
    ∇depthwiseconv_filter, ∇depthwiseconv_filter!

include("conv_bias_act.jl")
export conv_bias_act, conv_bias_act!

include("bias_act.jl")
export bias_act!

include("fold.jl")

include("ctc.jl")
export ctc_loss

include("pooling.jl")
export maxpool, maxpool!, meanpool, meanpool!, lpnormpool, lpnormpool!,
    ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lpnormpool, ∇lpnormpool!

include("padding.jl")
export pad_constant, pad_repeat, pad_reflect, pad_zeros, pad_symmetric, pad_circular

include("upsample.jl")
export upsample_nearest, ∇upsample_nearest,
    upsample_linear, ∇upsample_linear,
    upsample_bilinear, ∇upsample_bilinear,
    upsample_trilinear, ∇upsample_trilinear,
    pixel_shuffle

include("gather.jl")
include("scatter.jl")
include("utils.jl")

include("sampling.jl")
include("functions.jl")

include("normalization.jl")
# export batchnorm, ∇batchnorm

## Include implementations
include("impl/padding_edges.jl")

# Direct implementations of convolutional and depthwise-convolutional algorithms
include("impl/conv_direct.jl")
include("impl/depthwiseconv_direct.jl")
# im2col implementations of convolutional and depthwise-convolutional algorithms
include("impl/conv_im2col.jl")
include("impl/depthwiseconv_im2col.jl")

# Direct implementations of pooling
include("impl/pooling_direct.jl")
include("deprecations.jl")

include("rotation.jl")
export imrotate, ∇imrotate

include("audio/stft.jl")
include("audio/spectrogram.jl")
include("audio/mel.jl")
export stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks

end # module NNlib


================================================
FILE: src/activations.jl
================================================
## Activation functions
#
# Some of activation functions have its wrapper function for GPU in NNlibCUDAExt.jl.
# https://github.com/JuliaGPU/CuArrays.jl/issues/614

ACTIVATIONS = [
    :σ, :hardσ, :hardtanh, :relu,
    :leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_sigmoid, :gelu_erf, :swish, :hardswish, :selu,
    :celu, :softplus, :softsign, :logσ, :logcosh,
    :mish, :tanhshrink, :softshrink, :trelu, :lisht,
    :tanh_fast, :sigmoid_fast,
]

# of type float (to allow for integer inputs)
oftf(x, y) = oftype(float(x), y)

# oftype contains control flow on 1.10+, which can lead to type instabilities under AD 
function rrule(::typeof(oftf), x, y)
    proj_y = ChainRulesCore.ProjectTo(y)
    oftf_pullback(Δ) = (NoTangent(), NoTangent(), proj_y(Δ))
    return oftf(x, y), oftf_pullback
end

"""
    σ(x) = 1 / (1 + exp(-x))

Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
function.
Unicode `σ` can be entered as `\\sigma` then tab, in many editors.
The ascii name `sigmoid` is also exported.

See also [`sigmoid_fast`](@ref).

```julia-repl
julia> using UnicodePlots

julia> lineplot(sigmoid, -5, 5, height=7)
          ┌────────────────────────────────────────┐     
        1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x)
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡏⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡔⠋⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠊⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
        0 │⣀⣀⣀⣀⣀⣀⣀⠤⠤⠤⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
          └────────────────────────────────────────┘     
          ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀     
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀     

julia> sigmoid === σ
true
```
"""
function σ(x)
    t = exp(-abs(x))
    ifelse(x ≥ 0, inv(1 + t), t / (1 + t))
end

const sigmoid = σ

"""
    hardσ(x) = max(0, min(1, (x + 3) / 6))

Piecewise linear approximation of [`sigmoid`](@ref).

```julia-repl
julia> lineplot(hardsigmoid, -5, 5, height=7)
          ┌────────────────────────────────────────┐         
        1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠉⠉⠉⠉⠉⠉⠉⠉│ hardσ(x)
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡠⠔⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⡗⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠋⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
        0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⠤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          └────────────────────────────────────────┘         
          ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀         
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         

julia> lineplot(sigmoid, -5, 5, height=7)
          ┌────────────────────────────────────────┐     
        1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x)
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡏⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡔⠋⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠊⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
        0 │⣀⣀⣀⣀⣀⣀⣀⠤⠤⠤⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     
          └────────────────────────────────────────┘     
          ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀     
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀     
```
"""
hardσ(x) = clamp((x + 3) / 6, 0, 1)

# https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html

const hardsigmoid = hardσ

"""
    logσ(x)

Return `log(σ(x))` which is computed in a numerically stable way.

```julia-repl
julia> lineplot(logsigmoid, -5, 5, height=7)
           ┌────────────────────────────────────────┐        
         0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡧⠤⠔⠒⠒⠒⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ logσ(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠉⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
   f(x)    │⠀⠀⠀⠀⠀⠀⢀⡤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⡤⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
        -6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           └────────────────────────────────────────┘        
           ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀        
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        
```
"""
logσ(x) = -softplus(-x)

const logsigmoid = logσ

"""
    hardtanh(x) = max(-1, min(1, x))

Segment-wise linear approximation of `tanh`, much cheaper to compute.
See ["Large Scale Machine Learning"](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf).

See also [`tanh_fast`](@ref).
```julia-repl
julia> lineplot(hardtanh, -2, 2, height=7)
           ┌────────────────────────────────────────┐            
         1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⠔⠋⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ hardtanh(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡷⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│            
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠖⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠖⠋⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
        -1 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⠔⠋⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
           └────────────────────────────────────────┘            
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀            
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x

julia> lineplot(tanh, -2, 2, height=7)
           ┌────────────────────────────────────────┐        
         1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠤⠒⠒⠒⠊⠉⠉⠉│ tanh(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡷⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠊⠁⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
        -1 │⣀⣀⣀⡠⠤⠤⠤⠖⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           └────────────────────────────────────────┘        
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        
```
"""
hardtanh(x) = clamp(x, oftype(x, -1), oftype(x, 1))  # clamp(x, -1, 1) is type-stable, but would promote Int32, for which we have tests

"""
    relu(x) = max(0, x)

[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
activation function.

```julia-repl
julia> lineplot(relu, -2, 2, height=7)
          ┌────────────────────────────────────────┐        
        2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│ relu(x)
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠊⠁⠀⠀│        
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀⠀⠀⠀⠀│        
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀│        
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⡠⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⡠⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
        0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⠔⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
          └────────────────────────────────────────┘        
          ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        
```
"""
relu(x) = ifelse(x<0, zero(x), x)  # faster than max(zero(x), x), still preserves NaN

"""
    leakyrelu(x, a=0.01) = max(a*x, x)

Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
activation function.
You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.

```julia-repl
julia> lineplot(x -> leakyrelu(x, 0.5), -2, 2, height=7)
           ┌────────────────────────────────────────┐       
         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ #42(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│       
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│       
   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       
           │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│       
           │⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠤⠒⠒⠋⠉⠁⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       
        -1 │⣀⣀⠤⠤⠒⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       
           └────────────────────────────────────────┘       
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀       
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀       

julia> leakyrelu(-10f0, 0.2)
-2.0f0

julia> leakyrelu(-10f0, 0.02)
-0.5f0
```
"""
leakyrelu(x, a=oftf(x, leakyrelu_a)) = ifelse(x>0, float(x), oftf(x, a*x))  # max(a*x, x) is 3x slower

const leakyrelu_a = 0.01  # also used in gradient below

"""
    relu6(x) = min(max(0, x), 6)

[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
activation function capped at 6.
See ["Convolutional Deep Belief Networks"](https://www.cs.toronto.edu/~kriz/conv-cifar10-aug2010.pdf) from CIFAR-10.

```julia-repl
julia> lineplot(relu6, -10, 10, height=7)
          ┌────────────────────────────────────────┐         
        6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠎⠉⠉⠉⠉⠉⠉⠉⠉│ relu6(x)
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⡤⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⡠⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡔⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
        0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⡧⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          └────────────────────────────────────────┘         
          ⠀-10⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀10⠀         
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         
```
"""
relu6(x) = clamp(x, oftype(x, 0), oftype(x, 6))  # clamp promotes, but clamp(x, 0, 6) would promote x::Int32

"""
    rrelu(x, lo=1/8, hi=1/3) = max(a*x, x)
    # where `a` is randomly sampled from uniform distribution `U(lo, hi)`

Randomized Leaky Rectified Linear Unit activation function.
See ["Empirical Evaluation of Rectified Activations"](https://arxiv.org/abs/1505.00853)
You can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`.

```julia-repl
julia> lineplot(rrelu, -20, 10, height=7)
            ┌────────────────────────────────────────┐         
         10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ rrelu(x)
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀│         
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀│         
   f(x)     │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⠤⣤⣤⢤⣤⣤⠤⠤⠤⢼⠮⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│         
            │⣰⢀⣆⡄⣄⡄⡠⡰⠦⠷⡜⢢⠷⠳⠢⠊⠉⠉⠀⠀⠁⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
            │⠃⠉⠙⠘⠃⠈⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
        -10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
            └────────────────────────────────────────┘         
            ⠀-20⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀10⠀         
            ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         

julia> extrema(rrelu.(fill(-10f0, 1000)))
(-3.3316886f0, -1.2548422f0)
```
"""
function rrelu(x::T, l=oftf(x,1/8), u=oftf(x,1/3)) where T<:Number
    a = (u - l) * rand(float(T)) + l
    return leakyrelu(x, a)
end

"""
    elu(x, α=1) = x > 0 ? x : α * (exp(x) - 1)

Exponential Linear Unit activation function.
See ["Fast and Accurate Deep Network Learning by Exponential Linear Units"](https://arxiv.org/abs/1511.07289).
You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.

```julia-repl
julia> lineplot(elu, -2, 2, height=7)
           ┌────────────────────────────────────────┐       
         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ elu(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│       
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│       
   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       
           │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│       
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠔⠒⠋⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       
        -1 │⠤⠤⠤⠤⠔⠒⠒⠒⠊⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       
           └────────────────────────────────────────┘       
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀       
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀       

julia> elu(-10f0)
-0.9999546f0

julia> elu(-10f0, 2)
-1.9999092f0
```
"""
elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1))

deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α))

"""
    gelu_tanh(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))

Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) using tanh approximation.

This implementation uses `tanh` which allows for better pattern matching and fusion in optimizing 
compilers compared to the sigmoid-based implementation. For a potentially faster implementation 
that uses `sigmoid_fast`, see [`gelu_sigmoid`](@ref).

```julia-repl
julia> lineplot(gelu_tanh, -2, 2, height=7)
           ┌────────────────────────────────────────┐        
         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu_tanh(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀│        
   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⣤⣤⣤⣤⣤⣤⣤⣤⡤⠤⠤⠤⠤⠤⠤⠤⣤⣤⣤⡤⡧⠶⠶⠭⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠉⠉⠉⠉⠉⠉⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
        -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           └────────────────────────────────────────┘        
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        

julia> lineplot(gelu_tanh, -5, 0, height=7);

julia> lineplot!(ans, swish)
             ┌────────────────────────────────────────┐         
           0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu_tanh(x) 
             │⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x)
             │⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│         
   f(x)      │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│         
             │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⣄⠀⠀⠀⠀⠀⢠⡞⠀⠀│         
             │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⢄⣀⣀⡤⢣⠃⠀⠀│         
        -0.2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠇⠀⠀⠀│         
             └────────────────────────────────────────┘         
             ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀         
             ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         
```
"""
function gelu_tanh(x)
    α = oftf(x, 0.044715)
    λ = oftf(x, gelu_λ)
    x/2 * (1 + tanh_fast(λ * (x + α * x^3)))
end

const gelu_λ = √(2 / π)
const gelu_2λ = √(8 / π)

function deriv_gelu_tanh(x)
    α = oftf(x, 0.044715)
    α2 = oftf(x, 0.08943)
    λ = oftf(x, gelu_λ)
    x2 = x * x
    t = muladd(x2, α, one(x))
    z = λ * x * t
    Ω = tanh_fast(z)
    sech2 = 1 - Ω^2
    (1 + Ω)/2 + x * λ * muladd(x2, α2, t) * sech2 / 2
end

"""
    gelu_sigmoid(x) = x * σ(√(8/π) * (x + 0.044715x^3))

Alternative implementation of the GELU activation function using `sigmoid` instead of `tanh`.
This is mathematically equivalent to [`gelu_tanh`](@ref) but may be faster in some cases.

The sigmoid-based implementation may prevent pattern matching and fusion in some optimizing 
compilers. Use [`gelu_tanh`](@ref) if you need better compiler optimization support.

See ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415).
"""
function gelu_sigmoid(x)
    α = oftf(x, 0.044715)
    λλ = oftf(x, gelu_2λ)
    x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x)))
end

function deriv_gelu_sigmoid(x)
    α = oftf(x, 0.044715)
    α2 = oftf(x, 0.08943)
    λλ = oftf(x, gelu_2λ)
    x2 = x * x
    t = muladd(x2, α, one(x))
    Ω = sigmoid_fast(λλ * x * t)
    dσ = conj(Ω * (1 - Ω))
    muladd(dσ * λλ * muladd(x2, α2, t), x, Ω)
end

"""
    gelu_erf(x) = xΦ(x) = 0.5x * (1 + erf(x/√2))

Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) without approximation.
The SpecialFunctions.jl package needs to be loaded to use this function.
"""
function gelu_erf end
function deriv_gelu_erf end

"""
    gelu(x) = gelu_tanh(x)

Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). 
See [`gelu_tanh`](@ref).
"""
const gelu = gelu_tanh
# Need to alias the type as well to ensure serialization libraries still work
# See https://github.com/FluxML/NNlib.jl/issues/631
const var"#gelu" = typeof(gelu_tanh)
const deriv_gelu = deriv_gelu_tanh

"""
    swish(x) = x * σ(x)

Self-gated activation function.
See ["Swish: a Self-Gated Activation Function"](https://arxiv.org/abs/1710.05941).

```julia-repl
julia> lineplot(swish, -2, 2, height=7)
           ┌────────────────────────────────────────┐         
         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤│ swish(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋⠁⠀│         
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀│         
   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⢀⣀⡤⠔⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
           │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⣤⣤⡤⡧⠴⠶⠯⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│         
           │⠉⠑⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
        -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
           └────────────────────────────────────────┘         
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀         
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         
```
"""
@inline swish(x) = x * sigmoid_fast(x)

"""
    hardswish(x) = x * hardσ(x)

Hard-Swish activation function.
See ["Searching for MobileNetV3"](https://arxiv.org/abs/1905.02244).

```julia-repl
julia> lineplot(hardswish, -2, 5, height = 7)
           ┌────────────────────────────────────────┐             
         5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠒⠉│ hardswish(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠒⠉⠁⠀⠀⠀⠀│             
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│             
   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│             
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│             
           │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣤⣤⣖⣚⣉⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│             
        -1 │⠉⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│             
           └────────────────────────────────────────┘             
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀             
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀             

julia> lineplot(hardswish, -4, 0, height = 7);

julia> lineplot!(ans, swish)
             ┌────────────────────────────────────────┐             
           0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜│ hardswish(x)
             │⠒⠒⠢⠤⢄⣀⡀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀│ swish(x)    
             │⠀⠀⠀⠀⠀⠀⠈⠉⠑⠒⠦⢄⣘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡴⠃⠀⠀│             
   f(x)      │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠑⡖⠦⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⢔⠏⠁⠀⠀⠀│             
             │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠣⣄⠀⠉⠑⠒⠦⠤⢄⣀⣀⣀⣀⡠⠤⠖⣊⠕⠁⠀⠀⠀⠀⠀│             
             │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⠤⡀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀│             
        -0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠒⠢⠤⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│             
             └────────────────────────────────────────┘             
             ⠀-4⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀             
             ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀             

julia> hardswish.(-5:5)'
1×11 adjoint(::Vector{Float64}) with eltype Float64:
 -0.0  -0.0  -0.0  -0.333333  -0.333333  0.0  0.666667  1.66667  3.0  4.0  5.0
```
"""
@inline hardswish(x) = x * hardσ(x)

deriv_hardswish(x) = ifelse(x < -3, oftf(x,0), ifelse(x > 3, oftf(x,1), x/3 + oftf(x,1/2)))

"""
    lisht(x) = x * tanh(x)

Activation function from 
["LiSHT: Non-Parametric Linearly Scaled Hyperbolic Tangent ..."](https://arxiv.org/abs/1901.05894)

```julia-repl
julia> lineplot(lisht, -2, 2, height=7)
          ┌────────────────────────────────────────┐         
        2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x)
          │⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│         
          │⠀⠀⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⠀⠀│         
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
        0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⠦⣄⣀⣀⣇⣀⣀⠤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          └────────────────────────────────────────┘         
          ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀         
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         

julia> lineplot!(ans, logcosh)
          ┌────────────────────────────────────────┐           
        2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x)  
          │⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│ logcosh(x)
          │⠢⣄⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⣀⠔│           
   f(x)   │⠀⠈⠑⠢⣀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⣀⠔⠊⠁⠀│           
          │⠀⠀⠀⠀⠀⠉⠢⢄⡀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⡠⠔⠋⠁⠀⠀⠀⠀│           
          │⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠦⣌⡓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⣁⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀│           
        0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠪⠷⣦⣄⣀⣀⣇⣀⣀⣤⠶⠕⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│           
          └────────────────────────────────────────┘           
          ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀           
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀           
```
"""
lisht(x) = x * tanh_fast(x)

"""
    selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1))

    λ ≈ 1.05070...
    α ≈ 1.67326...

Scaled exponential linear units.
See ["Self-Normalizing Neural Networks"](https://arxiv.org/abs/1706.02515).

```julia-repl
julia> lineplot(selu, -3, 2, height=7)
           ┌────────────────────────────────────────┐        
         3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠊⠉⠀⠀⠀⠀│        
   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⡠⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⣉⠭⠛⡏⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⡤⠤⠒⠊⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
        -2 │⠤⠤⠖⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           └────────────────────────────────────────┘        
           ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        

julia> selu(-10f0)
-1.7580194f0
```
"""
function selu(x)
    λ = oftf(x, selu_λ)
    α = oftf(x, selu_α)
    λ * ifelse(x > 0, x, @fastmath α * (exp(x) - 1))
end

const selu_λ = 1.0507009873554804934193349852946
const selu_α = 1.6732632423543772848170429916717

function deriv_selu(Ω)
    λ = oftf(Ω, selu_λ)
    α = oftf(Ω, selu_α)
    ifelse(Ω > 0, λ, Ω + α * λ)
end

"""
    celu(x, α=1) = x ≥ 0 ? x : α * (exp(x/α) - 1)

Activation function from ["Continuously Differentiable Exponential Linear Units"](https://arxiv.org/abs/1704.07483).

```julia-repl
julia> lineplot(celu, -2, 2, height=7)
           ┌────────────────────────────────────────┐        
         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ celu(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│        
   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠔⠒⠋⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
        -1 │⠤⠤⠤⠤⠔⠒⠒⠒⠊⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           └────────────────────────────────────────┘        
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        

julia> celu(-10f0)
-0.9999546f0
```
"""
celu(x, α=1) = ifelse(x ≥ 0, float(x), oftf(x,α) * (exp(x/oftf(x,α)) - 1))

deriv_celu(Ω, α=1) = ifelse(Ω > 0, oftf(Ω, 1), Ω / oftf(Ω, α) + 1)

"""
    trelu(x, theta=1) = x > theta ? x : 0

Threshold gated rectified linear activation function.
See ["Zero-bias autoencoders and the benefits of co-adapting features"](https://arxiv.org/abs/1402.3337)

```julia-repl
julia> lineplot(trelu, -2, 4, height=7)
          ┌────────────────────────────────────────┐         
        4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ trelu(x)
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀│         
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠴⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⡏⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
        0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣀⣀⣀⣀⣀⣀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         
          └────────────────────────────────────────┘         
          ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀4⠀         
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         
```
"""
trelu(x, theta=1) = ifelse(x <= theta, zero(x), x)

const thresholdrelu = trelu

"""
    softsign(x) = x / (1 + |x|)

See ["Quadratic Polynomials Learn Better Image Features"](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205) (2009).

```julia-repl
julia> lineplot(softsign, -5, 5, height=7)
           ┌────────────────────────────────────────┐            
         1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⣀⣀⠤⠤⠤⠤⠤│ softsign(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⡤⠖⠒⠋⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⡔⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡯⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│            
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠒⠋⠁⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
        -1 │⠒⠒⠒⠒⠒⠊⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
           └────────────────────────────────────────┘            
           ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀            
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀            

julia> lineplot!(ans, tanh)
           ┌────────────────────────────────────────┐            
         1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡤⠖⠊⠉⠉⠉⣉⣉⣉⣉⣉⠭⠭⠭⠭⠭│ softsign(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡔⣃⡤⠖⠒⠋⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanh(x)    
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣧⡞⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡯⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│            
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡴⠃⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠒⢋⠕⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
        -1 │⣒⣒⣒⣒⣒⣊⣉⣉⣉⣉⣁⣀⣀⡠⠤⠒⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
           └────────────────────────────────────────┘            
           ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀            
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀            

julia> softsign(1f0)
0.5f0

julia> softsign(100f0)
0.990099f0
```
"""
softsign(x) = x / (1 + abs(x))

deriv_softsign(x) = 1 / (1 + abs(x))^2

"""
    softplus(x) = log(exp(x) + 1)

See ["Deep Sparse Rectifier Neural Networks"](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf), JMLR 2011.

```julia-repl
julia> lineplot(softplus, -3, 3, height=7)
          ┌────────────────────────────────────────┐            
        4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x)
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠│            
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀│            
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠊⠁⠀⠀⠀⠀⠀│            
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⠤⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡧⠤⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
        0 │⣀⣀⣀⣀⣀⣀⣀⡠⠤⠤⠤⠤⠔⠒⠒⠚⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
          └────────────────────────────────────────┘            
          ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀            
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀            

julia> lineplot!(ans, relu)
          ┌────────────────────────────────────────┐            
        4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x)
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠│ relu(x)    
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠⡴⠞⠋⠁│            
   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣤⡴⠞⠋⠁⠀⠀⠀⠀│            
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⢤⡲⠝⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀│            
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡧⠤⠒⠊⣉⠥⠚⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
        0 │⣀⣀⣀⣀⣀⣀⣀⣠⣤⣤⣤⣤⣔⣒⣒⣚⣉⣉⣁⣀⣇⠴⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            
          └────────────────────────────────────────┘            
          ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀            
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀            

julia> softplus(16f0)
16.0f0
```
"""
softplus(x) = log1p(exp(-abs(x))) + relu(x)

"""
    logcosh(x)

Return `log(cosh(x))` which is computed in a numerically stable way.

```julia-repl
julia> lineplot(logcosh, -5, 5, height=7)
          ┌────────────────────────────────────────┐           
        5 │⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ logcosh(x)
          │⠉⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│           
          │⠀⠀⠀⠑⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀│           
   f(x)   │⠀⠀⠀⠀⠀⠀⠑⠦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠊⠁⠀⠀⠀⠀⠀│           
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│           
          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⠦⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│           
        0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠢⢄⣀⣀⣇⣀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│           
          └────────────────────────────────────────┘           
          ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀           
          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀           
```
"""
logcosh(x) = x + softplus(-2x) - oftf(x, log2)

const log2 = log(2)

"""
    mish(x) = x * tanh(softplus(x))

Activation function from ["Mish: A Self Regularized Non-Monotonic Neural Activation Function"](https://arxiv.org/abs/1908.08681).

```julia-repl
julia> lineplot(mish, -5, 5, height=7)
           ┌────────────────────────────────────────┐        
         5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋│ mish(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠒⠁⠀⠀⠀│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠔⠋⠁⠀⠀⠀⠀⠀⠀│        
   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⡠⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣧⣔⣊⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│        
        -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        
           └────────────────────────────────────────┘        
           ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀        
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        
```
"""
mish(x) = x * tanh(softplus(x))

"""
    tanhshrink(x) = x - tanh(x)

See ["Tanhshrink Activation Function"](https://www.gabormelli.com/RKB/Tanhshrink_Activation_Function).

```julia-repl
julia> lineplot(tanhshrink, -3, 3, height=7)
           ┌────────────────────────────────────────┐              
         3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanhshrink(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠊│              
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⣀⡠⠤⠒⠊⠉⠁⠀⠀⠀⠀│              
   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⣤⡤⠤⠤⠤⠤⠤⠤⡷⠶⠶⠶⠶⠶⠮⠭⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│              
           │⠀⠀⠀⠀⠀⣀⡠⠴⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
           │⡠⠴⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
        -3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
           └────────────────────────────────────────┘              
           ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀              
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀              

julia> tanhshrink.((-10f0, 10f0))
(-9.0f0, 9.0f0)
```
"""
tanhshrink(x) = x - tanh_fast(x)

"""
    softshrink(x, λ=0.5) =
        (x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))

See ["Softshrink Activation Function"](https://www.gabormelli.com/RKB/Softshrink_Activation_Function).

```julia-repl
julia> lineplot(softshrink, -2, 2, height=7)
           ┌────────────────────────────────────────┐              
         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠒⠉⠁│              
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀│              
   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⡤⠤⠤⠤⠤⠤⠤⡧⠤⠤⠤⠤⠶⠮⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│              
           │⠀⠀⠀⠀⠀⠀⢀⣀⠤⠖⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
           │⠀⣀⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
        -2 │⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
           └────────────────────────────────────────┘              
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀              
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀              

julia> lineplot!(ans, tanhshrink)
           ┌────────────────────────────────────────┐              
         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠒⣉⡡│ tanhshrink(x)
           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⣒⣋⠥⠤⠒⠊⠉⠁⠀│              
   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⣤⣤⡤⠤⠤⠤⠤⠤⠤⡷⠶⠶⠶⠶⠶⠾⠿⠯⠭⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤│              
           │⠀⢀⣀⡠⠤⠖⢒⣋⠭⠗⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
           │⠊⣉⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
        -2 │⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              
           └────────────────────────────────────────┘              
           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀              
           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀

julia> softshrink.((-10f0, 10f0))
(-9.5f0, 9.5f0)
```
"""
function softshrink(x, λ = 0.5)
    lo = x - oftf(x, λ)
    hi = x + oftf(x, λ)
    ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi)
end

# Define broadcasts for activation functions on arrays
for f in ACTIVATIONS
  @eval $(f)(x::AbstractArray, args...) = $(f).(x, args...)
end

## Faster, less accurate, versions of some.

"""
    tanh_fast(x)

This is a faster but slighly less accurate version of `tanh`.

Where Julia's `tanh` function has an error under 2 eps, this
may be wrong by 5 eps, a reduction by less than one decimal digit. 

For `x::Float32` this is usually about 10 times faster,
with a smaller speedup for `x::Float64`.
For any other number types, it just calls `tanh`.

See also [`sigmoid_fast`](@ref).

```julia-repl
julia> tanh(0.5f0)
0.46211717f0

julia> tanh_fast(0.5f0)
0.46211714f0

julia> hard_tanh(0.5f0)
0.5f0
```
"""
@inline function tanh_fast(x::Float32)
    # This method added in NNlib.jl#345 by @mcabbott and @oscardssmith,
    # with coeffiecients found using Remez.jl
    x2 = abs2(x)
    n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8))
    d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7))
    ifelse(x2 < 66f0, x * (n / d), sign(x))
end

@inline function tanh_fast(x::Float64)
    exp2x = @fastmath exp(x + x)
    y = (exp2x - 1) / (exp2x + 1) 
    # That has large errors near zero; using `expm1` would more accurate, but about as slow as `tanh`.
    # Instead, we switch to a polynomial, which is very accurate within its range:
    x2 = x * x
    ypoly = x * evalpoly(x2, (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953))
    ifelse(x2 > 900.0, sign(x), ifelse(x2 < 0.017, ypoly, y))
end

# These approximations are very badly behaved for Float16; none are fast.
# They are also a bit slower with ForwardDiff.Dual numbers, let's use Base:
tanh_fast(x::Number) = Base.tanh(x)

"""
    sigmoid_fast(x)

This is a faster, and very slightly less accurate, version of `sigmoid`.
For `x::Float32`, perhaps 3 times faster, and maximum errors 2 eps instead of 1.

See also [`tanh_fast`](@ref).

```julia-repl
julia> sigmoid(0.2f0)
0.54983395f0

julia> sigmoid_fast(0.2f0)
0.54983395f0

julia> hardσ(0.2f0)
0.53333336f0
```
"""
function sigmoid_fast(x::Real)
    @static if VERSION ≥ v"1.11-"
        @inline
    end
    t = @fastmath exp(-abs(x))
    y = ifelse(x ≥ 0, inv(1 + t), t / (1 + t))
    ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y))
end
# For x::Float32, this is not as quick as the rational tanh_fast(x) above,
# but that polynomial has poor relative accuracy for negative x.

sigmoid_fast(x::Float16) = sigmoid(x)  # sigmoid_fast is extremely badly behaved at large x

function sigmoid_fast(x::Number)
    Base.depwarn("sigmoid only makes sense on real numbers, got $(typeof(x))", :sigmoid_fast)
    sigmoid(x)
end

"""
    NNlib.fast_act(f, [x::AbstractArray])

Replaces `f == tanh` with [`tanh_fast`](@ref), etc.

Takes an optional 2nd argument, so that you can disable
this replacement for some array or element types.
"""
@inline fast_act(f::F, ::AbstractArray = 1:0) where {F<:Function} = f
@inline fast_act(::typeof(tanh), ::AbstractArray = 1:0) = tanh_fast
@inline fast_act(::typeof(sigmoid), ::AbstractArray = 1:0) = sigmoid_fast

## Define rrules for some activation functions, along with the
## broadcasted rrule activation functions.

## This is a performance hack specifically for Zygote, because it doesn't handle fused
## broadcasts well; but it generally should be good (or at least harmless) for any AD, as
## it saves ADing the broadcasting machinery.
## Related Issue https://github.com/JuliaDiff/ChainRulesCore.jl/issues/271

## TODO: add to the lists below all activations.

UNARY_ACTS = [ # f, dfdx
    ## In the same order as above!
    (:σ,            :(conj(Ω * (1 - Ω)))),
    (:hardσ,        :(ifelse((Ω>0)&(Ω<1), oftf(Ω, 1/6), oftf(Ω, 1)))),
    (:logσ,         :(sigmoid_fast(-x))),
    (:hardtanh,     :((Ω>-1) & (Ω<1))),
    (:relu,         :(Ω > 0)),
    (:leakyrelu,    :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, leakyrelu_a)))),
    (:relu6,        :((Ω>0) & (Ω<6))),
    # rrelu is random, can't write a rule.
    (:elu,          :(deriv_elu(Ω))),
    (:gelu_tanh,    :(deriv_gelu_tanh(x))),
    (:gelu_sigmoid, :(deriv_gelu_sigmoid(x))),
    (:gelu_erf,     :(deriv_gelu_erf(x))),
    (:swish,        :(Ω + sigmoid_fast(x) * (1 - Ω))),
    (:hardswish,    :(deriv_hardswish(x))),
    # lisht
    (:selu,         :(deriv_selu(Ω))),
    (:celu,         :(deriv_celu(Ω))),
    (:trelu,        :(Ω > 0)),
    (:softsign,     :(deriv_softsign(x))),
    (:softplus,     :(sigmoid_fast(x))),
    # (:softplus,     :(1 - @fastmath exp(-Ω))),  # slightly faster, check accuracy?
    # logcosh
    # mish
    (:tanhshrink,    :((x - Ω)^2)),
    (:softshrink,    :(Ω != 0)),
    ## Fast variants are the same!
    (:tanh_fast,    :(conj(1 - Ω^2))),
    (:sigmoid_fast, :(conj(Ω * (1 - Ω)))),
]

for (f, dfdx) in UNARY_ACTS
    @eval @scalar_rule($f(x), $dfdx)

    pullback = Symbol(:broadcasted_, f, :_pullback)
    @eval function rrule(::typeof(broadcasted),
                         ::typeof($f), x::Union{Numeric, Broadcast.Broadcasted})
        Ω = $f.(x)
        function $pullback(dΩ)
            x_thunk = InplaceableThunk(
                dx -> @.(dx += dΩ * $dfdx),
                @thunk @.(dΩ * $dfdx)
            )
            NoTangent(), NoTangent(), x_thunk
        end
        return Ω, $pullback
    end
end

# NO_ACT_GRAD = ChainRulesCore.@not_implemented "for simplicitly NNlib assumes the 2nd argument of this activation function is a constant"
NO_ACT_GRAD = NaN  ## Still reminds you not to use this, but is perhaps more GPU friendly.

BINARY_ACTS = [ # f, dfdx1, dfdx2
    ## In the same order as above!
    (:leakyrelu,   :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, x2))), NO_ACT_GRAD),
    (:elu,         :(deriv_elu(Ω, x2)),      NO_ACT_GRAD),
    (:celu,        :(deriv_celu(Ω, x2)),     NO_ACT_GRAD),
    (:trelu,       :(Ω > 0),                 ZeroTangent()),
    (:softshrink,  :(Ω != 0),                NO_ACT_GRAD),
]

for (f, dfdx1, dfdx2) in BINARY_ACTS
    @eval @scalar_rule($f(x1, x2), ($dfdx1, $dfdx2))

    pullback = Symbol(:broadcasted_, f, :_pullback_2arg)
    @eval function rrule(::typeof(broadcasted),
                         ::typeof($f), 
                         x1::Union{Numeric, Broadcast.Broadcasted}, x2::Number)
        Ω = $f.(x1, x2)
        ## Allowing x2::Array would allow size(Ω) != size(x1), which is not handled here:
        $pullback(dΩ) = (NoTangent(), NoTangent(), @.(dΩ * $dfdx1), NO_ACT_GRAD)
        return Ω, $pullback
    end
end


================================================
FILE: src/attention.jl
================================================
const AA3{T} = AbstractArray{T,3}
const AA4{T} = AbstractArray{T,4}
const AA{N,T} = AbstractArray{T,N}

"""
    dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])

Multihead dot product attention used in transformer architectures.

The input arrays must have the first two dimensions given by the number of features
and the sequence length, then an arbitrary number of batch dimensions or none.


Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores
of size `(kv_len, q_len, nheads, batch_size...)`.

See also [`dot_product_attention_scores`](@ref) if you only need the attention scores.

# Arguments

- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
          It will be added to the attention scores before applying the softmax. Default `nothing`.
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
           Default `identity` (no dropout).
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
          The mask is applied to the attention scores just before the softmax.
          See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`.
- `nheads`: Number of heads to split the input arrays into. Default `1`.

# Examples

```julia
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
y, α = dot_product_attention(q, k, v)
```
"""
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N
    batch_size = size(q)[3:end]
    batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same."))
    q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))

    x, α = dot_product_attention(q, k, v, args...; kws...)

    x = reshape(x, size(x, 1), size(x, 2), batch_size...)
    α = reshape(α, size(α)[1:3]..., batch_size...)
    return x, α
end

function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
            fdrop=identity, mask=nothing, nheads=1)

    (all(size.((q, k, v), 1) .% nheads .== 0)) || throw(ArgumentError("""
    First dimension in query, key and value must be divisible by `nheads`.
    Instead:
    - size(q): $(size(q))
    - size(k): $(size(k))
    - size(v): $(size(v))
    - nheads: $nheads
    """))
    (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("""
    Batch dimensions have to be the same. Instead:
    - size(q): $(size(q))
    - size(k): $(size(k))
    - size(v): $(size(v))
    """))
    size(q, 1) == size(k, 1) || throw(ArgumentError("""
    First dimension in query and key has to be the same. Instead:
    - size(q): $(size(q))
    - size(k): $(size(k))
    """))
    size(k, 2) == size(v, 2) || throw(ArgumentError("""
    Second dimension in key and value has to be the same. Instead:
    - size(k): $(size(k))
    - size(v): $(size(v))
    """))

    # Multihead attention. TODO create fastpath for singlehead attention.
    q, k, v = split_heads.((q, k, v), nheads)
    x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)
    return join_heads(x), α
end

function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
    # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
    # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
    # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]

    α = dot_product_attention_scores(q, k, bias; fdrop, mask)
    # [α] = [kv_len, q_len, nheads, batch_size]

    # The following permutedims and batched_mul are equivalent to
    # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
    vt = permutedims(v, (1, 3, 2, 4))
    x = batched_mul(vt, α)
    x = permutedims(x, (1, 3, 2, 4))
    # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size]
    return x, α
end

"""
    dot_product_attention_scores(query, key, [bias]; [fdrop, mask])

Return the attention scores for the [`dot_product_attention`](@ref).
Input arrays must have dimensions
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.

See [`dot_product_attention`](@ref) for more details.
"""
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
            fdrop=identity, mask=nothing) where T

    # The following permutedims and batched_mul are equivalent to
    # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
    kt = permutedims(k, (3, 1, 2, 4))
    qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1))
    logits = batched_mul(kt, qt)
    # [logits] = [kv_len, q_len, nheads, batch_size]

    logits = apply_attn_bias(logits, bias)
    logits = apply_attn_mask(logits, mask)

    α = softmax(logits, dims=1)
    return fdrop(α)
end

apply_attn_bias(logits, bias::Nothing) = logits

apply_attn_bias(logits, bias) = logits .+ bias

apply_attn_mask(logits, mask::Nothing) = logits

function apply_attn_mask(logits, mask)
    neginf = typemin(eltype(logits))
    ifelse.(mask, logits, neginf)
end


"""
    make_causal_mask(x, dims=2)

Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
Its elements are set such that `m[i, j] == i ≤ j`.

Can be used to mask the attention scores in [`dot_product_attention`](@ref).
"""
function make_causal_mask(x::AbstractArray; dims::Int=2)
  len = size(x, dims)
  mask = triu(trues_like(x, (len, len)))
  return mask
end

trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)

split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)
join_heads(x) = reshape(x, :, size(x)[3:end]...)

@non_differentiable make_causal_mask(::Any...)
@non_differentiable trues_like(::Any...)
@non_differentiable falses_like(::Any...)


================================================
FILE: src/audio/mel.jl
================================================
"""
    melscale_filterbanks(;
        n_freqs::Int, n_mels::Int, sample_rate::Int,
        fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2))

Create triangular Mel scale filter banks
(ref: [Mel scale - Wikipedia](https://en.wikipedia.org/wiki/Mel_scale)).
Each column is a filterbank that highlights its own frequency.

# Arguments:

- `n_freqs::Int`: Number of frequencies to highlight.
- `n_mels::Int`: Number of mel filterbanks.
- `sample_rate::Int`: Sample rate of the audio waveform.
- `fmin::Float32`: Minimum frequency in Hz.
- `fmax::Float32`: Maximum frequency in Hz.

# Returns:

Filterbank matrix of shape `(n_freqs, n_mels)` where each column is a filterbank.

```jldoctest
julia> n_mels = 8;

julia> fb = melscale_filterbanks(; n_freqs=200, n_mels, sample_rate=16000);

julia> plot = lineplot(fb[:, 1]);

julia> for i in 2:n_mels
           lineplot!(plot, fb[:, i])
       end

julia> plot
     ┌────────────────────────────────────────┐
   1 │⠀⡀⢸⠀⢸⠀⠀⣧⠀⠀⢸⡄⠀⠀⠀⣷⠀⠀⠀⠀⠀⣷⠀⠀⠀⠀⠀⠀⢀⣿⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⡇⢸⡆⢸⡇⠀⣿⠀⠀⡜⡇⠀⠀⢰⠋⡆⠀⠀⠀⢰⠁⡇⠀⠀⠀⠀⠀⡸⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⣿⢸⡇⡇⡇⢰⠹⡄⠀⡇⢱⠀⠀⢸⠀⢣⠀⠀⠀⡜⠀⢸⡀⠀⠀⠀⢀⠇⠀⠈⡇⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⣿⡇⡇⡇⡇⢸⠀⡇⢀⠇⠸⡀⠀⡇⠀⠸⡀⠀⢀⠇⠀⠀⢇⠀⠀⠀⡸⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀│
     │⢠⢻⡇⡇⡇⢱⢸⠀⢇⢸⠀⠀⡇⢀⠇⠀⠀⡇⠀⢸⠀⠀⠀⠸⡀⠀⢠⠇⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀│
     │⢸⢸⡇⢱⡇⢸⡇⠀⢸⢸⠀⠀⢣⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⢇⠀⡜⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀⠀│
     │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⡎⠀⠀⠀⠈⣶⠁⠀⠀⠀⠀⠸⣤⠃⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀⠀⠀│
     │⢸⠀⡇⢸⠀⠀⡇⠀⠀⡇⠀⠀⠀⡇⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⢱⡀⠀⠀⠀⠀│
     │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⢇⠀⠀⠀⢀⠿⡀⠀⠀⠀⠀⢰⠛⡄⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀⠀⠀│
     │⢸⢸⡇⡸⡇⢸⡇⠀⢸⢸⠀⠀⡜⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⡎⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀│
     │⢸⢸⡇⡇⡇⡸⢸⠀⡎⢸⠀⠀⡇⠈⡆⠀⠀⡇⠀⢸⠀⠀⠀⢰⠁⠀⠘⡆⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀│
     │⡇⢸⡇⡇⡇⡇⢸⠀⡇⠈⡆⢰⠁⠀⡇⠀⢰⠁⠀⠈⡆⠀⠀⡎⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀│
     │⡇⢸⢸⡇⡇⡇⠸⣰⠃⠀⡇⡸⠀⠀⢸⠀⡜⠀⠀⠀⢣⠀⢸⠁⠀⠀⠀⠈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀│
     │⡇⡇⢸⠇⢸⡇⠀⣿⠀⠀⢣⡇⠀⠀⠸⣄⠇⠀⠀⠀⠸⡀⡇⠀⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄│
   0 │⣇⣇⣸⣀⣸⣀⣀⣟⣀⣀⣸⣃⣀⣀⣀⣿⣀⣀⣀⣀⣀⣿⣀⣀⣀⣀⣀⣀⣈⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣱│
     └────────────────────────────────────────┘
     ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀200⠀
```
"""
function melscale_filterbanks(;
    n_freqs::Int, n_mels::Int, sample_rate::Int,
    fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2),
)
    mel_min, mel_max = _hz_to_mel(fmin), _hz_to_mel(fmax)
    mel_points = range(mel_min, mel_max; length=n_mels + 2)

    all_freqs = collect(range(0f0, Float32(sample_rate ÷ 2); length=n_freqs))
    freq_points = _mel_to_hz.(mel_points)
    filter_banks = _triangular_filterbanks(freq_points, all_freqs)

    if any(maximum(filter_banks; dims=1) .≈ 0f0)
        @warn """At least one mel filterbank has all zero values.
        The value for `n_mels=$n_mels` may be set too high.
        Or the value for `n_freqs=$n_freqs` may be set too low.
        """
    end
    return filter_banks
end

_hz_to_mel(freq::T) where T = T(2595) * log10(T(1) + (freq / T(700)))

_mel_to_hz(mel::T) where T = T(700) * (T(10)^(mel / T(2595)) - T(1))

"""
    _triangular_filterbanks(
        freq_points::Vector{Float32}, all_freqs::Vector{Float32})

Create triangular filter banks.

# Arguments:

- `freq_points::Vector{Float32}`: Filter midpoints of size `n_filters`.
- `all_freqs::Vector{Float32}`: Frequency points of size `n_freqs`.

# Returns:

Array of size `(n_freqs, n_filters)`.
"""
function _triangular_filterbanks(
    freq_points::Vector{Float32}, all_freqs::Vector{Float32},
)
    diff = @view(freq_points[2:end]) .- @view(freq_points[1:end - 1])
    slopes = transpose(reshape(freq_points, :, 1) .- reshape(all_freqs, 1, :))

    down_slopes = -(@view(slopes[:, 1:end - 2]) ./ reshape(@view(diff[1:end - 1]), 1, :))
    up_slopes = @view(slopes[:, 3:end]) ./ reshape(@view(diff[2:end]), 1, :)
    return max.(0f0, min.(down_slopes, up_slopes))
end


================================================
FILE: src/audio/spectrogram.jl
================================================
"""
    spectrogram(waveform;
        pad::Int = 0, n_fft::Int, hop_length::Int, window,
        center::Bool = true, power::Real = 2.0,
        normalized::Bool = false, window_normalized::Bool = false,
    )

Create a spectrogram or a batch of spectrograms from a raw audio signal.

# Arguments

- `pad::Int`:
    Then amount of padding to apply on both sides.
- `window_normalized::Bool`:
    Whether to normalize the waveform by the window’s L2 energy.
- `power::Real`:
    Exponent for the magnitude spectrogram (must be ≥ 0)
    e.g., `1` for magnitude, `2` for power, etc.
    If `0`, complex spectrum is returned instead.

See [`stft`](@ref) for other arguments.

# Returns

Spectrogram in the shape `(T, F, B)`, where
`T` is the number of window hops and `F = n_fft ÷ 2 + 1`.
"""
function spectrogram(waveform::AbstractArray{T};
    pad::Int = 0, n_fft::Int, hop_length::Int, window,
    center::Bool = true, power::Real = 2.0,
    normalized::Bool = false, window_normalized::Bool = false,
) where T
    pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);)

    # Pack batch dimensions.
    sz = size(waveform)
    spec_ = stft(reshape(waveform, (sz[1], :));
        n_fft, hop_length, window, center, normalized)
    # Unpack batch dimensions.
    spec = reshape(spec_, (size(spec_)[1:2]..., sz[2:end]...))
    window_normalized && (spec = spec .* inv(norm(window));)

    if power > 0
        p = T(power)
        spec = abs.(spec .+ eps(T)).^p
    end
    return spec
end

"""
    power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)

Convert a power spectrogram (amplitude squared) to decibel (dB) units.

# Arguments

- `s`: Input power.
- `ref`: Scalar w.r.t. which the input is scaled.
- `amin`: Minimum threshold for `s`.
- `top_db`: Threshold the output at `top_db` below the peak:
    `max.(s_db, maximum(s_db) - top_db)`.

# Returns

`s_db ~= 10 * log10(s) - 10 * log10(ref)`
"""
function power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)
    log_spec = 10f0 .* (log10.(max.(amin, s)) .- log10.(max.(amin, ref)))
    return max.(log_spec, maximum(log_spec) - top_db)
end

"""
    db_to_power(s_db; ref::Real = 1f0)

Inverse of [`power_to_db`](@ref).
"""
function db_to_power(s_db; ref::Real = 1f0)
    return ref .* 10f0.^(s_db .* 0.1f0)
end


================================================
FILE: src/audio/stft.jl
================================================
"""
    hamming_window(
        window_length::Int, ::Type{T} = Float32; periodic::Bool = true,
        α::T = T(0.54), β::T = T(0.46),
    ) where T <: Real

Hamming window function
(ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)).
Generalized version of `hann_window`.

``w[n] = \\alpha - \\beta \\cos(\\frac{2 \\pi n}{N - 1})``

Where ``N`` is the window length.

```julia-repl
julia> lineplot(hamming_window(100); width=30, height=10)
     ┌──────────────────────────────┐
   1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠚⠉⠉⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠁⠀⠀⠀⠀⠀⠈⢢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⠀⠀⢰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⠀⣠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⡀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⢰⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⡰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀│
     │⠀⠀⠀⢀⠴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀│
     │⠀⢀⡠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⣀⠀│
   0 │⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉│
     └──────────────────────────────┘
     ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀
```

# Arguments:

- `window_length::Int`: Size of the window.
- `::Type{T}`: Elemet type of the window.

# Keyword Arguments:

- `periodic::Bool`: If `true` (default), returns a window to be used as
    periodic function. If `false`, return a symmetric window.

    Following always holds:

```jldoctest
julia> N = 256;

julia> hamming_window(N; periodic=true) ≈ hamming_window(N + 1; periodic=false)[1:end - 1]
true
```
- `α::Real`: Coefficient α in the equation above.
- `β::Real`: Coefficient β in the equation above.

# Returns:

Vector of length `window_length` and eltype `T`.
"""
function hamming_window(
    window_length::Int, ::Type{T} = Float32; periodic::Bool = true,
    α::T = T(0.54), β::T = T(0.46),
) where T <: Real
    window_length < 1 && throw(ArgumentError(
        "`window_length` must be > 0, instead: `$window_length`."))

    n::T = ifelse(periodic, window_length, window_length - 1)
    scale = T(2) * π / n
    return [α - β * cos(scale * T(k)) for k in 0:(window_length - 1)]
end

"""
    hann_window(
        window_length::Int, ::Type{T} = Float32; periodic::Bool = true,
    ) where T <: Real

Hann window function
(ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)).

``w[n] = \\frac{1}{2}[1 - \\cos(\\frac{2 \\pi n}{N - 1})]``

Where ``N`` is the window length.

```julia-repl
julia> lineplot(hann_window(100); width=30, height=10)
     ┌──────────────────────────────┐
   1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠚⠉⠉⠉⠢⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡔⠁⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⠀⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⠀⠀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⠀⢀⡜⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀│
     │⠀⠀⠀⠀⢀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀│
     │⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠣⡀⠀⠀│
   0 │⣀⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢤⣀│
     └──────────────────────────────┘
     ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀
```

# Arguments:

- `window_length::Int`: Size of the window.
- `::Type{T}`: Elemet type of the window.

# Keyword Arguments:

- `periodic::Bool`: If `true` (default), returns a window to be used as
    periodic function. If `false`, return a symmetric window.

    Following always holds:

```jldoctest
julia> N = 256;

julia> hann_window(N; periodic=true) ≈ hann_window(N + 1; periodic=false)[1:end - 1]
true

julia> hann_window(N) ≈ hamming_window(N; α=0.5f0, β=0.5f0)
true
```

# Returns:

Vector of length `window_length` and eltype `T`.
"""
function hann_window(
    window_length::Int, ::Type{T} = Float32; periodic::Bool = true,
) where T <: Real
    hamming_window(window_length, T; periodic, α=T(0.5), β=T(0.5))
end

"""
    stft(x;
        n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
        center::Bool = true, normalized::Bool = false,
    )

Short-time Fourier transform (STFT).

The STFT computes the Fourier transform of short overlapping windows of the input,
giving frequency components of the signal as they change over time.

``Y[\\omega, m] = \\sum_{k = 0}^{N - 1} \\text{window}[k] \\text{input}[m \\times \\text{hop length} + k] \\exp(-j \\frac{2 \\pi \\omega k}{\\text{n fft}})``

where ``N`` is the window length,
``\\omega`` is the frequency ``0 \\le \\omega < \\text{n fft}``
and ``m`` is the index of the sliding window.

# Arguments:

- `x`: Input, must be either a 1D time sequence (`(L,)` shape)
    or a 2D batch of time sequence (`(L, B)` shape).

# Keyword Arguments:

- `n_fft::Int`: Size of Fourier transform.
- `hop_length::Int`: Distance between neighboring sliding window frames.
- `window`: Optional window function to apply.
    Must be 1D vector `0 < length(window) ≤ n_fft`.
    If window is shorter than `n_fft`, it is padded with zeros on both sides.
    If `nothing` (default), then no window is applied.
- `center::Bool`: Whether to pad input on both sides so that ``t``-th frame
    is centered at time ``t \\times \\text{hop length}``.
    Padding is done with `pad_reflect` function.
- `normalized::Bool`: Whether to return normalized STFT,
    i.e. multiplied with ``\\text{n fft}^{-0.5}``.

# Returns:

Complex array of shape `(n_fft, n_frames, B)`,
where `B` is the optional batch dimension.
"""
function stft end

"""
    istft(y;
        n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,
        center::Bool = true, normalized::Bool = false,
        return_complex::Bool = false,
        original_length::Union{Nothing, Int} = nothing,
    )

Inverse Short-time Fourier Transform.

Return the least squares estimation of the original signal

# Arguments:

- `y`: Input complex array in the `(n_fft, n_frames, B)` shape.
    Where `B` is the optional batch dimension.

# Keyword Arguments:

- `n_fft::Int`: Size of Fourier transform.
- `hop_length::Int`: Distance between neighboring sliding window frames.
- `window`: Window function that was applied to the input of `stft`.
    If `nothing` (default), then no window was applied.
- `center::Bool`: Whether input to `stft` was padded on both sides
    so that ``t``-th frame is centered at time ``t \\times \\text{hop length}``.
    Padding is done with `pad_reflect` function.
- `normalized::Bool`: Whether input to `stft` was normalized.
- `return_complex::Bool`: Whether the output should be complex,
    or if the input should be assumed to derive from a real signal and window.
- `original_length::Union{Nothing, Int}`: Optional size of the first dimension
    of the input to `stft`. Helps restoring the exact `stft` input size.
    Otherwise, the array might be a bit shorter.
"""
function istft end


================================================
FILE: src/batched/batchedadjtrans.jl
================================================
import Base: -
import Adapt: adapt_structure, adapt

_batched_doc = """
    batched_transpose(A::AbstractArray{T,3})
    batched_adjoint(A)

Equivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`.

These exist to control how `batched_mul` behaves,
as it operates on such matrix slices of an array with `ndims(A)==3`.

`PermutedDimsArray(A, (2,1,3))` is equivalent to `batched_transpose(A)`,
and is also understood by `batched_mul` (and more widely supported elsewhere).

    BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
    BatchedAdjoint{T, S}

Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose` etc.
"""

@doc _batched_doc
struct BatchedTranspose{T, S} <: AbstractArray{T, 3}
    parent::S
    BatchedTranspose{T, S}(X::S) where {T, S} = new{T, S}(X)
end

@doc _batched_doc
batched_transpose(A::AbstractArray{T, 3}) where T = BatchedTranspose(A)
batched_transpose(A::BatchedTranspose) = A.parent

@doc _batched_doc
struct BatchedAdjoint{T, S} <: AbstractArray{T, 3}
    parent::S
    BatchedAdjoint{T, S}(X::S) where {T, S} = new{T, S}(X)
end

@doc _batched_doc
batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A)
batched_adjoint(A::BatchedAdjoint) = A.parent

batched_adjoint(A::BatchedTranspose{<:Real}) = A.parent
batched_transpose(A::BatchedAdjoint{<:Real}) = A.parent
batched_adjoint(A::PermutedDimsArray{<:Real,3,(2,1,3)}) = A.parent
batched_transpose(A::PermutedDimsArray{<:Number,3,(2,1,3)}) = A.parent
# if you can't unwrap, put BatchedAdjoint outside (for dispatch):
batched_transpose(A::BatchedAdjoint{<:Complex}) = BatchedAdjoint(BatchedTranspose(A.parent))

BatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
BatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)

const BatchedAdjOrTrans{T, S} = Union{BatchedTranspose{T, S}, BatchedAdjoint{T, S}}

LinearAlgebra.wrapperop(A::BatchedAdjoint) = batched_adjoint
LinearAlgebra.wrapperop(B::BatchedTranspose) = batched_transpose

# AbstractArray Interface
Base.length(A::BatchedAdjOrTrans) = length(A.parent)
Base.size(m::BatchedAdjOrTrans) = (size(m.parent, 2), size(m.parent, 1), size(m.parent, 3))
Base.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.parent, 3))

Base.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian()
Base.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k)
Base.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k))
Base.@propagate_inbounds Base.setindex!(m::BatchedTranspose, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjoint, v, i::Int, j::Int, k::Int) = setindex!(m.parent, adjoint(v), j, i, k)

Base.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims)
Base.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims)
Base.similar(A::BatchedAdjOrTrans, T::Type) = similar(A.parent, T, size(A))
Base.similar(A::BatchedAdjOrTrans) = similar(A.parent, size(A))

Base.parent(A::BatchedAdjOrTrans) = A.parent

(-)(A::BatchedAdjoint)   = BatchedAdjoint(  -A.parent)
(-)(A::BatchedTranspose) = BatchedTranspose(-A.parent)

# C interface
function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
    sp = strides(A.parent)
    (sp[2], sp[1], sp[3])
end

function Base.stride(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}, d::Integer)
    d == 1 && return Base.stride(A.parent, 2)
    d == 2 && return Base.stride(A.parent, 1)
    Base.stride(A.parent, d)
end

Base.pointer(A::BatchedAdjOrTrans) = pointer(parent(A))
Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =
    Base.unsafe_convert(Ptr{T}, parent(A))

# Gradients
function rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3})
    b_transpose_back(Δ) = (NoTangent(), batched_transpose(unthunk(Δ)))
    batched_transpose(A), b_transpose_back
end
function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3})
    b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(unthunk(Δ)))
    batched_adjoint(A), b_adjoint_back
end

adapt_structure(to, x::BatchedAdjoint) = BatchedAdjoint(adapt(to, parent(x)))
adapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x)))

Broadcast.BroadcastStyle(::Type{<:BatchedAdjOrTrans{T, S}}) where {T, S} = Broadcast.BroadcastStyle(S)


================================================
FILE: src/batched/batchedmul.jl
================================================
_unbatch(A) = A
_unbatch(A::BatchedAdjOrTrans) = parent(A)

"""
    batched_mul(A, B) -> C
    A ⊠ B  # \\boxtimes

Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent 
any indices in the last dimensions.

If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.

To transpose each matrix, apply `batched_transpose` to the array,
or `batched_adjoint` for conjugate-transpose:

```jldoctest
julia> A, B = randn(2,5,17), randn(5,9,17);

julia> A ⊠ B |> size
(2, 9, 17)

julia> batched_adjoint(A) |> size
(5, 2, 17)

julia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size
(2, 9, 17)

julia> A ⊠ randn(5,9,1) |> size
(2, 9, 17)

julia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3))
true
```

The equivalent `PermutedDimsArray` may be used in place of `batched_transpose`.
Other permutations are also handled by BLAS,
provided that the batch index `k` is not the first dimension of the underlying array.
Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine.

However, `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS,
since the batch dimension is the contiguous one: `stride(A,3) == 1`.
This will be copied, as doing so is faster than `batched_mul_generic!`.

Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
"""
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
    batch_size = size(x)[3:end]
    @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
    x2 = reshape(x, size(x, 1), size(x, 2), :)
    y2 = reshape(y, size(y, 1), size(y, 2), :)
    z = batched_mul(x2, y2)
    return reshape(z, size(z, 1), size(z, 2), batch_size...)
  end

function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
    size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 ||
        throw(DimensionMismatch("batch size mismatch: A != B"))
    _batched_mul(storage_typejoin(A, B), A, B)
end

const ⊠ = batched_mul

function _batched_mul(::Type, A, B)
    T = promote_type(eltype(A), eltype(B))
    C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))))
    batched_mul!(C, A, B)
    C
end
function _batched_mul(::Type{<:DenseArray{T}}, A, B) where {T<:BlasFloat}
    C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))))
    batched_mul!(C, _copy_if_faster(A), _copy_if_faster(B))
    C
end

function _copy_if_faster(X::AbstractArray{<:Number, 3})
    is_strided(X) || return X
    if Base.stride(X, 3) == 1 && Base.stride(X, 1) != 1
        @debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(X)
        return copy(X)
    end
    X
end
function _copy_if_faster(X::BatchedAdjoint{<:Complex})
    Xbase = _unbatch(X)
    is_strided(Xbase) || return X
    if Base.stride(Xbase, 1) != 1
        @debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(_unbatch(X))
        return copy(X) # or batched_adjoint(copy(Xbase)), may be better on GPU?
    end
    X
end

# Gradient, allowing that size(A,3)==1 means it's "broadcasted" out to size(B,3)

function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3})
    function batched_mul_pullback(_Δ)
        Δ = unthunk(_Δ)
        Athunk = @thunk begin
            tmp = batched_mul(Δ, batched_adjoint(B))
            size(A,3) == 1 ? sum(tmp, dims=3) : tmp
        end
        Bthunk = @thunk begin
            tmp = batched_mul(batched_adjoint(A), Δ)
            size(B,3) == 1 ? sum(tmp, dims=3) : tmp
        end
        return (NoTangent(), Athunk, Bthunk)
    end
    batched_mul(A, B), batched_mul_pullback
end

"""
    batched_mul(A::Array{T,3}, B::Matrix)
    batched_mul(A::Matrix, B::Array{T,3})
    A ⊠ B

This is always matrix-matrix multiplication, but
either `A` or `B` may lack a batch index.

* When `B` is a matrix, result has `C[:,:,k] == A[:,:,k] * B[:,:]` for all `k`.

* When `A` is a matrix, then `C[:,:,k] == A[:,:] * B[:,:,k]`.
  This can also be done by reshaping and calling `*`,
  for instance `A ⊡ B` using TensorCore.jl, but is implemented here using
  `batched_gemm` instead of `gemm`.

```jldoctest
julia> randn(16,8,32) ⊠ randn(8,4) |> size
(16, 4, 32)

julia> randn(16,8,32) ⊠ randn(8,4,1) |> size  # equivalent
(16, 4, 32)

julia> randn(16,8) ⊠ randn(8,4,32) |> size
(16, 4, 32)
```

See also `batched_vec` to regard `B` as a batch of vectors, `A[:,:,k] * B[:,k]`.
"""
batched_mul(A::AbstractArray{T,3} where T, B::AbstractMatrix) = _semi_batched_mul(A,B)

# Simplify signature of batched_mul by hiding dispatch on Adjoint etc:

_semi_batched_mul(A::AbstractArray{<:Any,3}, B::AbstractMatrix) =
    batched_mul(A, reshape(B, size(B)..., 1))

_semi_batched_mul(A::AbstractArray{<:Any,3}, B::Adjoint{<:Number,<:AbstractMatrix}) =
    batched_mul(A, batched_adjoint(reshape(parent(B), size(parent(B))..., 1)))

_semi_batched_mul(A::AbstractArray{<:Any,3}, B::Transpose{<:Number,<:AbstractMatrix}) =
    batched_mul(A, batched_transpose(reshape(parent(B), size(parent(B))..., 1)))

batched_mul(A::AbstractMatrix, B::AbstractArray{T,3} where T) = _semi_batched_mul(A,B)

_semi_batched_mul(A::AbstractMatrix, B::AbstractArray{<:Any,3}) =
    batched_mul(reshape(A, size(A)..., 1), B)

_semi_batched_mul(A::Adjoint{<:Number,<:AbstractMatrix}, B::AbstractArray{<:Any,3}) =
    batched_mul(batched_adjoint(reshape(parent(A), size(parent(A))..., 1)), B)

_semi_batched_mul(A::Transpose{<:Number,<:AbstractMatrix}, B::AbstractArray{<:Any,3}) =
    batched_mul(batched_transpose(reshape(parent(A), size(parent(A))..., 1)), B)

"""
    batched_vec(A::AbstractArray{T,3}, B::AbstractMatrix)
    batched_vec(A::AbstractArray{T,3}, b::AbstractVector)
    batched_vec(A::AbstractArray, B::AbstractArray)

Batched matrix-vector multiplication. For the 3D case:
the result has `C[:,:,k] == A[:,:,k] * B[:,k]` for all `k`,
or else `C[:,:,k] == A[:,:,k] * b` for `b::Vector`.

For the general N-D case where `ndims(A) == ndims(B) + 1`:
the result has `C[:,k...] == A[:,:,k...] * B[:,k...]` for all batch indices `k...`.
The batch dimensions must match: `size(A)[3:end] == size(B)[2:end]`.

With the same argument types, `batched_mul(A, B)` would regard `B` as
a fixed matrix, not a batch of vectors. Both reshape and then
call `batched_mul(::Array{T,3}, ::Array{T,3})`.

```jldoctest
julia> A, B, b = randn(16,8,32), randn(8,32), randn(8);

julia> batched_vec(A,B) |> size
(16, 32)

julia> batched_vec(A,b) |> size
(16, 32)

julia> A4d, B3d = randn(16,8,10,32), randn(8,10,32);  # 4D and 3D arrays

julia> batched_vec(A4d, B3d) |> size
(16, 10, 32)
```
"""
function batched_vec(A::AbstractArray, B::AbstractArray)
    ndims(A) == ndims(B) + 1 || throw(DimensionMismatch(
        "batched_vec requires ndims(A) == ndims(B) + 1, got ndims(A)=$(ndims(A)) and ndims(B)=$(ndims(B))"))
    size(A)[3:end] == size(B)[2:end] || throw(DimensionMismatch(
        "batch dimensions must match: size(A)[3:end]=$(size(A)[3:end]) != size(B)[2:end]=$(size(B)[2:end])"))
    
    # Reshape B to add a singleton dimension for matrix multiplication
    B_reshaped = reshape(B, size(B, 1), 1, size(B)[2:end]...)
    # Perform batched multiplication
    C = batched_mul(A, B_reshaped)
    # Remove the singleton dimension
    return dropdims(C, dims=2)
end

batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) =
    reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3))

# If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway:
batched_vec(A::AbstractArray{T,3} where T, B::AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}) =
    batched_vec(A, copy(B))

batched_vec(A::AbstractArray{T,3} where T, b::AbstractVector) =
    reshape(batched_mul(A, reshape(b, length(b), 1, 1)), size(A,1), size(A,3))


"""
    batched_mul!(C, A, B) -> C
    batched_mul!(C, A, B, α=1, β=0)

In-place batched matrix multiplication, equivalent to
`mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`.
If `size(B,3) == 1` then every batch uses `B[:,:,1]` instead.

This will call `batched_gemm!` whenever possible. For real arrays this means that,
for `X ∈ [A,B,C]`, either `stride(X,1)==1` or `stride(X,2)==1`, the latter may
be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`.
Unlike `batched_mul` this will never make a copy.

For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen.
In this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then
only `stride(AorB::BatchedAdjoint,2) == 1` is accepted.
"""
function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3},
        α::Number=one(T), β::Number=zero(T)) where {T}
    _batched_mul!(storage_typejoin(C,A,B), C, A, B, α, β)
    C
end

_batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β)

_batched_mul!(::Type{D
Download .txt
gitextract_0xi0432k/

├── .buildkite/
│   └── pipeline.yml
├── .codecov.yml
├── .github/
│   ├── copilot-instructions.md
│   ├── dependabot.yml
│   └── workflows/
│       ├── BenchmarkTrigger.yml
│       ├── CompatHelper.yml
│       ├── Downstream.yml
│       ├── TagBot.yml
│       ├── ci.yml
│       ├── clean_preview.yml
│       └── pr_comment.yml
├── .gitignore
├── LICENSE.md
├── Project.toml
├── README.md
├── benchmark/
│   ├── Project.toml
│   ├── benchmarks.jl
│   ├── perf_report.jl
│   └── runbenchmarks.jl
├── docs/
│   ├── .gitignore
│   ├── Project.toml
│   ├── make.jl
│   └── src/
│       ├── assets/
│       │   ├── flux.css
│       │   └── jfk.flac
│       ├── audio.md
│       ├── index.md
│       └── reference.md
├── ext/
│   ├── NNlibAMDGPUExt/
│   │   ├── NNlibAMDGPUExt.jl
│   │   ├── activations.jl
│   │   ├── batched_mul.jl
│   │   ├── conv.jl
│   │   └── pool.jl
│   ├── NNlibCUDACUDNNExt/
│   │   ├── NNlibCUDACUDNNExt.jl
│   │   ├── activations.jl
│   │   ├── batchnorm.jl
│   │   ├── conv.jl
│   │   ├── pooling.jl
│   │   └── softmax.jl
│   ├── NNlibCUDAExt/
│   │   ├── NNlibCUDAExt.jl
│   │   ├── activations.jl
│   │   ├── batchedadjtrans.jl
│   │   ├── batchedmul.jl
│   │   ├── ctc.jl
│   │   ├── sampling.jl
│   │   ├── scatter.jl
│   │   └── utils.jl
│   ├── NNlibEnzymeCoreExt/
│   │   └── NNlibEnzymeCoreExt.jl
│   ├── NNlibFFTWExt/
│   │   ├── NNlibFFTWExt.jl
│   │   └── stft.jl
│   ├── NNlibForwardDiffExt.jl
│   ├── NNlibMetalExt.jl
│   └── NNlibSpecialFunctionsExt.jl
├── src/
│   ├── NNlib.jl
│   ├── activations.jl
│   ├── attention.jl
│   ├── audio/
│   │   ├── mel.jl
│   │   ├── spectrogram.jl
│   │   └── stft.jl
│   ├── batched/
│   │   ├── batchedadjtrans.jl
│   │   └── batchedmul.jl
│   ├── bias_act.jl
│   ├── conv.jl
│   ├── conv_bias_act.jl
│   ├── ctc.jl
│   ├── deprecations.jl
│   ├── dim_helpers/
│   │   ├── ConvDims.jl
│   │   ├── DenseConvDims.jl
│   │   ├── DepthwiseConvDims.jl
│   │   └── PoolDims.jl
│   ├── dim_helpers.jl
│   ├── dropout.jl
│   ├── fold.jl
│   ├── functions.jl
│   ├── gather.jl
│   ├── gemm.jl
│   ├── impl/
│   │   ├── conv_direct.jl
│   │   ├── conv_im2col.jl
│   │   ├── depthwiseconv_direct.jl
│   │   ├── depthwiseconv_im2col.jl
│   │   ├── padding_edges.jl
│   │   └── pooling_direct.jl
│   ├── normalization.jl
│   ├── padding.jl
│   ├── pooling.jl
│   ├── rotation.jl
│   ├── sampling.jl
│   ├── scatter.jl
│   ├── softmax.jl
│   ├── upsample.jl
│   └── utils.jl
└── test/
    ├── Project.toml
    ├── activations.jl
    ├── attention.jl
    ├── batchedmul.jl
    ├── bias_act.jl
    ├── conv.jl
    ├── conv_bias_act.jl
    ├── ctc.jl
    ├── dropout.jl
    ├── ext_amdgpu/
    │   ├── activations.jl
    │   ├── attention.jl
    │   ├── batched_mul.jl
    │   ├── batched_repr.jl
    │   ├── conv.jl
    │   ├── dropout.jl
    │   ├── pool.jl
    │   ├── runtests.jl
    │   ├── softmax.jl
    │   └── storage_type.jl
    ├── ext_cuda/
    │   ├── activations.jl
    │   ├── batchedadjtrans.jl
    │   ├── batchedmul.jl
    │   ├── batchnorm.jl
    │   ├── conv.jl
    │   ├── ctc.jl
    │   ├── dropout.jl
    │   ├── fold.jl
    │   ├── gather.jl
    │   ├── pooling.jl
    │   ├── runtests.jl
    │   ├── sampling.jl
    │   ├── scatter.jl
    │   ├── softmax.jl
    │   └── test_utils.jl
    ├── ext_metal/
    │   ├── activations.jl
    │   └── runtests.jl
    ├── functions.jl
    ├── inference.jl
    ├── padding.jl
    ├── pooling.jl
    ├── runtests.jl
    ├── sampling.jl
    ├── softmax.jl
    ├── test_utils.jl
    ├── testsuite/
    │   ├── fold.jl
    │   ├── gather.jl
    │   ├── rotation.jl
    │   ├── scatter.jl
    │   ├── spectral.jl
    │   └── upsample.jl
    └── utils.jl
Condensed preview — 141 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (712K chars).
[
  {
    "path": ".buildkite/pipeline.yml",
    "chars": 3603,
    "preview": "steps:\n  - label: \":julia: Julia {{matrix.julia}} - CUDA GPU\"\n    command:\n      - echo 'CUDA = \"052768ef-5323-5732-b1bb"
  },
  {
    "path": ".codecov.yml",
    "chars": 15,
    "preview": "comment: false\n"
  },
  {
    "path": ".github/copilot-instructions.md",
    "chars": 7201,
    "preview": "# NNlib.jl Copilot Instructions\n\n## Repository Overview\n\nNNlib.jl is a library providing fundamental neural network oper"
  },
  {
    "path": ".github/dependabot.yml",
    "chars": 496,
    "preview": "# To get started with Dependabot version updates, you'll need to specify which\n# package ecosystems to update and where "
  },
  {
    "path": ".github/workflows/BenchmarkTrigger.yml",
    "chars": 1660,
    "preview": "name: Benchmark Trigger\n\non:\n  pull_request_target:\n    types: [ labeled ]\n  workflow_dispatch:\n    inputs:\n      pr_id:"
  },
  {
    "path": ".github/workflows/CompatHelper.yml",
    "chars": 1380,
    "preview": "name: CompatHelper\non:\n  schedule:\n    - cron: 0 0 * * *\n  workflow_dispatch:\npermissions:\n  contents: write\n  pull-requ"
  },
  {
    "path": ".github/workflows/Downstream.yml",
    "chars": 2087,
    "preview": "name: IntegrationTest\non:\n  push:\n    branches: [master]\n    tags: [v*]\n  pull_request:\n\n# needed to allow julia-actions"
  },
  {
    "path": ".github/workflows/TagBot.yml",
    "chars": 825,
    "preview": "name: TagBot\non:\n  issue_comment:\n    types:\n      - created\n  workflow_dispatch:\n    inputs:\n      lookback:\n        de"
  },
  {
    "path": ".github/workflows/ci.yml",
    "chars": 2596,
    "preview": "name: CI\n\non:\n  push:\n    branches:\n      - master\n      - staging\n      - trying\n    tags: '*'\n  pull_request:\n\n# neede"
  },
  {
    "path": ".github/workflows/clean_preview.yml",
    "chars": 855,
    "preview": "# from https://github.com/CliMA/ClimaTimeSteppers.jl\nname: Doc Preview Cleanup\n\non:\n  pull_request:\n    types: [closed]\n"
  },
  {
    "path": ".github/workflows/pr_comment.yml",
    "chars": 738,
    "preview": "name: pr_comment\non:\n  pull_request:\n    types: [labeled]\njobs:\n  pr_comment:\n    runs-on: ubuntu-latest\n    steps:\n    "
  },
  {
    "path": ".gitignore",
    "chars": 186,
    "preview": "*.jl.cov\n*.jl.*.cov\n*.jl.mem\n*.o\n*.so\n*.dylib\n*.dll\n*~\n\\#*\ndeps/usr\ndeps.jl\n*.log\n.vscode/\n/Manifest.toml\ntest/Manifest."
  },
  {
    "path": "LICENSE.md",
    "chars": 1207,
    "preview": "The NNlib.jl package is licensed under the MIT \"Expat\" License:\n\n> Copyright (c) 2017-19: Julia Computing, Inc., Mike J "
  },
  {
    "path": "Project.toml",
    "chars": 1584,
    "preview": "name = \"NNlib\"\nuuid = \"872c559c-99b0-510c-b3b7-b6c96a88d5cd\"\nversion = \"0.9.34\"\n\n[deps]\nAdapt = \"79e6a3ab-5dfb-504d-930d"
  },
  {
    "path": "README.md",
    "chars": 1428,
    "preview": "<img align=\"right\" width=\"200px\" src=\"https://github.com/FluxML/NNlib.jl/raw/master/docs/src/assets/logo.png\">\n\n# NNlib."
  },
  {
    "path": "benchmark/Project.toml",
    "chars": 437,
    "preview": "[deps]\nArgParse = \"c7e460c6-2fb9-53a9-8c5b-16f535851c63\"\nBenchmarkCI = \"20533458-34a3-403d-a444-e18f38190b5b\"\nBenchmarkT"
  },
  {
    "path": "benchmark/benchmarks.jl",
    "chars": 1408,
    "preview": "using BenchmarkTools\nusing NNlib\nusing NNlib.ChainRulesCore: rrule\nusing Random\n\nRandom.seed!(1234567890)\n\nconst SUITE ="
  },
  {
    "path": "benchmark/perf_report.jl",
    "chars": 3179,
    "preview": "using JLD2, NNlib, BenchmarkTools\n\n# TODO organize and compare benchmarks using BenchmarkGroups\n\n# We need things to go "
  },
  {
    "path": "benchmark/runbenchmarks.jl",
    "chars": 1774,
    "preview": "# Adapted from\n# https://github.com/kul-forbes/ProximalOperators.jl/tree/master/benchmark\nusing ArgParse\nusing PkgBenchm"
  },
  {
    "path": "docs/.gitignore",
    "chars": 27,
    "preview": "build/\nsite/\nManifest.toml\n"
  },
  {
    "path": "docs/Project.toml",
    "chars": 399,
    "preview": "[deps]\nCairoMakie = \"13f3f980-e62b-5c42-98c6-ff1f3baf88f0\"\nDocumenter = \"e30172f5-a6a5-5a46-863b-614d45cd2de4\"\nFLAC = \"a"
  },
  {
    "path": "docs/make.jl",
    "chars": 685,
    "preview": "using Documenter, NNlib\n\nDocMeta.setdocmeta!(NNlib, :DocTestSetup,\n    :(using FFTW, NNlib, UnicodePlots); recursive = t"
  },
  {
    "path": "docs/src/assets/flux.css",
    "chars": 2090,
    "preview": "@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');\n\nbody {\n  font-family: Lato, \"Segoe UI\",Roboto,\"He"
  },
  {
    "path": "docs/src/audio.md",
    "chars": 1173,
    "preview": "# Reference\n\n!!! note\n    Spectral functions require importing `FFTW` package to enable them.\n\n## Window functions\n\n```@"
  },
  {
    "path": "docs/src/index.md",
    "chars": 979,
    "preview": "# NNlib.jl\n\n`NNlib` provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multip"
  },
  {
    "path": "docs/src/reference.md",
    "chars": 2770,
    "preview": "# Reference\n\nThe API reference of `NNlib`.\n\n## Activation Functions\n\nNon-linearities that go between layers of your mode"
  },
  {
    "path": "ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl",
    "chars": 1804,
    "preview": "module NNlibAMDGPUExt\n\nusing Adapt\nusing AMDGPU\nusing ChainRulesCore\nusing NNlib\nusing NNlib: BatchedAdjoint, BatchedTra"
  },
  {
    "path": "ext/NNlibAMDGPUExt/activations.jl",
    "chars": 543,
    "preview": "for (f, op) in [\n        NNlib.relu => MIOpen.relu,\n        NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6),\n        NNlib."
  },
  {
    "path": "ext/NNlibAMDGPUExt/batched_mul.jl",
    "chars": 739,
    "preview": "function _blas_at(x)\n    Base.stride(x, 1) == 1 && return x, 'N'\n    Base.stride(x, 2) == 1 && return batched_transpose("
  },
  {
    "path": "ext/NNlibAMDGPUExt/conv.jl",
    "chars": 3396,
    "preview": "function NNlib.conv!(\n    y::ROCArray{T, N}, x::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims,\n) where {T <: M"
  },
  {
    "path": "ext/NNlibAMDGPUExt/pool.jl",
    "chars": 1964,
    "preview": "for poolname in (:maxpool, :meanpool)\n    @eval function NNlib.$(poolname)(\n        x::ROCArray{T, N}, pdims::PoolDims,\n"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/NNlibCUDACUDNNExt.jl",
    "chars": 713,
    "preview": "module NNlibCUDACUDNNExt\n\nusing NNlib\nusing cuDNN\nusing CUDA\nusing Random, Statistics\n\nusing cuDNN: handle, with_workspa"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/activations.jl",
    "chars": 1799,
    "preview": "\n# Activation\n\nusing Base.Broadcast\nusing cuDNN: cudnnActivationForward!, cudnnOpTensor!,\n             CUDNN_ACTIVATION_"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/batchnorm.jl",
    "chars": 6705,
    "preview": "using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,\n             cudnnBatchNormalizationForwardInference"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/conv.jl",
    "chars": 12164,
    "preview": "\nusing NNlib: DenseConvDims\nimport NNlib: conv!, ∇conv_filter!, ∇conv_data!, conv_bias_act!\n\nusing cuDNN: scalingParamet"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/pooling.jl",
    "chars": 3418,
    "preview": "using cuDNN: cudnnPoolingMode_t, CUDNN_POOLING_MAX,\n             CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING,\n          "
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/softmax.jl",
    "chars": 4039,
    "preview": "import NNlib: softmax, softmax!, ∇softmax, ∇softmax!,\n              logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!\n\n"
  },
  {
    "path": "ext/NNlibCUDAExt/NNlibCUDAExt.jl",
    "chars": 248,
    "preview": "module NNlibCUDAExt\n\nusing NNlib\nusing CUDA\nusing Random, Statistics\n\ninclude(\"sampling.jl\")\ninclude(\"activations.jl\")\ni"
  },
  {
    "path": "ext/NNlibCUDAExt/activations.jl",
    "chars": 675,
    "preview": "# Activation functions\n\n# Some of activation functions need a wrapper for GPU support\n# https://github.com/JuliaGPU/CuAr"
  },
  {
    "path": "ext/NNlibCUDAExt/batchedadjtrans.jl",
    "chars": 1198,
    "preview": "using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans\nusing Adapt\nusing Adapt: WrappedArray\n\nconst CuBatchedA"
  },
  {
    "path": "ext/NNlibCUDAExt/batchedmul.jl",
    "chars": 377,
    "preview": "# Batched matrix multiplication\n# 1st argument is produced by NNlib.storage_type(A)\nNNlib._batched_gemm!(::Type{<:CuArra"
  },
  {
    "path": "ext/NNlibCUDAExt/ctc.jl",
    "chars": 6131,
    "preview": "# CTC loss moved from Flux.jl to NNlib\n\nimport NNlib: ctc_loss, ctc_alpha, ∇ctc_loss\n\n## GPU implementation\n\n# a port of"
  },
  {
    "path": "ext/NNlibCUDAExt/sampling.jl",
    "chars": 5101,
    "preview": "@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 4}, value, ix, iy, c, n) where T\n    @inbounds CUDA.@atomic dx[ix"
  },
  {
    "path": "ext/NNlibCUDAExt/scatter.jl",
    "chars": 7285,
    "preview": "# supported op: +, -, *, /, max, min, &, |, mean\n\n## TODO support sparse dst/src/idx\n## See issue https://github.com/Flu"
  },
  {
    "path": "ext/NNlibCUDAExt/utils.jl",
    "chars": 1172,
    "preview": "NNlib._rng_from_array(::CuArray) = CUDA.default_rng()\n\nNNlib._rng_compat_array(rng::CUDA.RNG, A::CuArray) = nothing\nNNli"
  },
  {
    "path": "ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl",
    "chars": 12284,
    "preview": "module NNlibEnzymeCoreExt\n\nusing NNlib\nimport EnzymeCore\nusing Random\n\nusing EnzymeCore.EnzymeRules\n\nfor (name, dataname"
  },
  {
    "path": "ext/NNlibFFTWExt/NNlibFFTWExt.jl",
    "chars": 94,
    "preview": "module NNlibFFTWExt\n\nusing FFTW\nusing NNlib\nusing KernelAbstractions\n\ninclude(\"stft.jl\")\n\nend\n"
  },
  {
    "path": "ext/NNlibFFTWExt/stft.jl",
    "chars": 4179,
    "preview": "function NNlib.stft(x;\n    n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,\n    center::Bool = true, normalize"
  },
  {
    "path": "ext/NNlibForwardDiffExt.jl",
    "chars": 201,
    "preview": "module NNlibForwardDiffExt\n\nusing ForwardDiff: ForwardDiff\nusing NNlib: NNlib\n\nNNlib.within_gradient(x::ForwardDiff.Dual"
  },
  {
    "path": "ext/NNlibMetalExt.jl",
    "chars": 157,
    "preview": "module NNlibMetalExt\n\n\nusing Metal: method_table, @device_override\nusing NNlib: NNlib\n\n@device_override NNlib.tanh_fast("
  },
  {
    "path": "ext/NNlibSpecialFunctionsExt.jl",
    "chars": 308,
    "preview": "module NNlibSpecialFunctionsExt\n\nusing NNlib: NNlib, oftf\nusing SpecialFunctions: erf\n\n# Full gelu (gelu_erf)\nNNlib.gelu"
  },
  {
    "path": "src/NNlib.jl",
    "chars": 3393,
    "preview": "module NNlib\n\nimport Atomix\nimport ChainRulesCore: rrule\n\nusing Base.Broadcast: broadcasted\nusing Base.Threads\nusing Cha"
  },
  {
    "path": "src/activations.jl",
    "chars": 39313,
    "preview": "## Activation functions\n#\n# Some of activation functions have its wrapper function for GPU in NNlibCUDAExt.jl.\n# https:/"
  },
  {
    "path": "src/attention.jl",
    "chars": 6010,
    "preview": "const AA3{T} = AbstractArray{T,3}\nconst AA4{T} = AbstractArray{T,4}\nconst AA{N,T} = AbstractArray{T,N}\n\n\"\"\"\n    dot_prod"
  },
  {
    "path": "src/audio/mel.jl",
    "chars": 3488,
    "preview": "\"\"\"\n    melscale_filterbanks(;\n        n_freqs::Int, n_mels::Int, sample_rate::Int,\n        fmin::Float32 = 0f0, fmax::F"
  },
  {
    "path": "src/audio/spectrogram.jl",
    "chars": 2311,
    "preview": "\"\"\"\n    spectrogram(waveform;\n        pad::Int = 0, n_fft::Int, hop_length::Int, window,\n        center::Bool = true, po"
  },
  {
    "path": "src/audio/stft.jl",
    "chars": 6605,
    "preview": "\"\"\"\n    hamming_window(\n        window_length::Int, ::Type{T} = Float32; periodic::Bool = true,\n        α::T = T(0.54), "
  },
  {
    "path": "src/batched/batchedadjtrans.jl",
    "chars": 4485,
    "preview": "import Base: -\nimport Adapt: adapt_structure, adapt\n\n_batched_doc = \"\"\"\n    batched_transpose(A::AbstractArray{T,3})\n   "
  },
  {
    "path": "src/batched/batchedmul.jl",
    "chars": 14223,
    "preview": "_unbatch(A) = A\n_unbatch(A::BatchedAdjOrTrans) = parent(A)\n\n\"\"\"\n    batched_mul(A, B) -> C\n    A ⊠ B  # \\\\boxtimes\n\nBatc"
  },
  {
    "path": "src/bias_act.jl",
    "chars": 4654,
    "preview": "\nusing NNlib: fast_act, tanh_fast\nusing ChainRulesCore\n\nconst RCR = RuleConfig{>:HasReverseMode}\n\n# This just saves typi"
  },
  {
    "path": "src/conv.jl",
    "chars": 19290,
    "preview": "## Convolution API\n#\n#  We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d,\n#  2d and "
  },
  {
    "path": "src/conv_bias_act.jl",
    "chars": 1558,
    "preview": "function conv_bias_act(x::AbstractArray{xT,N}, w::AbstractArray{wT,N},\n                cdims::ConvDims, b::AbstractArray"
  },
  {
    "path": "src/ctc.jl",
    "chars": 4012,
    "preview": "# CTC loss moved from Flux.jl to NNlib\n\n## CPU implementation\n\n\"\"\"\n    logaddexp(a, b)\n\nAdds log-space `a` and `b` such "
  },
  {
    "path": "src/deprecations.jl",
    "chars": 1521,
    "preview": "### Deprecated while v0.8 was latest\n\nexport ∇softmax,\n    ∇softmax!,\n    logsoftmax,\n    logsoftmax!,\n    ∇logsoftmax,\n"
  },
  {
    "path": "src/dim_helpers/ConvDims.jl",
    "chars": 4985,
    "preview": "\"\"\"\n    ConvDims\n\nType system-level information about convolution dimensions. Critical for things like\n`im2col!()` to ge"
  },
  {
    "path": "src/dim_helpers/DenseConvDims.jl",
    "chars": 3970,
    "preview": "\"\"\"\n    DenseConvDims\n\nConcrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d.\n\"\"\"\nstruct DenseConvDims{N, K"
  },
  {
    "path": "src/dim_helpers/DepthwiseConvDims.jl",
    "chars": 3728,
    "preview": "\"\"\"\n    DepthwiseConvDims\n\nConcrete subclass of `ConvDims` for a depthwise convolution.  Differs primarily due to\ncharac"
  },
  {
    "path": "src/dim_helpers/PoolDims.jl",
    "chars": 2981,
    "preview": "\"\"\"\n    PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int};\n             stride=k, padding=0, dilation=1)  where "
  },
  {
    "path": "src/dim_helpers.jl",
    "chars": 4793,
    "preview": "# Various helper functions to calculate dimensions for operations\ninclude(\"dim_helpers/ConvDims.jl\")\ninclude(\"dim_helper"
  },
  {
    "path": "src/dropout.jl",
    "chars": 5076,
    "preview": "\n\"\"\"\n    dropout([rng], A, p; [dims])\n\nReturns an array in which each element of `A` is either replaced with zero,\nwith "
  },
  {
    "path": "src/fold.jl",
    "chars": 9220,
    "preview": "\"\"\"\n    unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)\n\nPlaces sliding windows of x into a co"
  },
  {
    "path": "src/functions.jl",
    "chars": 487,
    "preview": "\"\"\"\n    glu(x, dim = 1)\n\nThe gated linear unit from the [\"Language Modeling with Gated Convolutional Networks\"](https://"
  },
  {
    "path": "src/gather.jl",
    "chars": 3950,
    "preview": "\"\"\"\n    NNlib.gather(src, idx) -> dst\n\nReverse operation of [`scatter`](@ref). Gathers data from source `src`\nand writes"
  },
  {
    "path": "src/gemm.jl",
    "chars": 6153,
    "preview": "## Low level gemm! call with pointers\n## Borrowed from Knet.jl, adapted for compile-time constants\n\nusing LinearAlgebra."
  },
  {
    "path": "src/impl/conv_direct.jl",
    "chars": 7844,
    "preview": "## This file contains direct Julia implementations of 2d and 3d convolutions\n\n# Helper functions for restricting x/w ove"
  },
  {
    "path": "src/impl/conv_im2col.jl",
    "chars": 14351,
    "preview": "## This file contains im2col-backed implementations of convolution for 2d and 3d\n## convolutions.  Expect to see a lot o"
  },
  {
    "path": "src/impl/depthwiseconv_direct.jl",
    "chars": 8112,
    "preview": "## This file contains direct Julia implementations of depwthwise convolutions\n\n\"\"\"\n    depthwiseconv_direct!(y, x, w, cd"
  },
  {
    "path": "src/impl/depthwiseconv_im2col.jl",
    "chars": 6247,
    "preview": "## This file contains adapter code for doing depthwise convolutions with im2col.\n\n\n\"\"\"\n    depthwiseconv_im2col!(y, x, w"
  },
  {
    "path": "src/impl/padding_edges.jl",
    "chars": 4394,
    "preview": "\"\"\"\n    calc_padding_regions(dims)\n\nPadding is a jerk.  A HUGE jerk that tries to sneak a bunch of conditionals and edge"
  },
  {
    "path": "src/impl/pooling_direct.jl",
    "chars": 14810,
    "preview": "# Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing\n# the inner loop operation and a "
  },
  {
    "path": "src/normalization.jl",
    "chars": 475,
    "preview": "# TODO: add CPU implementation\nfunction batchnorm end\n\nfunction ∇batchnorm end\n\n\nfunction ChainRulesCore.rrule(::typeof("
  },
  {
    "path": "src/padding.jl",
    "chars": 12076,
    "preview": "\"\"\"\n    pad_zeros(x, pad::Tuple; [dims])\n    pad_zeros(x, pad::Int; [dims])\n\nPad the array `x` with zeros.\nEquivalent to"
  },
  {
    "path": "src/pooling.jl",
    "chars": 8497,
    "preview": "## Pooling API\n#\n#  We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d,\n#  2d and 3d p"
  },
  {
    "path": "src/rotation.jl",
    "chars": 10460,
    "preview": "\"\"\"\n    _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, round_or_floor)\n\nThis rotates the coordinates and either "
  },
  {
    "path": "src/sampling.jl",
    "chars": 18531,
    "preview": "@inline in_bounds(h, w, H, W) = 1 ≤ h ≤ H && 1 ≤ w ≤ W\n@inline in_bounds(h, w, d, H, W, D) = 1 ≤ h ≤ H && 1 ≤ w ≤ W && 1"
  },
  {
    "path": "src/scatter.jl",
    "chars": 10214,
    "preview": "## Scatter API\n#   - Scatter:\n#     - scatter(op, src, idx)\n#     - scatter!(op, dst, src, idx)\n#   - Scatter destinatio"
  },
  {
    "path": "src/softmax.jl",
    "chars": 5452,
    "preview": "\n\"\"\"\n    softmax(x; dims = 1)\n\n[Softmax](https://en.wikipedia.org/wiki/Softmax_function) turns input array `x`\ninto prob"
  },
  {
    "path": "src/upsample.jl",
    "chars": 25750,
    "preview": "\"\"\"\n    pixel_shuffle(x, r::Integer)\n\nPixel shuffling operation, upscaling by a factor `r`.\n\nFor 4-arrays representing `"
  },
  {
    "path": "src/utils.jl",
    "chars": 4534,
    "preview": "\"\"\"\n    within_gradient(x) --> Bool\n\nReturns `false` except when used inside a `gradient` call, when it returns `true`.\n"
  },
  {
    "path": "test/Project.toml",
    "chars": 1378,
    "preview": "[deps]\nAdapt = \"79e6a3ab-5dfb-504d-930d-738a2a938a0e\"\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nChainRules"
  },
  {
    "path": "test/activations.jl",
    "chars": 14333,
    "preview": "\nACTIVATION_FUNCTIONS = [@eval($a) for a in NNlib.ACTIVATIONS]\n\nBINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Floa"
  },
  {
    "path": "test/attention.jl",
    "chars": 2609,
    "preview": "@testset \"different batchsizes\" begin\n    n = 15\n    lenq = 3\n    lenkv = 4\n    for batch_size in [(), 1, 2, (2,1,3)], n"
  },
  {
    "path": "test/batchedmul.jl",
    "chars": 12134,
    "preview": "using NNlib, Test, LinearAlgebra, Logging\nusing NNlib: storage_type, storage_typejoin, is_strided,\n    batched_mul_gener"
  },
  {
    "path": "test/bias_act.jl",
    "chars": 5530,
    "preview": "using NNlib, Zygote, ChainRulesCore, Test\nusing Zygote: ForwardDiff\n\nACTIVATION_FUNCTIONS =\n    [@eval($a) for a in NNli"
  },
  {
    "path": "test/conv.jl",
    "chars": 45777,
    "preview": "using NNlib, Test\nusing NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier,\n             stri"
  },
  {
    "path": "test/conv_bias_act.jl",
    "chars": 237,
    "preview": "@testset \"conv_bias_act\" begin\n    x = rand(4,4,3,3)\n    w = rand(2,2,3,3)\n    b = rand(1,1,1,3)\n    cdims = DenseConvDi"
  },
  {
    "path": "test/ctc.jl",
    "chars": 1342,
    "preview": "using Test\nusing NNlib: ctc_loss\nusing Zygote: gradient\nusing LinearAlgebra\n\n# Custom function to check numerical gradie"
  },
  {
    "path": "test/dropout.jl",
    "chars": 3646,
    "preview": "using NNlib, Test, Statistics, Random, LinearAlgebra\nusing Zygote, StableRNGs, ChainRulesCore, Enzyme\n\n@testset \"dropout"
  },
  {
    "path": "test/ext_amdgpu/activations.jl",
    "chars": 492,
    "preview": "@testset \"Compare CPU & GPU\" begin\n    for (T, atol) in ((Float16, 1.0f-2), (Float32, 1.0f-5))\n        @testset \"ndims: "
  },
  {
    "path": "test/ext_amdgpu/attention.jl",
    "chars": 1487,
    "preview": "@testset \"Compare CPU & GPU\" begin\n    n = 15\n    lenq = 3\n    lenkv = 4\n    for batch_size in [(), 1, 2, (2, 1, 3)], nh"
  },
  {
    "path": "test/ext_amdgpu/batched_mul.jl",
    "chars": 1057,
    "preview": "@testset \"batched_mul\" begin\n    A = rand(Float32, 3, 3, 2)\n    B = rand(Float32, 3, 3, 2)\n    dA, dB = ROCArray.((A, B)"
  },
  {
    "path": "test/ext_amdgpu/batched_repr.jl",
    "chars": 1301,
    "preview": "function print_array_strs(x)\n    str = sprint((io, x)->show(io, MIME\"text/plain\"(), x), x)\n    return @view split(str, '"
  },
  {
    "path": "test/ext_amdgpu/conv.jl",
    "chars": 506,
    "preview": "@testset \"Compare CPU & GPU\" begin\n    channels, batch = 3, 2\n    for T in (Float16, Float32), nd in (1, 2, 3)\n        x"
  },
  {
    "path": "test/ext_amdgpu/dropout.jl",
    "chars": 806,
    "preview": "@testset \"Test API\" begin\n    x = AMDGPU.randn(Float32, 3, 4)\n    @test size(@inferred dropout(x, 0.1)) == (3, 4)\n    @t"
  },
  {
    "path": "test/ext_amdgpu/pool.jl",
    "chars": 438,
    "preview": "@testset \"Compare CPU & GPU\" begin\n    channels, batch = 3, 2\n    for T in (Float16, Float32), nd in (1, 2, 3)\n        x"
  },
  {
    "path": "test/ext_amdgpu/runtests.jl",
    "chars": 1345,
    "preview": "using NNlib: batched_adjoint, batched_mul, batched_mul!, batched_transpose\nusing NNlib: is_strided, storage_type\nusing L"
  },
  {
    "path": "test/ext_amdgpu/softmax.jl",
    "chars": 539,
    "preview": "@testset \"Compare CPU & GPU\" begin\n    for (T, atol) in ((Float16, 1f-2), (Float32, 1f-5))\n        for (sz, dims) in [\n "
  },
  {
    "path": "test/ext_amdgpu/storage_type.jl",
    "chars": 471,
    "preview": "@testset \"NNlib storage type\" begin\n    x = ROCArray(ones(Float32, 10, 10))\n    @test storage_type(x) <: ROCArray{Float3"
  },
  {
    "path": "test/ext_cuda/activations.jl",
    "chars": 1375,
    "preview": "@testset \"activation broadcast\" begin\n    for f in NNlib.ACTIVATIONS\n        if f ∉ [:rrelu]\n            @eval gputest(x"
  },
  {
    "path": "test/ext_cuda/batchedadjtrans.jl",
    "chars": 1307,
    "preview": "function print_array_strs(x)\n    str = sprint((io, x)->show(io, MIME\"text/plain\"(), x), x)\n    return @view split(str, '"
  },
  {
    "path": "test/ext_cuda/batchedmul.jl",
    "chars": 1833,
    "preview": "@testset \"batched_mul\" begin\n    using NNlib: batched_mul, batched_mul!, batched_vec, \n                 batched_adjoint,"
  },
  {
    "path": "test/ext_cuda/batchnorm.jl",
    "chars": 1568,
    "preview": "using Statistics\n\n@testset \"Batchnorm\" begin\n    v = CUDA.rand(Float32, 2)\n    m = CUDA.rand(Float32, 2, 5)\n\n    @testse"
  },
  {
    "path": "test/ext_cuda/conv.jl",
    "chars": 5827,
    "preview": "using NNlib: DenseConvDims\n\n@testset \"convolution\" begin\n@testset \"$T\" for T in (Float64, ComplexF64)\n    a, b, c = rand"
  },
  {
    "path": "test/ext_cuda/ctc.jl",
    "chars": 1499,
    "preview": "# Custom function to check numerical gradient of ctc loss,\n# based on `ngradient` in `Tracker.jl`\nfunction ctc_ngradient"
  },
  {
    "path": "test/ext_cuda/dropout.jl",
    "chars": 1301,
    "preview": "@testset \"dropout + CUDA\" begin\n    # Basics\n    x1 = CUDA.randn(3, 4)\n    @test size(@inferred dropout(x1, 0.1)) == (3,"
  },
  {
    "path": "test/ext_cuda/fold.jl",
    "chars": 1193,
    "preview": "\n@testset \"fold\" begin\n    # Test for agreement between CPU/GPU versions, across a variety of kwargs\n    options = Dict{"
  },
  {
    "path": "test/ext_cuda/gather.jl",
    "chars": 3338,
    "preview": "@testset \"gather\" begin\n    T = Float32\n    CT = CuArray{Float32}\n\n    ## 1d src, 2d index of ints -> 2d output\n    src "
  },
  {
    "path": "test/ext_cuda/pooling.jl",
    "chars": 738,
    "preview": "@testset \"pooling\" begin\n\n    # Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs\n    "
  },
  {
    "path": "test/ext_cuda/runtests.jl",
    "chars": 557,
    "preview": "using Test\nusing NNlib\nusing Zygote\nusing ForwardDiff: Dual\nusing Statistics: mean\nusing CUDA, cuDNN\nimport CUDA.CUSPARS"
  },
  {
    "path": "test/ext_cuda/sampling.jl",
    "chars": 4493,
    "preview": "@testset \"Grid Sampling\" begin\n    for T in (Float32, Float64)\n        x = ones(T, (2, 2, 1, 1))\n        grid = Array{T}"
  },
  {
    "path": "test/ext_cuda/scatter.jl",
    "chars": 3941,
    "preview": "dsts = Dict(\n    0 => cu([3, 4, 5, 6, 7]),\n    1 => cu([3 3 4 4 5;\n             5 5 6 6 7]),\n)\nsrcs = Dict(\n    (0, true"
  },
  {
    "path": "test/ext_cuda/softmax.jl",
    "chars": 1033,
    "preview": "@testset \"softmax\" begin\n    for (sz, dims) in [((5,), :), ((5,), 1), ((5,5), :), ((5,5), 1), ((5,5), 2), ((5,5,5,5), (2"
  },
  {
    "path": "test/ext_cuda/test_utils.jl",
    "chars": 872,
    "preview": "function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...)\n    cpu_in = "
  },
  {
    "path": "test/ext_metal/activations.jl",
    "chars": 580,
    "preview": "@testset \"activation broadcast\" begin\n    broken_f = (:hardσ, :leakyrelu) \n    for name in NNlib.ACTIVATIONS\n        # p"
  },
  {
    "path": "test/ext_metal/runtests.jl",
    "chars": 919,
    "preview": "using NNlib\nusing Test\nusing Metal\nusing Zygote: gradient\nusing MLDataDevices: gpu_device\nusing ForwardDiff: Dual\n\nMetal"
  },
  {
    "path": "test/functions.jl",
    "chars": 282,
    "preview": "using NNlib: glu\nusing Zygote\n\n@testset \"glu\" begin\n    x = [1 2 3; 4 5 6; 7 8 9; 10 11 12]\t \n    @test ceil.(glu(x, 1))"
  },
  {
    "path": "test/inference.jl",
    "chars": 2194,
    "preview": "import NNlib: conv_direct, conv_im2col, channels_in, channels_out\n\n@testset \"Conv Inference\" begin\n    for T in (Float32"
  },
  {
    "path": "test/padding.jl",
    "chars": 5135,
    "preview": "using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular\n\n@testset \"padding constant\" "
  },
  {
    "path": "test/pooling.jl",
    "chars": 34137,
    "preview": "# using NNlib, Test\n\nmaxpool_answer_dict = Dict(\n    1 => Dict(\n        \"y\"          => [2, 4.],\n        \"y_nostride\" =>"
  },
  {
    "path": "test/runtests.jl",
    "chars": 5694,
    "preview": "using NNlib, Test, Statistics, Random\nusing ChainRulesCore, ChainRulesTestUtils\nusing Base.Broadcast: broadcasted\nimport"
  },
  {
    "path": "test/sampling.jl",
    "chars": 5561,
    "preview": "@testset \"Known gradients\" begin\n    x = ones(Float64, (2, 2, 1, 1))\n    grid = Array{Float64}(undef, 2, 2, 2, 1)\n    gr"
  },
  {
    "path": "test/softmax.jl",
    "chars": 5009,
    "preview": "using Statistics: mean\nusing NNlib: ∇softmax_data, ∇logsoftmax_data\n\n@testset \"softmax integer input\" begin\n    @test so"
  },
  {
    "path": "test/test_utils.jl",
    "chars": 2571,
    "preview": "const IntOrTuple = Union{Int, NTuple{N,Int} where N}\n\ngradtest(f, dims::IntOrTuple...; kw...) =\n    gradtest(f, randn.(R"
  },
  {
    "path": "test/testsuite/fold.jl",
    "chars": 1887,
    "preview": "import NNlib\n\nfunction fold_testsuite(Backend)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gr"
  },
  {
    "path": "test/testsuite/gather.jl",
    "chars": 6210,
    "preview": "using NNlib: gather, gather!\nimport EnzymeTestUtils\nusing EnzymeCore\n\nfunction gather_testsuite(Backend)\n    device(x) ="
  },
  {
    "path": "test/testsuite/rotation.jl",
    "chars": 4169,
    "preview": "function rotation_testsuite(Backend)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gradtest : g"
  },
  {
    "path": "test/testsuite/scatter.jl",
    "chars": 8429,
    "preview": "using NNlib: scatter, scatter!\n\ndsts = Dict(\n    0 => [3, 4, 5, 6, 7],\n    1 => [3 3 4 4 5;\n          5 5 6 6 7],\n)\nsrcs"
  },
  {
    "path": "test/testsuite/spectral.jl",
    "chars": 5816,
    "preview": "function spectral_testsuite(Backend)\n    cpu(x) = adapt(CPU(), x)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = "
  },
  {
    "path": "test/testsuite/upsample.jl",
    "chars": 7405,
    "preview": "function upsample_testsuite(Backend)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gradtest : g"
  },
  {
    "path": "test/utils.jl",
    "chars": 1518,
    "preview": "@testset \"within_gradient\" begin\n    @test NNlib.within_gradient([1.0]) === false\n    @test gradient(x -> NNlib.within_g"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the FluxML/NNlib.jl GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 141 files (634.0 KB), approximately 226.8k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!