[
  {
    "path": ".buildkite/pipeline.yml",
    "content": "steps:\n  - label: \":julia: Julia {{matrix.julia}} - CUDA GPU\"\n    command:\n      - echo 'CUDA = \"052768ef-5323-5732-b1bb-66c8b64840ba\"' >> test/Project.toml\n      - echo 'cuDNN = \"02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd\"' >> test/Project.toml\n    plugins:\n      - JuliaCI/julia#v1:\n          version: \"{{matrix.julia}}\"\n      - JuliaCI/julia-test#v1:\n          test_args: \"--quickfail\"\n      - JuliaCI/julia-coverage#v1:\n          codecov: true\n          dirs:\n            - src\n            - ext\n    agents:\n      queue: \"juliagpu\"\n      cuda: \"*\"\n    env:\n      JULIA_NUM_THREADS: 4\n      NNLIB_TEST_CUDA: \"true\"\n      NNLIB_TEST_CPU: \"false\"\n    if: build.message !~ /\\[skip tests\\]/\n    timeout_in_minutes: 180\n    matrix:\n      setup:\n        julia:\n          - \"1.10\"\n          - \"1\"\n          - \"nightly\"\n      adjustments:\n        - with:\n            julia: \"nightly\"\n          soft_fail: true\n\n\n  - label: \":julia: Julia {{matrix.julia}} - AMD GPU\"\n    command:\n      - echo 'AMDGPU = \"21141c5a-9bdb-4563-92ae-f87d6854732e\"' >> test/Project.toml\n    plugins:\n      - JuliaCI/julia#v1:\n          version: \"1\"\n      - JuliaCI/julia-test#v1:\n          test_args: \"--quickfail\"\n      - JuliaCI/julia-coverage#v1:\n          codecov: true\n          dirs:\n            - src\n            - ext\n    agents:\n      queue: \"juliagpu\"\n      rocm: \"*\"\n      rocmgpu: \"*\"\n    timeout_in_minutes: 180\n    env:\n      JULIA_AMDGPU_CORE_MUST_LOAD: \"1\"\n      JULIA_AMDGPU_HIP_MUST_LOAD: \"1\"\n      JULIA_AMDGPU_DISABLE_ARTIFACTS: \"1\"\n      NNLIB_TEST_AMDGPU: \"true\"\n      NNLIB_TEST_CPU: \"false\"\n      JULIA_NUM_THREADS: 4\n    matrix:\n      setup:\n        julia:\n          # - \"1.10\"  \n          - \"1\"\n          # - \"nightly\"\n      # adjustments:\n      #   - with:\n      #       julia: \"nightly\"\n      #     soft_fail: true\n\n\n  - label: \":julia: Julia {{matrix.julia}} - Metal GPU\"\n    command:\n      - echo 'Metal = \"dde4c033-4e86-420c-a63e-0dd931031962\"' >> test/Project.toml\n    plugins:\n      - JuliaCI/julia#v1:\n          version: \"{{matrix.julia}}\"\n      - JuliaCI/julia-test#v1:\n          test_args: \"--quickfail\"\n      - JuliaCI/julia-coverage#v1:\n          codecov: true\n          dirs:\n            - src\n            - ext\n    agents:\n      queue: \"juliaecosystem\"\n      os: \"macos\"\n      arch: \"aarch64\"\n    timeout_in_minutes: 180\n    env:\n      NNLIB_TEST_METAL: \"true\"\n      NNLIB_TEST_CPU: \"false\"\n      JULIA_NUM_THREADS: 4\n    matrix:\n      setup:\n        julia:\n          # - \"1.10\"\n          - \"1\"\n          # - \"nightly  \"\n      # adjustments:\n      #   - with:\n      #       julia: \"nightly\"\n      #     soft_fail: true\n\n\n  - label: \"Benchmarks\"\n    plugins:\n      - JuliaCI/julia#v1:\n          version: 1\n    env:\n      JULIA_NUM_THREADS: 4\n    command:\n      - julia --project=benchmark -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'\n      - julia --project=benchmark benchmark/runbenchmarks.jl\n      - printf '%b\\n' \"$(cat benchmark/report.md)\" | buildkite-agent annotate --style 'info'\n    agents:\n      queue: \"juliagpu\"\n    if: build.pull_request.labels includes \"benchmark\"\n    timeout_in_minutes: 30\n\nenv:\n  SECRET_CODECOV_TOKEN: \"IlEMvDI6RciJQr5eX7qBBpHYFAe8+Svf3lNJh9gZi0MeJZQvMZWzHfW/lVncA9d9K+gDBBTv/zwqF86xOaIFLuACNdcGZiGgHS+NGeXN5CEppjqLnqKuaeHmLgJ43jygxRwgF88LhwTGcHG7pmESIp1Bn3Jd23UUv4t8hJLBDF+KJLZMefzCXnEVzfwJYxhJktnKJPA4dOv59w33Vj1x5uCYZbQlLP54IJPBm8UGdXS+JrUX8Z7lhxbkJUi6c+R6cvVBw27uRjF0pUJY26mt1frx8MzTGTOweXTpi+Kc5JhzlokMlan17j6T/b7qMC13IuKopfqu1GhkSBQD3ZhQqA==;U2FsdGVkX19l7JMB48k4oJHLoaqC7/MmvQWmaiBxRN472ZC6AcQ0uCBRy6Fw8tI0YcjIxKDScaBnJ2v/deOfhg==\"\n"
  },
  {
    "path": ".codecov.yml",
    "content": "comment: false\n"
  },
  {
    "path": ".github/copilot-instructions.md",
    "content": "# NNlib.jl Copilot Instructions\n\n## Repository Overview\n\nNNlib.jl is a library providing fundamental neural network operations and primitives for Julia. It is primarily used by Flux.jl but can be used independently. The library provides:\n\n- Activation functions (sigmoid, relu, gelu, etc.)\n- Convolution and pooling operations\n- Attention mechanisms\n- Batched matrix operations\n- Neural network utilities (dropout, normalization, etc.)\n- GPU acceleration support (CUDA, AMDGPU)\n\n## Project Structure\n\n```\nNNlib.jl/\n├── src/              # Core library implementation\n│   ├── NNlib.jl      # Main module file\n│   ├── activations.jl # Activation functions\n│   ├── attention.jl   # Attention mechanisms\n│   ├── conv.jl        # Convolution operations\n│   ├── pooling.jl     # Pooling operations\n│   ├── batched/       # Batched operations\n│   └── impl/          # Implementation details\n├── ext/              # Package extensions for GPU backends\n│   ├── NNlibCUDAExt/      # CUDA-specific implementations\n│   ├── NNlibAMDGPUExt/    # AMDGPU-specific implementations\n│   └── NNlibCUDACUDNNExt/ # cuDNN-specific implementations\n├── test/             # Test suite\n└── docs/             # Documentation\n```\n\n## Julia Version\n\n- Minimum Julia version: 1.10\n- CI tests on: minimum julia version, latest stable (1.x), and pre-release versions\n\n## Coding Standards\n\n### Julia Conventions\n\n1. **Naming**:\n   - Functions: lowercase with underscores (e.g., `dot_product_attention`)\n   - Types: PascalCase (e.g., `ConvDims`, `PoolDims`)\n   - Constants: UPPERCASE with underscores (e.g., `ACTIVATIONS`)\n\n2. **Documentation**:\n   - Use Julia docstrings (\"\"\" ... \"\"\") for all exported functions\n   - Include examples in docstrings where appropriate\n   - Keep documentation up-to-date with implementation changes\n\n3. **Type Annotations**:\n   - Use type parameters and abstract types for generic implementations\n   - Leverage Julia's multiple dispatch for specialized implementations\n   - Define clear type hierarchies (e.g., `DenseConvDims`, `DepthwiseConvDims`)\n\n4. **Performance**:\n   - Prefer in-place operations where appropriate (functions ending with `!`)\n   - Use `@inbounds` judiciously when bounds checking is verified\n   - Consider thread safety for multi-threaded operations\n   - Use `NNlib.@disallow_spawns` to control threading behavior\n\n### Code Organization\n\n1. **Core Implementations**: CPU implementations go in `src/`\n2. **GPU Extensions**: GPU-specific code belongs in `ext/` as package extensions\n3. **Tests**: Mirror the structure of `src/` in `test/`\n4. **Gradients**: Define gradients using ChainRules.jl (`rrule` functions)\n\n## Testing\n\n### Test Infrastructure\n\n- Uses the standard Julia `Test` framework\n- Tests are organized to mirror the source structure\n- GPU tests are conditional (controlled by environment variables)\n\n### Running Tests\n\n```julia\n# Run all CPU tests\njulia --project -e 'using Pkg; Pkg.test()'\n\n# Run tests with threading\nJULIA_NUM_THREADS=4 julia --project -e 'using Pkg; Pkg.test()'\n```\n\n### Test Patterns\n\n1. **Activation Functions**: Test at specific values (0.0, 1.0, -1.0) and verify expected outputs\n2. **Gradient Tests**: Use `ChainRulesTestUtils` for gradient correctness\n3. **Type Stability**: Use `@inferred` where appropriate\n4. **GPU Tests**: Conditional testing based on environment variables:\n   - `ENV[\"NNLIB_TEST_CUDA\"]` for CUDA tests\n   - `ENV[\"NNLIB_TEST_AMDGPU\"]` for AMDGPU tests\n\n### Writing New Tests\n\n- Include tests for edge cases (zero inputs, negative values, boundary conditions)\n- Test both forward pass and gradients (using ChainRulesTestUtils)\n- For array operations, test multiple dimensions and batch sizes\n- Include tests for type stability when performance-critical\n\n## Dependencies\n\n### Core Dependencies\n\n- **ChainRulesCore**: For automatic differentiation support\n- **KernelAbstractions**: For GPU kernel abstractions\n- **Adapt**: For moving data between CPU/GPU\n- **GPUArraysCore**: GPU array interface\n\n### Weak Dependencies (Extensions)\n\n- **CUDA.jl/cuDNN**: NVIDIA GPU support\n- **AMDGPU.jl**: AMD GPU support\n- **FFTW**: Fast Fourier transforms\n- **ForwardDiff**: Forward-mode AD support\n- **EnzymeCore**: Enzyme AD support\n\n### Adding New Dependencies\n\n- Consider whether the dependency should be a weak dependency (extension)\n- Update `Project.toml` with version constraints\n- Ensure compatibility with supported Julia versions\n- Run full test suite after adding dependencies\n\n## GPU Support\n\nNNlib uses Julia's package extension system for GPU backends:\n\n1. **CUDA**: Load with `using NNlib, CUDA, cuDNN`\n2. **AMDGPU**: Load with `using NNlib, AMDGPU`\n\n### GPU Implementation Guidelines\n\n- Keep GPU-specific code in appropriate extensions (`ext/` directory)\n- Provide CPU fallback implementations in `src/`\n- Test GPU implementations separately (conditional on hardware availability)\n- Use KernelAbstractions for portable GPU kernels when possible\n\n## Build and CI/CD\n\n### Continuous Integration\n\n- **CI Workflow**: `.github/workflows/ci.yml`\n  - Tests on Linux (always), Windows, and macOS\n  - Tests with different Julia versions (LTS, stable, pre-release)\n  - Tests with different thread counts\n  \n### Additional Workflows\n\n- **TagBot**: Automatic release tagging\n- **CompatHelper**: Dependency compatibility updates\n- **Downstream**: Tests dependent packages\n- **BenchmarkTrigger**: Performance regression testing\n\n## Common Tasks\n\n### Adding a New Activation Function\n\n1. Add function to `src/activations.jl`\n2. Add to `ACTIVATIONS` tuple for automatic export\n3. Define gradient with `@scalar_rule` or `rrule`\n4. Add tests in `test/activations.jl` at key values\n5. Document with docstring and example\n\n### Adding a New Operation\n\n1. Implement in appropriate file in `src/`\n2. Export from `src/NNlib.jl`\n3. Define gradients using ChainRules\n4. Add comprehensive tests\n5. Add GPU implementations in extensions if applicable\n6. Document in appropriate file in `docs/src/`\n\n### Modifying Existing Functions\n\n1. Check for dependent code in Flux.jl and other downstream packages\n2. Maintain backward compatibility or document breaking changes\n3. Update tests to cover new behavior\n4. Update gradients if needed\n5. Consider performance implications\n\n## Performance Considerations\n\n1. **Memory Allocation**: Minimize allocations in hot paths\n2. **Threading**: NNlib uses Julia threads for parallel operations\n   - Control with `NNlib.@disallow_spawns` if needed\n   - Thread count controlled by `JULIA_NUM_THREADS`\n3. **GPU Kernels**: Optimize kernel launch parameters and memory access patterns\n4. **Type Stability**: Ensure type-stable code for performance-critical paths\n\n## Documentation\n\n- Documentation source: `docs/src/`\n- Built with Documenter.jl\n- Includes API reference, examples, and guides\n- Documentation tests run via `DocTestSetup`\n\n## Related Projects\n\n- **Flux.jl**: Primary consumer of NNlib\n- **Zygote.jl**: Automatic differentiation (uses ChainRules)\n- **ChainRules.jl**: Gradient definitions\n- **KernelAbstractions.jl**: GPU kernel abstraction\n\n## Getting Help\n\n- Documentation: https://fluxml.ai/NNlib.jl/dev/\n- Issues: https://github.com/FluxML/NNlib.jl/issues\n- FluxML community: https://github.com/FluxML\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "# To get started with Dependabot version updates, you'll need to specify which\n# package ecosystems to update and where the package manifests are located.\n# Please see the documentation for all configuration options:\n# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file\n\nversion: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\" # Location of package manifests\n    schedule:\n      interval: \"weekly\"\n"
  },
  {
    "path": ".github/workflows/BenchmarkTrigger.yml",
    "content": "name: Benchmark Trigger\n\non:\n  pull_request_target:\n    types: [ labeled ]\n  workflow_dispatch:\n    inputs:\n      pr_id:\n        type: string\n        description: id of the pull request that triggers this workflow\n      target_url:\n        type: string\n        description: url of target\n      baseline_url:\n        type: string\n        description: url of baseline\n\njobs:\n  benchmark_trigger:\n    if: ${{ github.event.label.name == 'benchmark' }}\n    runs-on: ubuntu-latest\n    env:\n      REPOSITORY: ${{ github.event.repository.full_name }}\n      PR_ID: ${{ github.event.inputs.pr_id || github.event.pull_request.number }}\n      TARGET_URL: ${{ github.event.inputs.target_url || format('{0}#{1}', github.event.pull_request.head.repo.html_url, github.event.pull_request.head.sha) }}\n      BASELINE_URL: ${{ github.event.inputs.baseline_url || format('{0}#{1}', github.event.pull_request.base.repo.html_url, github.event.pull_request.base.sha) }}\n    steps:\n      -\n        name: Get app installation token (ghs)\n        id: get-app-token\n        uses: tibdex/github-app-token@v2\n        with: \n          app_id: ${{ secrets.BENCH_APP_ID }}\n          installation_id: ${{ secrets.BENCH_INSTALL_ID }}\n          private_key: ${{ secrets.BENCH_PRIVATE_KEY }}\n      -\n        uses: benc-uk/workflow-dispatch@v1\n        with:\n          repo: FluxML/FluxMLBenchmarks.jl\n          ref: refs/heads/main\n          workflow: Benchmark.yml\n          token: ${{ steps.get-app-token.outputs.token }}\n          inputs: '{ \"repository\": \"${{ env.REPOSITORY }}\", \"pr_id\": \"${{ env.PR_ID }}\", \"target_url\": \"${{ env.TARGET_URL }}\", \"baseline_url\": \"${{ env.BASELINE_URL }}\" }'\n"
  },
  {
    "path": ".github/workflows/CompatHelper.yml",
    "content": "name: CompatHelper\non:\n  schedule:\n    - cron: 0 0 * * *\n  workflow_dispatch:\npermissions:\n  contents: write\n  pull-requests: write\njobs:\n  CompatHelper:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check if Julia is already available in the PATH\n        id: julia_in_path\n        run: which julia\n        continue-on-error: true\n      - name: Install Julia, but only if it is not already available in the PATH\n        uses: julia-actions/setup-julia@v2\n        with:\n          version: '1'\n          arch: ${{ runner.arch }}\n        if: steps.julia_in_path.outcome != 'success'\n      - name: \"Add the General registry via Git\"\n        run: |\n          import Pkg\n          ENV[\"JULIA_PKG_SERVER\"] = \"\"\n          Pkg.Registry.add(\"General\")\n        shell: julia --color=yes {0}\n      - name: \"Install CompatHelper\"\n        run: |\n          import Pkg\n          name = \"CompatHelper\"\n          uuid = \"aa819f21-2bde-4658-8897-bab36330d9b7\"\n          version = \"3\"\n          Pkg.add(; name, uuid, version)\n        shell: julia --color=yes {0}\n      - name: \"Run CompatHelper\"\n        run: |\n          import CompatHelper\n          CompatHelper.main()\n        shell: julia --color=yes {0}\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}\n          # COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }}"
  },
  {
    "path": ".github/workflows/Downstream.yml",
    "content": "name: IntegrationTest\non:\n  push:\n    branches: [master]\n    tags: [v*]\n  pull_request:\n\n# needed to allow julia-actions/cache to delete old caches that it has created\npermissions:\n  actions: write\n  contents: read\n\njobs:\n  test:\n    name: ${{ matrix.package.repo }}/${{ matrix.package.group }}\n    runs-on: ${{ matrix.os }}\n    env:\n      GROUP: ${{ matrix.package.group }}\n    strategy:\n      fail-fast: false\n      matrix:\n        julia-version: [1]\n        os: [ubuntu-latest]\n        package:\n          - {user: FluxML, repo: Flux.jl, group: All}\n          - {user: FluxML, repo: Tracker.jl, group: All}\n          - {user: LuxDL, repo: Lux.jl, group: All}\n    steps:\n      - uses: actions/checkout@v6\n      - uses: julia-actions/setup-julia@v2\n        with:\n          version: ${{ matrix.julia-version }}\n          arch: x64\n      - uses: julia-actions/cache@v2\n      - uses: julia-actions/julia-buildpkg@latest\n      - name: Clone Downstream\n        uses: actions/checkout@v6\n        with:\n          repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}\n          path: downstream\n      - name: Load this and run the downstream tests\n        shell: julia --color=yes --project=downstream {0}\n        run: |\n          using Pkg\n          try\n            # force it to use this PR's version of the package\n            Pkg.develop(PackageSpec(path=\".\"))  # resolver may fail with main deps\n            Pkg.update()\n            Pkg.test()  # resolver may fail with test time deps\n          catch err\n            err isa Pkg.Resolve.ResolverError || rethrow()\n            # If we can't resolve that means this is incompatible by SemVer and this is fine\n            # It means we marked this as a breaking change, so we don't need to worry about\n            # Mistakenly introducing a breaking change, as we have intentionally made one\n            @info \"Not compatible with this release. No problem.\" exception=err\n            exit(0)  # Exit immediately, as a success\n          end\n        env:\n          RETESTITEMS_NWORKERS: 4\n          BACKEND_GROUP: CPU  # for Lux.jl\n\n"
  },
  {
    "path": ".github/workflows/TagBot.yml",
    "content": "name: TagBot\non:\n  issue_comment:\n    types:\n      - created\n  workflow_dispatch:\n    inputs:\n      lookback:\n        default: 3\npermissions:\n  actions: read\n  checks: read\n  contents: write\n  deployments: read\n  issues: read\n  discussions: read\n  packages: read\n  pages: read\n  pull-requests: read\n  repository-projects: read\n  security-events: read\n  statuses: read\njobs:\n  TagBot:\n    if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'\n    runs-on: ubuntu-latest\n    steps:\n      - uses: JuliaRegistries/TagBot@v1\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n          # Edit the following line to reflect the actual name of the GitHub Secret containing your private key\n          ssh: ${{ secrets.DOCUMENTER_KEY }}\n          # ssh: ${{ secrets.NAME_OF_MY_SSH_PRIVATE_KEY_SECRET }}"
  },
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: CI\n\non:\n  push:\n    branches:\n      - master\n      - staging\n      - trying\n    tags: '*'\n  pull_request:\n\n# needed to allow julia-actions/cache to delete old caches that it has created\npermissions:\n  actions: write\n  contents: read\n\ndefaults:\n  run:\n    shell: bash\n\njobs:\n  test:\n    name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.julia-threads }} thread(s) \n    runs-on: ${{ matrix.os }}\n    env:\n      JULIA_NUM_THREADS: ${{ matrix.julia-threads }}\n    strategy:\n      fail-fast: false\n      matrix:\n        version:\n          - '1.10' # uncomment when julia 1.10 is out\n          - '1'   # automatically expands to the latest stable 1.x release of Julia\n          - 'nightly'\n        os:\n          - ubuntu-latest\n          # - macOS-latest\n          # - windows-latest\n        julia-threads:\n          - '1'\n\n        include:\n          - os: windows-latest\n            version: '1'\n            julia-threads: '1'\n          - os: macOS-latest\n            version: '1'\n            julia-threads: '1'\n          - os: ubuntu-latest\n            version: '1'\n            julia-threads: '2'\n  \n    steps:\n      - uses: actions/checkout@v6\n      - uses: julia-actions/setup-julia@v2\n        with:\n          version: ${{ matrix.version }}\n      - uses: julia-actions/cache@v2\n      - uses: julia-actions/julia-buildpkg@v1\n\n      - name: \"Run test without coverage\"\n        uses: julia-actions/julia-runtest@v1\n        if: ${{ !contains(fromJson('[\"1\"]'), matrix.version) || matrix.os != 'ubuntu-latest' }}\n        with:\n          coverage: false\n\n      - name: \"Run test with coverage\"\n        uses: julia-actions/julia-runtest@v1\n        if: contains(fromJson('[\"1\"]'), matrix.version) && matrix.os == 'ubuntu-latest'\n      - uses: julia-actions/julia-processcoverage@v1\n        if: contains(fromJson('[\"1\"]'), matrix.version) && matrix.os == 'ubuntu-latest'\n      - uses: codecov/codecov-action@v5\n        if: contains(fromJson('[\"1\"]'), matrix.version) && matrix.os == 'ubuntu-latest'\n        with:\n          file: lcov.info\n\n  docs:\n    name: Documentation\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6\n      - uses: julia-actions/setup-julia@v2\n        with:\n          version: '1.10'\n      - uses: julia-actions/cache@v2\n      - run: |\n          julia --project=docs -e '\n            using Pkg\n            Pkg.develop(PackageSpec(path=pwd()))\n            Pkg.instantiate()'\n      - run: julia --project=docs docs/make.jl\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}\n"
  },
  {
    "path": ".github/workflows/clean_preview.yml",
    "content": "# from https://github.com/CliMA/ClimaTimeSteppers.jl\nname: Doc Preview Cleanup\n\non:\n  pull_request:\n    types: [closed]\n\njobs:\n  doc-preview-cleanup:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout gh-pages branch\n        uses: actions/checkout@v6\n        with:\n          ref: gh-pages\n      - name: Delete preview and history + push changes\n        run: |\n            if [ -d \"previews/PR$PRNUM\" ]; then\n              git config user.name \"Documenter.jl\"\n              git config user.email \"documenter@juliadocs.github.io\"\n              git rm -rf \"previews/PR$PRNUM\"\n              git commit -m \"delete preview\"\n              git branch gh-pages-new $(echo \"delete history\" | git commit-tree HEAD^{tree})\n              git push --force origin gh-pages-new:gh-pages\n            fi\n        env:\n            PRNUM: ${{ github.event.number }}\n"
  },
  {
    "path": ".github/workflows/pr_comment.yml",
    "content": "name: pr_comment\non:\n  pull_request:\n    types: [labeled]\njobs:\n  pr_comment:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Create PR comment\n        if: github.event_name == 'pull_request' && github.repository == github.event.pull_request.head.repo.full_name && github.event.label.name == 'documentation' # if this is a pull request build AND the pull request is NOT made from a fork\n        uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b\n        with:\n          message: 'Once the documentation build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/NNlib.jl/previews/PR${{ github.event.number }}/'\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": "*.jl.cov\n*.jl.*.cov\n*.jl.mem\n*.o\n*.so\n*.dylib\n*.dll\n*~\n\\#*\ndeps/usr\ndeps.jl\n*.log\n.vscode/\n/Manifest.toml\ntest/Manifest.toml\nbenchmark/Manifest.toml\nbenchmark/*.json\nbenchmark/report.md\n"
  },
  {
    "path": "LICENSE.md",
    "content": "The NNlib.jl package is licensed under the MIT \"Expat\" License:\n\n> Copyright (c) 2017-19: Julia Computing, Inc., Mike J Innes, and Contributors\n> \n> Permission is hereby granted, free of charge, to any person obtaining a copy\n> of this software and associated documentation files (the \"Software\"), to deal\n> in the Software without restriction, including without limitation the rights\n> to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n> copies of the Software, and to permit persons to whom the Software is\n> furnished to do so, subject to the following conditions:\n> \n> The above copyright notice and this permission notice shall be included in all\n> copies or substantial portions of the Software.\n> \n> THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n> IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n> FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n> AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n> LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n> OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n> SOFTWARE.\n> \n"
  },
  {
    "path": "Project.toml",
    "content": "name = \"NNlib\"\nuuid = \"872c559c-99b0-510c-b3b7-b6c96a88d5cd\"\nversion = \"0.9.34\"\n\n[deps]\nAdapt = \"79e6a3ab-5dfb-504d-930d-738a2a938a0e\"\nAtomix = \"a9b6321e-bd34-4604-b9c9-b65b8de01458\"\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nGPUArraysCore = \"46192b85-c4d5-4398-a991-12ede77f4527\"\nKernelAbstractions = \"63c18a36-062a-441e-b654-da1e3ab1ce7c\"\nLinearAlgebra = \"37e2e46d-f89d-539d-b4ee-838fcccc9c8e\"\nRandom = \"9a3f8284-a2c9-5f02-9a11-845980a1fd5c\"\nScopedValues = \"7e506255-f358-4e82-b7e4-beb19740aa63\"\nStatistics = \"10745b16-79ce-11e8-11f9-7d13ad32a3b2\"\n\n[weakdeps]\nAMDGPU = \"21141c5a-9bdb-4563-92ae-f87d6854732e\"\nCUDA = \"052768ef-5323-5732-b1bb-66c8b64840ba\"\nEnzymeCore = \"f151be2c-9106-41f4-ab19-57ee4f262869\"\nFFTW = \"7a1cc6ca-52ef-59f5-83cd-3a7055c09341\"\nForwardDiff = \"f6369f11-7733-5829-9624-2563aa707210\"\nMetal = \"dde4c033-4e86-420c-a63e-0dd931031962\"\nSpecialFunctions = \"276daf66-3868-5448-9aa4-cd146d93841b\"\ncuDNN = \"02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd\"\n\n[extensions]\nNNlibAMDGPUExt = \"AMDGPU\"\nNNlibCUDACUDNNExt = [\"CUDA\", \"cuDNN\"]\nNNlibCUDAExt = \"CUDA\"\nNNlibEnzymeCoreExt = \"EnzymeCore\"\nNNlibFFTWExt = \"FFTW\"\nNNlibForwardDiffExt = \"ForwardDiff\"\nNNlibMetalExt = \"Metal\"\nNNlibSpecialFunctionsExt = \"SpecialFunctions\"\n\n[compat]\nAMDGPU = \"1, 2\"\nAdapt = \"3.2, 4\"\nAtomix = \"0.1, 1\"\nCUDA = \"4, 5, 6\"\nChainRulesCore = \"1.25\"\nEnzymeCore = \"0.7, 0.8\"\nFFTW = \"1.8.0\"\nForwardDiff = \"1\"\nGPUArraysCore = \"0.2\"\nKernelAbstractions = \"0.9.2\"\nLinearAlgebra = \"1\"\nMetal = \"1.6\"\nRandom = \"1\"\nScopedValues = \"1.3.0\"\nSpecialFunctions = \"2\"\nStatistics = \"1\"\ncuDNN = \"1, 6\"\njulia = \"1.10\"\n"
  },
  {
    "path": "README.md",
    "content": "<img align=\"right\" width=\"200px\" src=\"https://github.com/FluxML/NNlib.jl/raw/master/docs/src/assets/logo.png\">\n\n# NNlib.jl\n\n[![Documentation][docs-dev-img]][docs-dev-url]\n[![CI](https://github.com/FluxML/NNlib.jl/actions/workflows/ci.yml/badge.svg)](https://github.com/FluxML/NNlib.jl/actions/workflows/ci.yml)\n[![Coverage](https://codecov.io/gh/FluxML/NNlib.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/FluxML/NNlib.jl) \n\n[docs-stable-img]: https://img.shields.io/badge/docs-stable-blue.svg\n[docs-stable-url]: https://fluxml.ai/NNlib.jl/stable/\n\n[docs-dev-img]: https://img.shields.io/badge/docs-latest-blue.svg\n[docs-dev-url]: https://fluxml.ai/NNlib.jl/dev/\n\nThis package provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multiplication, convolutions and pooling. Many of these are used by [Flux.jl](https://github.com/FluxML/Flux.jl), which loads this package, but they may be used independently.\n\nFor use with automatic differentiation, this package defines gradients using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). These will be seen by various packages including [Zygote.jl](https://github.com/FluxML/Zygote.jl).\n\nGPU support is provided as package extensions (see the `ext/` folder). In order to load the extensions, use the imports\n```julia\nusing NNlib, CUDA, cuDNN\n```\nfor CUDA support, or\n```julia\nusing NNlib, AMDGPU\n```\nfor AMDGPU support.\n"
  },
  {
    "path": "benchmark/Project.toml",
    "content": "[deps]\nArgParse = \"c7e460c6-2fb9-53a9-8c5b-16f535851c63\"\nBenchmarkCI = \"20533458-34a3-403d-a444-e18f38190b5b\"\nBenchmarkTools = \"6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf\"\nNNlib = \"872c559c-99b0-510c-b3b7-b6c96a88d5cd\"\nPkgBenchmark = \"32113eaa-f34f-5b0d-bd6c-c81e245fc73d\"\n\n[compat]\n# No compat bounds for NNlib because we may test breaking versions\nArgParse = \"1\"\nBenchmarkCI = \"0.1\"\nBenchmarkTools = \"1.3\"\nPkgBenchmark = \"0.2\"\njulia = \"1.6\"\n"
  },
  {
    "path": "benchmark/benchmarks.jl",
    "content": "using BenchmarkTools\nusing NNlib\nusing NNlib.ChainRulesCore: rrule\nusing Random\n\nRandom.seed!(1234567890)\n\nconst SUITE = BenchmarkGroup()\n\nSUITE[\"activations\"] = BenchmarkGroup()\nfor et in (Float16, Float32, Float64)\n    et_suite = BenchmarkGroup()\n    SUITE[\"activations\"][string(et)] = et_suite\n    let x = rand(et, 1024, 1024), y = similar(x)\n        for f in NNlib.ACTIVATIONS\n            act = @eval($f)\n            et_suite[string(f)] = @benchmarkable broadcast!($act, $y, $x)\n        end\n    end\nend\n\nfor (fn!, fn_bw) in [(softmax!, NNlib.∇softmax_data), (logsoftmax!, NNlib.∇logsoftmax_data)]\n    fn_suite = BenchmarkGroup()\n    SUITE[rstrip(string(fn!), '!')] = fn_suite\n    let SIZES = [\n        (128, 384, 8),\n        (512, 784, 8),\n        (768, 1024, 4),\n        (1024, 2048, 4),\n        (2048, 2048, 2),\n        (4096, 2048, 2),\n        (4096, 4096, 2),\n        (12288, 2048, 1)\n    ]\n        for et in (Float16, Float32)\n            et_suite = BenchmarkGroup(\"fw\" => BenchmarkGroup(), \"bw\" => BenchmarkGroup())\n            fn_suite[string(et)] = et_suite\n            for sz in SIZES\n                x = randn(et, sz)\n                y = similar(x)\n                dy = zero(x)\n                fn!(y, x)\n                et_suite[\"fw\"][string(sz)] = @benchmarkable $fn!($y, $x)\n                et_suite[\"bw\"][string(sz)] = @benchmarkable $fn_bw($dy, $y)\n            end\n        end\n    end\nend\n\n"
  },
  {
    "path": "benchmark/perf_report.jl",
    "content": "using JLD2, NNlib, BenchmarkTools\n\n# TODO organize and compare benchmarks using BenchmarkGroups\n\n# We need things to go quickly here\nBenchmarkTools.DEFAULT_PARAMETERS.samples = 20\nBenchmarkTools.DEFAULT_PARAMETERS.seconds = 2.5\n\nresults = Dict()\n\nfunction add_result(val, keys...)\n    r = results\n    for k in keys[1:end-1]\n        if !haskey(r, k)\n            r[k] = Dict()\n        end\n        r = r[k]\n    end\n    r[keys[end]] = val\n    return r\nend\n\n# Modify these as needed\nfor rank in (2,),\n    N in (20, 40, 80),\n    C_in in (1,),\n    C_out in (1,),\n    K in (3,),\n    stride in (1,),\n    dilation in (1,),\n    padding in (0, 2)\n\n    benchmark_items = [\n            (NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, \"direct\"),\n            (NNlib.conv_im2col!, NNlib.∇conv_data_im2col!, NNlib.∇conv_filter_im2col!, DenseConvDims, \"im2col\"),\n            (NNlib.depthwiseconv_direct!, NNlib.∇depthwiseconv_data_direct!, NNlib.∇depthwiseconv_filter_direct!, DepthwiseConvDims, \"direct\"),\n            (NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, \"im2col\"),\n    ]\n\n    for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in benchmark_items\n\n        x = zeros(Float32, repeat([N], rank)..., C_in, 1)\n        if cT == DenseConvDims\n            w = zeros(Float32, repeat([K], rank)..., C_in, C_out)\n        else\n            w = zeros(Float32, repeat([K], rank)..., C_out, C_in)\n        end\n        cdims = try\n            cT(x, w; stride=stride, dilation=dilation, padding=padding)\n        catch\n            continue\n        end\n\n        if cT == DenseConvDims\n            y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1)\n        else\n            y = zeros(Float32, NNlib.output_size(cdims)..., C_out*C_in, 1)\n        end\n\n        dx = similar(x)\n        dw = similar(w)\n        dy = similar(y)\n\n        t_fwd = @benchmark $(conv!)($y, $x, $w, $cdims)\n        t_dx = @benchmark $(∇conv_data!)($dx, $y, $w, $cdims)\n        t_dw = @benchmark $(∇conv_filter!)($dw, $x, $y, $cdims)\n\n        add_result(t_fwd, \"conv$(rank)d\", backend, cdims)\n        add_result(t_dx, \"conv$(rank)d_data\", backend, cdims)\n        add_result(t_dw, \"conv$(rank)d_filter\", backend, cdims)\n\n        @show(cdims)\n        @save \"results.jld2\" results\n    end\nend\n\n\n# Modify these as needed\nfor rank in (2,),\n    N in (20,),\n    K in (2, 4),\n    stride in (1, 2, 4)\n\n    x = zeros(Float32, repeat([N], rank)..., 1, 1)\n    pdims = PoolDims(x, K; stride=stride)\n    y = zeros(Float32, NNlib.output_size(pdims)..., 1, 1)\n    dx = similar(x)\n\n    for (pool, ∇pool, name) in (\n            (NNlib.maxpool!, NNlib.∇maxpool!, \"maxpool\"),\n            (NNlib.meanpool!, NNlib.∇meanpool!, \"meanpool\"),\n            (NNlib.lpnormpool!, NNlib.∇lpnormpool!, \"lpnormpool\"),\n        )\n\n        t_fwd  = @benchmark $(pool)( $y, $x, $pdims)\n        t_data = @benchmark $(∇pool)($dx, $y, $y, $x, $pdims)\n\n        add_result(t_fwd, \"$(name)$(rank)d\", \"direct\", pdims)\n        add_result(t_data, \"$(name)$(rank)d_data\", \"direct\", pdims)\n\n        @show(pdims)\n        @save \"results.jld2\" results\n    end\nend\n"
  },
  {
    "path": "benchmark/runbenchmarks.jl",
    "content": "# Adapted from\n# https://github.com/kul-forbes/ProximalOperators.jl/tree/master/benchmark\nusing ArgParse\nusing PkgBenchmark\nusing BenchmarkCI: displayjudgement, printresultmd, CIResult\nusing Markdown\n\nfunction markdown_report(judgement)\n    md = sprint(printresultmd, CIResult(judgement = judgement))\n    md = replace(md, \":x:\" => \"❌\")\n    md = replace(md, \":white_check_mark:\" => \"✅\")\n    return md\nend\n\nfunction parse_commandline()\n    s = ArgParseSettings()\n\n    @add_arg_table! s begin\n        \"--target\"\n            help = \"the branch/commit/tag to use as target\"\n            default = \"HEAD\"\n        \"--baseline\"\n            help = \"the branch/commit/tag to use as baseline\"\n            default = \"master\"\n        \"--retune\"\n            help = \"force re-tuning (ignore existing tuning data)\"\n            action = :store_false\n    end\n\n    return parse_args(s)\nend\n\nfunction main()\n    parsed_args = parse_commandline()\n\n    mkconfig(; kwargs...) =\n        BenchmarkConfig(\n            env = Dict(\n                \"JULIA_NUM_THREADS\" => get(ENV, \"JULIA_NUM_THREADS\", \"1\"),\n            );\n            kwargs...\n        )\n\n    target = parsed_args[\"target\"]\n    group_target = benchmarkpkg(\n        dirname(@__DIR__),\n        mkconfig(id = target),\n        resultfile = joinpath(@__DIR__, \"result-$(target).json\"),\n        retune = parsed_args[\"retune\"],\n    )\n\n    baseline = parsed_args[\"baseline\"]\n    group_baseline = benchmarkpkg(\n        dirname(@__DIR__),\n        mkconfig(id = baseline),\n        resultfile = joinpath(@__DIR__, \"result-$(baseline).json\"),\n    )\n\n    judgement = judge(group_target, group_baseline)\n    report_md = markdown_report(judgement)\n    write(joinpath(@__DIR__, \"report.md\"), report_md)\n    display(Markdown.parse(report_md))\nend\n\nmain()\n"
  },
  {
    "path": "docs/.gitignore",
    "content": "build/\nsite/\nManifest.toml\n"
  },
  {
    "path": "docs/Project.toml",
    "content": "[deps]\nCairoMakie = \"13f3f980-e62b-5c42-98c6-ff1f3baf88f0\"\nDocumenter = \"e30172f5-a6a5-5a46-863b-614d45cd2de4\"\nFLAC = \"abae9e3b-a9a0-4778-b5c6-ca109b507d99\"\nFileIO = \"5789e2e9-d7fb-5bc7-8068-2c6fae9b9549\"\nMakie = \"ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a\"\nNNlib = \"872c559c-99b0-510c-b3b7-b6c96a88d5cd\"\nUnicodePlots = \"b8865327-cd53-5732-bb35-84acbb429228\"\nFFTW = \"7a1cc6ca-52ef-59f5-83cd-3a7055c09341\"\n"
  },
  {
    "path": "docs/make.jl",
    "content": "using Documenter, NNlib\n\nDocMeta.setdocmeta!(NNlib, :DocTestSetup,\n    :(using FFTW, NNlib, UnicodePlots); recursive = true)\n\nmakedocs(modules = [NNlib],\n    sitename = \"NNlib.jl\",\n    doctest = true,\n    pages = [\"Home\" => \"index.md\",\n             \"Reference\" => \"reference.md\",\n             \"Audio\" => \"audio.md\"],\n    format = Documenter.HTML(\n        canonical = \"https://fluxml.ai/NNlib.jl/stable/\",\n        # analytics = \"UA-36890222-9\",\n        assets = [\"assets/flux.css\"],\n        prettyurls = get(ENV, \"CI\", nothing) == \"true\"),\n    warnonly=[:missing_docs,]\n)\n\ndeploydocs(repo = \"github.com/FluxML/NNlib.jl.git\",\n           target = \"build\",\n           push_preview = true)\n"
  },
  {
    "path": "docs/src/assets/flux.css",
    "content": "@import url('https://fonts.googleapis.com/css?family=Lato:400,400i');\n\nbody {\n  font-family: Lato, \"Segoe UI\",Roboto,\"Helvetica Neue\",Arial,sans-serif;\n}\n\nnav.toc {\n  padding-top: 0;\n  background: rgb(240, 240, 240);\n  line-height: 2em;\n  cursor: default;\n  user-select: none;\n}\n\nh1+h2 {\n  margin-top: 0;\n}\n\n/* Green banner in ToC */\nnav.toc > h1 {\n  margin-top: 0;\n  padding-top: 0.4em;\n  padding-bottom: 0.5em;\n  border-bottom: 5px solid white;\n  box-shadow: 0px -2px 5px rgb(60,60,60);\n  margin-bottom: 0.5em;\n  background: rgb(60, 150, 60);\n\n  font-style: italic;\n  font-weight: normal;\n  font-size: 50pt;\n  text-transform: lowercase;\n  text-shadow: 2px 2px 5px rgba(0,0,0,0.2);\n  color: white;\n}\n\n/* Reduce ToC font size */\n.toctext {\n  font-size: 10pt;\n}\n\n/* Fade out non-clickable ToC headers */\nnav.toc ul span.toctext {\n  color: rgb(180, 180, 180);\n}\n\nnav.toc ul .toctext {\n  color: rgb(100, 100, 100);\n}\n\nnav.toc ul a.toctext:hover {\n  color: inherit;\n  background: rgb(220, 220, 220);\n  cursor: default;\n}\n\nnav.toc li.current > .toctext {\n  background: linear-gradient(90deg, rgb(245,245,245) 0%, white 90%);\n  font-weight: normal;\n}\n\nnav.toc ul.internal li.toplevel {\n  font-weight: normal;\n}\n\n/* Content */\n\narticle { max-width: none; }\n\narticle > p, article > ul {\n  max-width: 45em;\n}\n\n/* Links */\na, a:visited { color: rgb(0, 120, 0); }\narticle p a { border-bottom: 1px solid rgb(200, 230, 200); }\na:hover, a:visited:hover { color: rgb(0, 80, 0); }\n\n/* Article Links */\narticle p a { border-bottom: 1px solid rgb(200, 230, 200); }\narticle p a:hover, article a:visited:hover { color: rgb(0, 120, 0); }\narticle p a:hover { border-bottom: 1px solid rgb(150, 200, 150); }\n\n/* Doctstrings */\narticle section.docstring {\n  padding: 0.5em 0;\n  border-left: none;\n  border-right: none;\n  border-bottom: none;\n}\n\n/* Code */\n\narticle pre, article p > code {\n  background: rgb(245, 250, 245);\n}\n\narticle pre {\n  border: none;\n  max-width: none;\n  padding: 1em;\n  border-radius: 10px 0px 0px 10px;\n}\n\n.hljs-comment {\n  font-style: italic;\n}\n\n.hljs-number {\n  color: rgb(0, 150, 150);\n}\n"
  },
  {
    "path": "docs/src/audio.md",
    "content": "# Reference\n\n!!! note\n    Spectral functions require importing `FFTW` package to enable them.\n\n## Window functions\n\n```@docs\nhann_window\nhamming_window\n```\n\n## Spectral\n\n```@docs\nstft\nistft\nNNlib.power_to_db\nNNlib.db_to_power\n```\n\n## Spectrogram\n\n```@docs\nmelscale_filterbanks\nspectrogram\n```\n\nExample:\n\n```@example 1\nusing FFTW # <- required for STFT support.\nusing NNlib\nusing FileIO\nusing Makie, CairoMakie\nCairoMakie.activate!()\n\nwaveform, sampling_rate = load(\"./assets/jfk.flac\")\nfig = lines(reshape(waveform, :))\nsave(\"waveform.png\", fig)\n\n# Spectrogram.\n\nn_fft = 1024\nspec = spectrogram(waveform; n_fft, hop_length=n_fft ÷ 4, window=hann_window(n_fft))\nfig = heatmap(transpose(NNlib.power_to_db(spec)[:, :, 1]))\nsave(\"spectrogram.png\", fig)\n\n# Mel-scale spectrogram.\n\nn_freqs = n_fft ÷ 2 + 1\nfb = melscale_filterbanks(; n_freqs, n_mels=128, sample_rate=Int(sampling_rate))\nmel_spec = permutedims(spec, (2, 1, 3)) ⊠ fb # (time, n_mels)\nfig = heatmap(NNlib.power_to_db(mel_spec)[:, :, 1])\nsave(\"mel-spectrogram.png\", fig)\nnothing # hide\n```\n\n|Waveform|Spectrogram|Mel Spectrogram|\n|:---:|:---:|:---:|\n|![](waveform.png)|![](spectrogram.png)|![](mel-spectrogram.png)|\n"
  },
  {
    "path": "docs/src/index.md",
    "content": "# NNlib.jl\n\n`NNlib` provides a library of functions useful for neural networks, such as softmax, sigmoid, batched multiplication, convolutions and pooling. Many of these are used by [Flux.jl](https://github.com/FluxML/Flux.jl), which loads this package, but they may be used independently.\n\nFor use with automatic differentiation, this package defines gradients using [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). These will be seen by various packages including [Zygote.jl](https://github.com/FluxML/Zygote.jl).\n\nGPU support is provided as package extensions. In order to load the extensions, use the imports\n```julia\nusing NNlib, CUDA, cuDNN\n```\nfor CUDA support, or\n```julia\nusing NNlib, AMDGPU\n```\nfor AMDGPU support.\n\n## Threading\n\nVarious `NNlib` functions utilize available julia threads on divisible workloads. To disable this use\nthe `ScopedValue`-backed switch `NNlib.@disallow_spawns`\ni.e.\n```julia\nNNlib.@disallow_spawns function_that_uses_nnlib()\n```\n"
  },
  {
    "path": "docs/src/reference.md",
    "content": "# Reference\n\nThe API reference of `NNlib`.\n\n## Activation Functions\n\nNon-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on.\n\n```@docs\ncelu\nelu\ngelu\ngelu_tanh\ngelu_sigmoid\ngelu_erf\nhardsigmoid\nsigmoid_fast\nhardtanh\ntanh_fast\nleakyrelu\nlisht\nlogcosh\nlogsigmoid\nmish\nrelu\nrelu6\nrrelu\nselu\nsigmoid\nsoftplus\nsoftshrink\nsoftsign\nswish\nhardswish\ntanhshrink\ntrelu\n```\n\n## Attention \n\n```@docs\ndot_product_attention\ndot_product_attention_scores\nmake_causal_mask\n```\n\n## Softmax\n\n`Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally.\n\n```@docs\nsoftmax\nlogsoftmax\n```\n\n## Pooling\n\n`Flux`'s `AdaptiveMaxPool`, `AdaptiveMeanPool`, `GlobalMaxPool`, `GlobalMeanPool`, `MaxPool`, `MeanPool` and `lpnormpool` use `NNlib.PoolDims`, `NNlib.maxpool`, `NNlib.meanpool` and `NNlib.lpnormpool` as their backend.\n\n```@docs\nPoolDims\nmaxpool\nmeanpool\nlpnormpool\n```\n\n## Padding\n\n```@docs\npad_reflect\npad_symmetric\npad_circular\npad_repeat\npad_constant\npad_zeros\n```\n\n## Convolution\n\n`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.\n\n`NNlib.conv` supports complex datatypes on CPU and CUDA devices.\n\n!!! note \"AMDGPU MIOpen supports only cross-correlation (`flipkernel=true`).\"\n\n    Therefore for every regular convolution (`flipkernel=false`)\n    kernel is flipped before calculation.\n    For better performance, use cross-correlation (`flipkernel=true`)\n    and manually flip the kernel before `NNlib.conv` call.\n    `Flux` handles this automatically, this is only required for direct calls.\n\n```@docs\nconv\nConvDims\ndepthwiseconv\nDepthwiseConvDims\nDenseConvDims\nNNlib.unfold\nNNlib.fold\n```\n\n## Upsampling\n\n`Flux`'s `Upsample` layer uses `NNlib.upsample_nearest`, `NNlib.upsample_bilinear`, and `NNlib.upsample_trilinear` as its backend. Additionally, `Flux`'s `PixelShuffle` layer uses `NNlib.pixel_shuffle` as its backend.\n\n```@docs\nupsample_nearest\n∇upsample_nearest\nupsample_linear\n∇upsample_linear\nupsample_bilinear\n∇upsample_bilinear\nupsample_trilinear\n∇upsample_trilinear\npixel_shuffle\n```\n\n## Rotation\nRotate images in the first two dimensions of an array.\n\n```@docs\nimrotate\n∇imrotate\n```\n\n## Batched Operations\n\n`Flux`'s `Bilinear` layer uses `NNlib.batched_mul` internally.\n\n```@docs\nbatched_mul\nbatched_mul!\nbatched_adjoint\nbatched_transpose\nbatched_vec\n```\n\n## Gather and Scatter\n\n`Flux`'s `Embedding` layer uses `NNlib.gather` as its backend.\n\n```@docs\nNNlib.gather\nNNlib.gather!\nNNlib.scatter\nNNlib.scatter!\n```\n\n## Sampling\n\n```@docs\ngrid_sample\n∇grid_sample\n```\n\n## Losses\n\n```@docs\nctc_loss\n```\n\n## Miscellaneous\n\n```@docs\nlogsumexp\nNNlib.glu\nNNlib.within_gradient\nbias_act!\n```\n"
  },
  {
    "path": "ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl",
    "content": "module NNlibAMDGPUExt\n\nusing Adapt\nusing AMDGPU\nusing ChainRulesCore\nusing NNlib\nusing NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans\nusing NNlib: DenseConvDims, PoolDims\n\nconst MIOPENFloat = Union{Float16, Float32}\n\nconst ROCBatchedAdjoint{T} = BatchedAdjoint{T, <: ROCArray{T}}\nconst ROCBatchedTranspose{T} = BatchedTranspose{T, <: ROCArray{T}}\nconst ROCBatchedAdjOrTrans{T} = Union{ROCBatchedAdjoint{T}, ROCBatchedTranspose{T}}\nconst WrappedROCBatchedAdjOrTrans{T, N} = Adapt.WrappedArray{T, N, ROCBatchedAdjOrTrans{T}, ROCBatchedAdjOrTrans{T}}\nconst AnyROCBatchedAdjOrTrans = Union{ROCBatchedAdjOrTrans, WrappedROCBatchedAdjOrTrans}\n\nfunction Base.convert(::Type{T}, b::AnyROCBatchedAdjOrTrans) where {T <: Array}\n    Base.convert(T, adapt(Array, b))\nend\n\nfunction Base.Array{T, N}(b::AnyROCBatchedAdjOrTrans) where {T, N}\n    Array{T, N}(adapt(Array, b))\nend\n\nBase.collect(b::AnyROCBatchedAdjOrTrans) = collect(adapt(Array, b))\n\nfunction Base.show(\n    io::IO, mime::MIME{Symbol(\"text/plain\")}, x::AnyROCBatchedAdjOrTrans,\n)\n    show(io, mime, adapt(Array, x))\nend\n\nBase.show(io::IO, x::AnyROCBatchedAdjOrTrans) = show(io, adapt(Array, x))\n\nBase.display(x::AnyROCBatchedAdjOrTrans) = display(adapt(Array, x))\n\nfunction nnlib_padding(dims)\n    pd = NNlib.padding(dims)\n    if !all(pd[1:2:end] .== pd[2:2:end])\n        @warn \"\"\"\n        MIOpen does not support asymmetric padding, defaulting to symmetric choice:\n        $pd -> $(pd[1:2:end]).\n        \"\"\" maxlog=1\n    end\n    pd[1:2:end]\nend\n\ninclude(\"batched_mul.jl\")\n\n@static if AMDGPU.functional(:MIOpen)\n    using AMDGPU.MIOpen\n\n    include(\"conv.jl\")\n    include(\"pool.jl\")\n    include(\"activations.jl\")\nelse\n    @warn \"\"\"\n    ROCm MIOpen is not available for AMDGPU.\n    NNlib has limited functionality for AMDGPU.\n    \"\"\"\nend\n\nend\n"
  },
  {
    "path": "ext/NNlibAMDGPUExt/activations.jl",
    "content": "for (f, op) in [\n        NNlib.relu => MIOpen.relu,\n        NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6),\n        NNlib.softplus => MIOpen.softrelu,\n        NNlib.σ => MIOpen.sigmoid,\n        Base.tanh => MIOpen.tanh,\n        # TODO define for leakyrelu, elu, etc.?\n    ], N in 1:5\n    @eval function Base.materialize(\n        bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat,$N}}}\n    )\n        return $op(bc.args[1])\n    end\nend\n\nBase.broadcasted(::typeof(identity), x::ROCArray{T}) where {T<:MIOPENFloat} = x\n"
  },
  {
    "path": "ext/NNlibAMDGPUExt/batched_mul.jl",
    "content": "function _blas_at(x)\n    Base.stride(x, 1) == 1 && return x, 'N'\n    Base.stride(x, 2) == 1 && return batched_transpose(x), 'T'\n    throw(ArgumentError(\"\"\"\n    Unsupported array layout for batched mul.\n    - Size: $(size(x))\n    - Strides: $(strides(x))\n    \"\"\"))\nend\n\nfunction NNlib._batched_mul!(\n    ::Type{AT}, C, A, B, α::Float16, β::Float16,\n) where AT <: ROCArray{Float16}\n    blasA, transA = _blas_at(A)\n    blasB, transB = _blas_at(B)\n    NNlib._batched_gemm!(AT, transA, transB, α, blasA, blasB, β, C)\n    C\nend\n\nfunction NNlib._batched_gemm!(\n    ::Type{<:ROCArray{T}}, transA::Char, transB::Char, α::T, A, B, β::T, C,\n) where T <: Union{MIOPENFloat, Float64}\n    AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C)\nend\n"
  },
  {
    "path": "ext/NNlibAMDGPUExt/conv.jl",
    "content": "function NNlib.conv!(\n    y::ROCArray{T, N}, x::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims,\n) where {T <: MIOPENFloat, N}\n    if !NNlib.flipkernel(cdims)\n        @warn \"\"\"\n        MIOpen supports only cross-correlation (flipkernel=true).\n        Therefore for every regular convolution (flipkernel=false)\n        kernel is flipped before calculation.\n        For better performance, use cross-correlation (flipkernel=true)\n        and manually flip the kernel before `NNlib.conv` call.\n        \"\"\" maxlog=1\n        flip_dims = ntuple(\n            i -> (i ≤ ndims(w) - 2) ? (size(w, i):-1:1) : Colon(),\n            ndims(w))\n        w = w[flip_dims...]\n    end\n\n    nd = max(0, 4 - N)\n    ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)\n    MIOpen.convolution!(\n        NNlib.insert_singleton_spatial_dimension(y, nd),\n        NNlib.insert_singleton_spatial_dimension(x, nd),\n        NNlib.insert_singleton_spatial_dimension(w, nd);\n        padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),\n        dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))\n    return y\nend\n\nfunction NNlib.∇conv_data!(\n    dx::ROCArray{T, N}, dy::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims,\n) where {T <: MIOPENFloat, N}\n    if !NNlib.flipkernel(cdims)\n        @warn \"\"\"\n        MIOpen supports only cross-correlation (flipkernel=true).\n        Therefore for every regular convolution (flipkernel=false)\n        kernel is flipped before calculation.\n        For better performance, use cross-correlation (flipkernel=true)\n        and manually flip the kernel before `NNlib.conv` call.\n        \"\"\" maxlog=1\n        flip_dims = ntuple(\n            i -> (i ≤ ndims(w) - 2) ? (size(w, i):-1:1) : Colon(),\n            ndims(w))\n        w = w[flip_dims...]\n    end\n\n    nd = max(0, 4 - N)\n    ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)\n    MIOpen.∇convolution_data!(\n        NNlib.insert_singleton_spatial_dimension(dx, nd),\n        NNlib.insert_singleton_spatial_dimension(dy, nd),\n        NNlib.insert_singleton_spatial_dimension(w, nd);\n        padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),\n        dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))\n    return dx\nend\n\nfunction NNlib.∇conv_filter!(\n    dw::ROCArray{T, N}, x::ROCArray{T, N}, dy::ROCArray{T, N}, cdims::DenseConvDims,\n) where {T <: MIOPENFloat, N}\n    nd = max(0, 4 - N)\n    ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd)\n    MIOpen.∇convolution_weight!(\n        NNlib.insert_singleton_spatial_dimension(dw, nd),\n        NNlib.insert_singleton_spatial_dimension(dy, nd),\n        NNlib.insert_singleton_spatial_dimension(x, nd);\n        padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims),\n        dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims))\n\n    if !NNlib.flipkernel(cdims)\n        @warn \"\"\"\n        MIOpen supports only cross-correlation (flipkernel=true).\n        Therefore for every regular convolution (flipkernel=false)\n        kernel is flipped before calculation.\n        For better performance, use cross-correlation (flipkernel=true)\n        and manually flip the kernel before `NNlib.conv` call.\n        \"\"\" maxlog=1\n        flip_dims = ntuple(\n            i -> (i ≤ ndims(dw) - 2) ? (size(dw, i):-1:1) : Colon(),\n            ndims(dw))\n        dw = dw[flip_dims...]\n    end\n    return dw\nend\n"
  },
  {
    "path": "ext/NNlibAMDGPUExt/pool.jl",
    "content": "for poolname in (:maxpool, :meanpool)\n    @eval function NNlib.$(poolname)(\n        x::ROCArray{T, N}, pdims::PoolDims,\n    ) where {T <: MIOPENFloat, N}\n        y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N))\n        nd = max(0, 4 - N)\n        npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd)\n        MIOpen.$(Symbol(\"$(poolname)!\"))(\n            NNlib.insert_singleton_spatial_dimension(y, nd),\n            NNlib.insert_singleton_spatial_dimension(x, nd);\n            dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),\n            stride=NNlib.stride(npdims), do_backward=false)\n        return y\n    end\n\n    @eval function ChainRulesCore.rrule(\n        ::typeof(NNlib.$(poolname)), x::ROCArray{T, N}, pdims::PoolDims,\n    ) where {T <: MIOPENFloat, N}\n        y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N))\n        nd = max(0, 4 - N)\n        npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd)\n\n        # `workspace` is used in the pullback.\n        _, workspace = MIOpen.$(Symbol(\"$(poolname)!\"))(\n            NNlib.insert_singleton_spatial_dimension(y, nd),\n            NNlib.insert_singleton_spatial_dimension(x, nd);\n            dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),\n            stride=NNlib.stride(npdims))\n\n        function _pooling_pullback(Δ)\n            dx = similar(x)\n            MIOpen.$(Symbol(\"∇$(poolname)!\"))(\n                NNlib.insert_singleton_spatial_dimension(dx, nd),\n                NNlib.insert_singleton_spatial_dimension(unthunk(Δ), nd),\n                NNlib.insert_singleton_spatial_dimension(y, nd),\n                NNlib.insert_singleton_spatial_dimension(x, nd);\n                dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims),\n                stride=NNlib.stride(npdims), workspace)\n            return NoTangent(), dx, NoTangent()\n        end\n        y, _pooling_pullback\n    end\nend\n"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/NNlibCUDACUDNNExt.jl",
    "content": "module NNlibCUDACUDNNExt\n\nusing NNlib\nusing cuDNN\nusing CUDA\nusing Random, Statistics\n\nusing cuDNN: handle, with_workspace, cudnnTensorDescriptor, cudnnFilterDescriptor,\n             cudnnDataType, math_mode, CUDNN_DEFAULT_REORDER, CUDNN_CROSS_CORRELATION,\n             CUDNN_NOT_PROPAGATE_NAN, CUDNN_TENSOR_NCHW, dim4\n\ncudnnversion() = cuDNN.version()\n\nfunction nnlibPadding(dims)\n    pd = NNlib.padding(dims)\n    if !all(pd[1:2:end] .== pd[2:2:end])\n        @warn \"cuDNN does not support asymmetric padding; defaulting to symmetric choice\" maxlog=1\n    end\n    return pd[1:2:end]\nend\n\ninclude(\"conv.jl\")\ninclude(\"pooling.jl\")\ninclude(\"softmax.jl\")\ninclude(\"activations.jl\")\ninclude(\"batchnorm.jl\")\n\nend # module"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/activations.jl",
    "content": "\n# Activation\n\nusing Base.Broadcast\nusing cuDNN: cudnnActivationForward!, cudnnOpTensor!,\n             CUDNN_ACTIVATION_TANH, CUDNN_ACTIVATION_SIGMOID, CUDNN_ACTIVATION_ELU,\n             CUDNN_ACTIVATION_RELU, CUDNN_ACTIVATION_CLIPPED_RELU, CUDNN_OP_TENSOR_MAX,\n             CUDNN_ACTIVATION_IDENTITY\n\nfor (f, op) in [\n    CUDA.tanh       => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_TANH),\n    NNlib.σ         => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_SIGMOID),\n    NNlib.elu       => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_ELU),\n    NNlib.relu      => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_RELU),\n    # NNlib.relu6     => (src,dst)->cudnnActivationForward!(dst, src, mode=CUDNN_ACTIVATION_CLIPPED_RELU, coef=6.0),\n    # NNlib.leakyrelu => (src,dst)->cudnnOpTensor!(dst, src, src; op=CUDNN_OP_TENSOR_MAX, alpha1=0.01),\n    ]\n\n    @eval begin\n        # in-place\n        function Base.materialize!(dst::DenseCuArray{<:CUDNNFloat},\n                                   bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})\n            $op(bc.args[1], dst)\n            return dst\n        end\n\n        # out of place\n        function Base.materialize(bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{DenseCuArray{<:CUDNNFloat}}})\n            ElType = Broadcast.combine_eltypes(bc.f, bc.args)\n            dst = similar(bc, ElType)\n            $op(bc.args[1], dst)\n            return dst\n        end\n    end\nend\n\n# CUDNN_ACTIVATION_IDENTITY does not work with cudnnActivationForward\n# FIXME: put this optimization in GPUArrays' `copyto!` (like Base.Broadcast's `copyto!`)\nBase.broadcasted(::typeof(identity), x::DenseCuArray{T}) where {T<:CUDNNFloat} = x\n\n"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/batchnorm.jl",
    "content": "using cuDNN: CUDNN_BN_MIN_EPSILON, cudnnBatchNormalizationBackward,\n             cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL,\n             cudnnBatchNormalizationForwardTraining\nimport NNlib: batchnorm, ∇batchnorm\n\n# TODO: replace with new cudnn normalization interface\n# https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/normalization.jl\n\nmutable struct BNCache\n  mean\n  ivar\nend\n\nBNCache() = BNCache(nothing, nothing)\n\n@inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)\n\nfunction batchnorm(g::Nothing, b::Nothing, x::DenseCuArray,\n                   running_mean, running_var, momentum; kws...)\n  affine_sz = _wsize(x)\n  g = fill!(similar(x, affine_sz), 1)\n  b = fill!(similar(x, affine_sz), 0)\n  return batchnorm(g, b, x, running_mean, running_var, momentum; kws...)\nend\n\n# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations\n# so reshape a 2D Tensor into 4D\nfunction batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},\n                   running_mean, running_var, momentum; kws...) where T<:CUDNNFloat\n  x = reshape(x, 1, 1, size(x, 1), size(x, 2))\n  y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)\n  return dropdims(y, dims = (1, 2))\nend\n\nfunction batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}},\n                   running_mean, running_var, momentum; kws...) where T<:CUDNNFloat\n  cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)\nend\n\nfunction cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T},\n                        running_mean, running_var, momentum;\n                        cache = nothing,\n                        alpha = T(1), beta = T(0),\n                        eps = T(1e-5),\n                        training = true,\n                        affine = true,\n                        track_stats = true) where T<:CUDNNFloat\n  dims = _wsize(x)\n  if eps < CUDNN_BN_MIN_EPSILON\n    @warn \"eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON\"\n    eps = CUDNN_BN_MIN_EPSILON\n  end\n\n  if running_mean === nothing || running_var === nothing\n    running_mean !== running_var && throw(ArgumentError(\"both or neither of running_mean and running_var must be nothing\"))\n    if track_stats || !training\n      running_mean = fill!(similar(x, dims), 0)\n      running_var = fill!(similar(x, dims), 1)\n    end\n  end\n\n  xd = cudnnTensorDescriptor(x)\n  yd = cudnnTensorDescriptor(y)\n  gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW)))\n\n  if training\n    if !track_stats\n      running_mean = CU_NULL\n      running_var = CU_NULL\n    end\n\n    if cache !== nothing\n      mean = fill!(similar(x, dims), 0)\n      ivar = fill!(similar(x, dims), 1)\n    else\n      mean = CU_NULL\n      ivar = CU_NULL\n    end\n\n    cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, running_mean, running_var, eps, mean, ivar)\n\n    if cache !== nothing\n      cache.mean = mean\n      cache.ivar = ivar\n    end\n  else\n    if track_stats\n      cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps)\n    else\n      # cudnnBatchNormalizationForwardInference does not accept CV_NULL for running_mean\n      # and running_var. We could calculate mean and var of `x` here, but instead use\n      # cudnnBatchNormalizationFowardTraining. cudnnBatchNormalizationForwardTraining does\n      # accept CV_NULL and will calculate mean and var itself.\n      cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, CU_NULL, CU_NULL, eps, CU_NULL, CU_NULL)\n    end\n  end\n  return y\nend\n\nfunction ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray,\n                    running_mean, running_var, momentum; kws...)\n  affine_sz = _wsize(x)\n  g = fill!(similar(x, affine_sz), 1)\n  b = fill!(similar(x, affine_sz), 0)\n  return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)\nend\n\nfunction ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},\n            running_mean, running_var, momentum;\n            kws...) where T<:CUDNNFloat\n  dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),\n                          size(dy, 2)), running_mean, running_var, momentum; kws...)\n  (dg, db, dropdims(dx, dims = (1, 2)))\nend\n\n\nfunction ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},\n                    running_mean, running_var, momentum;\n                    affine=true, kws...) where T<:CUDNNFloat\n  dg = similar(g)\n  db = similar(b)\n  dx = similar(x)\n  cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...)\n  if affine\n    (dg, db, dx)\n  else\n    # cuDNN always calculates dg and db, therefore we just have to drop them\n    (nothing, nothing, dx)\n  end\nend\n\nfunction cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T},\n                          dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},\n                          running_mean, running_var,\n                          momentum; cache = nothing, eps = T(1e-5),\n                          alpha = T(1), beta = T(0),\n                          dalpha = T(1), dbeta = T(0), training = true,\n                          track_stats = true) where T<:CUDNNFloat\n  if !track_stats\n    running_mean = CU_NULL\n    running_var = CU_NULL\n  end\n\n  xd = cudnnTensorDescriptor(x)\n  dyd = cudnnTensorDescriptor(dy)\n  dxd = cudnnTensorDescriptor(dx)\n  gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))\n  if cache !== nothing\n    @debug \"fetching mean and ivar from the cache\"\n    mean, ivar = cache.mean, cache.ivar\n  else\n    mean, ivar = CU_NULL, CU_NULL\n  end\n\n  if eps < CUDNN_BN_MIN_EPSILON\n    @warn \"eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON\"\n    eps = CUDNN_BN_MIN_EPSILON\n  end\n\n  cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,\n        scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),\n        xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)\nend\n"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/conv.jl",
    "content": "\nusing NNlib: DenseConvDims\nimport NNlib: conv!, ∇conv_filter!, ∇conv_data!, conv_bias_act!\n\nusing cuDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,\n             cudnnConvolutionBwdDataAlgoPerf,\n             cudnnConvolutionForward!, cudnnConvolutionBwdFilterAlgoPerf,\n             cudnnConvolutionBackwardData, cudnnConvolutionBackwardFilter,\n             cudnnConvolutionBackwardBias\nimport cuDNN: cudnnConvolutionDescriptor\n\nconst CUDNNFloat = Union{Float16,Float32,Float64}\nconst CUDNNComplexFloat = Union{ComplexF16,ComplexF32,ComplexF64}\n\nfunction cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T\n    # The main purpose of this function is to catch asymmetric padding which cudnn does not support\n    # If we find asymmetric padding we'll make a copy of x which is manually padded so that we can\n    # call cudnn with symmetric padding.\n    pad = NNlib.padding(cdims)\n    sdims = NNlib.spatial_dims(cdims)\n    all(i -> pad[i] .== pad[i+1], 1:2:2sdims) && return (cudnnConvolutionDescriptor(cdims, x), x, identity)\n\n    # Naive implementation, is there a faster way?\n    # How much we need to pad x manually: The absolute difference between pad_left and pad_right, pad_top\n    # and pad_bottom etc. respectively. We keep the sign here though because we use it below to figure out\n    # which side of x to pad. Oh, and we use a CartesianIndex as we will mainly use this to index in x\n    pad_manual = CartesianIndex(ntuple(i -> i > sdims ? 0 : pad[2(i-1)+1] - pad[2(i-1)+2], ndims(x)))\n\n    # How much we can let cudnn pad: The smallest padding amount between pad_left and pad_right, pad_top\n    # and pad_bottom etc. respectively\n    pad_cudnn = ntuple(i -> min(pad[2(i-1)+1], pad[2(i-1)+2]), sdims)\n\n    x_padded_size = ntuple(i -> i <= sdims ? size(x, i) + abs(pad_manual[i]) : size(x ,i), ndims(x))\n    x_padded = similar(x, x_padded_size)\n    fill!(x_padded, 0)\n    # This is a bit yucky, but we are basically figuring out where in x_padded we shall insert x\n    # Haven't benchmarked if this has any advantages over a more readable solution, e.g. writing dim\n    # by dim to an array in a loop\n    xIs = CartesianIndices(x)\n    xI_first = first(xIs)\n    xI_last = last(xIs)\n    xIs_pad = max(xI_first, xI_first + pad_manual) : max(xI_last, xI_last + pad_manual)\n    x_padded[xIs_pad] = x\n\n    return cudnnConvolutionDescriptor(cdims, x_padded, pad_cudnn), x_padded, _x -> _x[xIs_pad]\nend\n\nfunction cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pad = nnlibPadding(cdims)) where T\n    mode=(NNlib.flipkernel(cdims) ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION)\n    cudnnConvolutionDescriptor(convdims(pad, size(x),0),\n                               convdims(NNlib.stride(cdims),size(x),1),\n                               convdims(NNlib.dilation(cdims),size(x),1),\n                               mode,\n                               cudnnDataType(real(T)),\n                               math_mode(),\n                               CUDNN_DEFAULT_REORDER,\n                               Cint(NNlib.groupcount(cdims)))\nend\n\n@inline function _complex!(y::DenseCuArray{T1}, yr::DenseCuArray{T2}, yi::DenseCuArray{T2}; bias=zero(T1), alpha=one(T1), beta=zero(T1), σ=identity) where {T1 <: CUDNNComplexFloat, T2<:CUDNNFloat}\n    # if y is from similar(), it may have NaNs, and beta*NaN will propagate.\n    if beta != 0\n        @. y = σ(alpha*(yr + im*yi) + bias + beta*y)\n    else\n        @. y = σ(alpha*(yr + im*yi) + bias)\n    end\n    return y\nend\n\nfunction conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;\n               alpha=1, beta=0, algo=-1) where T<:CUDNNFloat\n    if cudnnversion() < v\"6\"\n        all(x -> x == 1, dilation(cdims)) || error(\"Only dilation = 1 is supported in cuDNN version < 6\")\n    end\n    if algo != -1\n        @warn \"algo option has been deprecated, the fastest algo is computed automatically\" maxlog=1\n    end\n    d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)\n    cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y)\nend\n\n# Complex convolution with Gauss's trick (1 complex mul === 3 real mul):\n# Consider x = xr + im*xi, y = yr + im*yi,\n# so x*y = (xr*yr - xi*yi) + im*(xr*yi + xi*yr).\n# Let a = xr*yr,\n#     b = xi*yi,\n#     c = (xr + xi)*(yr + yi) = xr*yr + xr*yi + xi*yr + xi*yi.\n# Then,\n# x*y = (a - b) + im*(c - a - b).\n# Convolution is linear so this multiplication trick translates to convolution.\nfunction conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;\n               alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat\n    xr, xi = reim(x)\n    wr, wi = reim(w)\n    a = conv!(similar(real(y)), xr, wr, cdims; algo=algo)\n    b = conv!(similar(a), xi, wi, cdims; algo=algo)\n    c = conv!(similar(a), xr + xi, wr + wi, cdims; algo=algo)\n    return _complex!(y, a - b, c - a - b; alpha=alpha, beta=beta)\nend\n\n# (xr + im*xi) * w = xr*w + im*(xi*w)\nfunction conv!(y::DenseCuArray{T1}, x::DenseCuArray{T1}, w::DenseCuArray{T2}, cdims::DenseConvDims;\n               alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}\n    xr, xi = reim(x)\n    yr = conv!(similar(real(y)), xr, w, cdims; algo=algo)\n    yi = conv!(similar(yr), xi, w, cdims; algo=algo)\n    return _complex!(y, yr, yi; alpha=alpha, beta=beta)\nend\n\n# x * (wr + im*wi) = x*wr + im*(x*wi)\nfunction conv!(y::DenseCuArray{T1}, x::DenseCuArray{T2}, w::DenseCuArray{T1}, cdims::DenseConvDims;\n               alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}\n    wr, wi = reim(w)\n    yr = conv!(similar(real(y)), x, wr, cdims; algo=algo)\n    yi = conv!(similar(yr), x, wi, cdims; algo=algo)\n    return _complex!(y, yr, yi; alpha=alpha, beta=beta)\nend\n\nfunction conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},\n                        cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;\n                        z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat\n    if cudnnversion() < v\"6\"\n        all(x -> x == 1, dilation(cdims)) || error(\"Only dilation = 1 is supported in cuDNN version < 6\")\n    end\n    if algo != -1\n        @warn \"The algo option has been deprecated, the fastest algo is computed automatically\" maxlog=1\n    end\n    d, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)\n    # only relu and identity are supported by cudnnConvolutionForward!\n    activation = (σ == NNlib.relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY)\n    cudnnConvolutionForward!(y, w, x, d; z, bias, activation, alpha, beta)\n    if activation === CUDNN_ACTIVATION_IDENTITY && σ ∉ (nothing, identity)\n        @. y = σ(y)\n    end\n    return y\nend\n\nfunction conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},\n                        cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;\n                        z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat\n    xr, xi = reim(x)\n    wr, wi = reim(w)\n    a = conv!(similar(real(y)), xr, wr, cdims; alpha=1, beta=0, algo=algo)\n    b = conv!(similar(a), xi, wi, cdims; alpha=1, beta=0, algo=algo)\n    c = conv!(similar(a), xr + xi, wr + wi, cdims; alpha=1, beta=0, algo=algo)\n    return _complex!(y, a - b, c - a - b; bias=bias, alpha=alpha, beta=beta, σ=σ)\nend\n\nfunction ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},\n                     cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat\n    if cudnnversion() < v\"6\"\n        all(x -> x == 1, dilation(cdims)) || error(\"Only dilation = 1 is supported in cuDNN version < 6\")\n    end\n    if algo != -1\n        @warn \"The algo option has been deprecated, the fastest algo is computed automatically\" maxlog=1\n    end\n    alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);\n    convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx)\n    xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)\n    p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx, beta!=0)\n    with_workspace(p.memory) do workspace\n        cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx)\n    end\n    return depad(dx)\nend\n\nfunction ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},\n                     cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat\n    dyr, dyi = reim(dy)\n    wr, wi = reim(w)\n    # note: w is conjugated, i.e. wi is negated below\n    a = ∇conv_data!(similar(real(dx)), dyr, wr, cdims; alpha=1, beta=0, algo=algo)\n    b = ∇conv_data!(similar(a), dyi, -wi, cdims; alpha=1, beta=0, algo=algo)\n    c = ∇conv_data!(similar(a), dyr + dyi, wr - wi, cdims; alpha=1, beta=0, algo=algo)\n    return _complex!(dx, a - b, c - a - b; alpha=alpha, beta=beta)\nend\n\n# dx = (dyr + im*dyi)*w = dyr*w + im*(dyi*w)\nfunction ∇conv_data!(dx::DenseCuArray{T1}, dy::DenseCuArray{T1}, w::DenseCuArray{T2},\n                     cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}\n    dyr, dyi = reim(dy)\n    dxr = ∇conv_data!(similar(real(dx)), dyr, w, cdims; alpha=1, beta=0, algo=algo)\n    dxi = ∇conv_data!(similar(dxr), dyi, w, cdims; alpha=1, beta=0, algo=algo)\n    return _complex!(dx, dxr, dxi; alpha=alpha, beta=beta)\nend\n\nfunction ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},\n                       cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat\n    if cudnnversion() < v\"6\"\n        all(x -> x == 1, dilation(cdims)) || error(\"Only dilation = 1 is supported in cuDNN version < 6\")\n    end\n    if algo != -1\n        @warn \"The algo option has been deprecated, the fastest algo is computed automatically\" maxlog=1\n    end\n    alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);\n    convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)\n    xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)\n    p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw, beta!=0);\n    with_workspace(p.memory) do workspace\n        cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw);\n    end\n    return dw\nend\n\nfunction ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},\n                       cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat\n    xr, xi = reim(x)\n    dyr, dyi = reim(dy)\n    # note: x is conjugated, i.e. xi is negated below\n    a = ∇conv_filter!(similar(real(dw)), xr, dyr, cdims; alpha=1, beta=0, algo=algo)\n    b = ∇conv_filter!(similar(a), -xi, dyi, cdims; alpha=1, beta=0, algo=algo)\n    c = ∇conv_filter!(similar(a), xr - xi, dyr + dyi, cdims; alpha=1, beta=0, algo=algo)\n    return _complex!(dw, a - b, c - a - b; alpha=alpha, beta=beta)\nend\n\n# dw = x*(dyr + im*dyi) = x*dyr + im*(x*dyi)\nfunction ∇conv_filter!(dw::DenseCuArray{T1}, x::DenseCuArray{T2}, dy::DenseCuArray{T1},\n                       cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}\n    dyr, dyi = reim(dy)\n    dwr = ∇conv_filter!(similar(real(dw)), x, dyr, cdims; alpha=1, beta=0, algo=algo)\n    dwi = ∇conv_filter!(similar(dwr), x, dyi, cdims; alpha=1, beta=0, algo=algo)\n    return _complex!(dw, dwr, dwi; alpha=alpha, beta=beta)\nend\n\nfunction ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNFloat\n    alpha,beta = scalingParameter(T,alpha), scalingParameter(T,beta)\n    bDesc, yDesc = cudnnTensorDescriptor.((db,dy))\n    cudnnConvolutionBackwardBias(handle(), alpha, yDesc, dy, beta, bDesc, db)\n    return db\nend\n\nfunction ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNComplexFloat\n    dyr, dyi = reim(dy)\n    dbr = ∇conv_bias!(similar(real(db)), dyr; alpha=1, beta=0)\n    dbi = ∇conv_bias!(similar(dbr), dyi; alpha=1, beta=0)\n    return _complex!(db, dbr, dbi; alpha=alpha, beta=beta)\nend\n"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/pooling.jl",
    "content": "using cuDNN: cudnnPoolingMode_t, CUDNN_POOLING_MAX,\n             CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING,\n             cudnnPoolingForward!, pooldims, cudnnPoolingBackward\n\nimport NNlib: maxpool!, ∇maxpool!, meanpool!, ∇meanpool!\nimport cuDNN: cudnnPoolingDescriptor\n\nfunction cudnnPoolingDescriptor(pdims::PoolDims, x::DenseCuArray{T}, mode::cudnnPoolingMode_t) where T\n    window, padding, stride = NNlib.kernel_size(pdims), nnlibPadding(pdims), NNlib.stride(pdims)\n    nanOpt = CUDNN_NOT_PROPAGATE_NAN\n    cudnnPoolingDescriptor(mode, nanOpt, Cint(ndims(x)-2), pooldims(window,size(x)), pooldims(padding,size(x)), pooldims(stride,size(x)))\nend\n\nfunction maxpool!(y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat\n    d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_MAX)\n    cudnnPoolingForward!(y, x, d)\nend\n\nfunction ∇maxpool!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat\n    xDesc, yDesc = cudnnTensorDescriptor.((x, y))\n    d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_MAX)\n    alpha, beta = scalingParameter(T,1), scalingParameter(T,0)\n    cudnnPoolingBackward(handle(), d, alpha, yDesc, y, yDesc, dy, xDesc, x, beta, xDesc, dx)\n    return dx\nend\n\nfunction meanpool!(y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat\n    d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING)\n    cudnnPoolingForward!(y, x, d)\nend\n\nfunction ∇meanpool!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, y::DenseCuArray{T}, x::DenseCuArray{T}, pdims::PoolDims) where T<:CUDNNFloat\n    xDesc, yDesc = cudnnTensorDescriptor.((x, y))\n    d = cudnnPoolingDescriptor(pdims, x, CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING)\n    alpha, beta = scalingParameter(T,1), scalingParameter(T,0)\n    cudnnPoolingBackward(handle(), d, alpha, yDesc, y, yDesc, dy, xDesc, x, beta, xDesc, dx)\n    return dx\nend\n\n### Since CUDA.jl does not support 1D pooling, we have to convert to 2d\n\nadd1d(x) = reshape(x, 1, size(x)...)\n\nfunction fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D}\n    PoolDims{2, K + 1, S + 1, P + 2, D + 1}((1, NNlib.input_size(pdims)...),\n                                            (1, NNlib.kernel_size(pdims)...),\n                                            NNlib.channels_in(pdims),\n                                            (1, NNlib.stride(pdims)...),\n                                            (0, 0, NNlib.padding(pdims)...),\n                                            (1, NNlib.dilation(pdims)...))\nend\n\nfunction maxpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat\n    maxpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims))\n    return y\nend\n\nfunction meanpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat\n    meanpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims))\n    return y\nend\n\nfunction ∇maxpool!(dx::DenseCuArray{T,3}, dy::DenseCuArray{T,3}, y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat\n    ∇maxpool!(add1d(dx), add1d(dy), add1d(y), add1d(x), fix_pooldims_1d(pdims))\n    return dx\nend\n\nfunction ∇meanpool!(dx::DenseCuArray{T,3}, dy::DenseCuArray{T,3}, y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat\n    ∇meanpool!(add1d(dx), add1d(dy), add1d(y), add1d(x), fix_pooldims_1d(pdims))\n    return dx\nend\n\n\n"
  },
  {
    "path": "ext/NNlibCUDACUDNNExt/softmax.jl",
    "content": "import NNlib: softmax, softmax!, ∇softmax, ∇softmax!,\n              logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!\n\nusing cuDNN: CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL,\n             CUDNN_SOFTMAX_FAST, CUDNN_SOFTMAX_ACCURATE, cudnnSoftmaxForward!,\n             cudnnSoftmaxBackward\n\n# Softmax\n\n# @denizyuret: do not do inplace operations with softmax/logsoftmax when (1) cpu version is not, (2) one can use softmax!\nfunction softmax(x::T; dims=1) where {T<:DenseCuArray}\n    softmax!(similar(x), x; dims)\nend\n\nfunction ∇softmax(dy::T, x::T, y::T; dims=1) where {T<:DenseCuArray}\n    ∇softmax!(similar(x), dy, x, y; dims)\nend\n\nfunction logsoftmax(x::T; dims=1) where {T<:DenseCuArray}\n    logsoftmax!(similar(x), x; dims)\nend\n\nfunction ∇logsoftmax(dy::T, x::T, y::T; dims=1) where {T<:DenseCuArray}\n    ∇logsoftmax!(similar(x), dy, x, y; dims)\nend\n\n# @denizyuret: backup implementations for unsupported/slow size/dims combinations:\nfunction _softmax!(y::T, x::T; dims) where {T<:DenseCuArray}\n    y .= exp.(x .- maximum(x; dims))\n    y ./= sum(y; dims)\nend\n\nfunction _∇softmax!(dx::T, dy::T, x::T, y::T; dims) where {T<:DenseCuArray}\n    dx .= y .* (dy .- sum(dy .* y; dims))\nend\n\nfunction _logsoftmax!(y::T, x::T; dims) where {T<:DenseCuArray}\n    y .= x .- maximum(x; dims)\n    y .-= log.(sum(exp.(y); dims))\nend\n\nfunction _∇logsoftmax!(dx::T, dy::T, x::T, y::T; dims) where {T<:DenseCuArray}\n    dx .= dy .- sum(dy; dims) .* exp.(y)\nend\n\n# Trick by @norci to use cudnn for softmax dims args that are contiguous:\n# If dims=(dmin:dmax) then CUDNN_SOFTMAX_MODE_CHANNEL does the trick with reshape\n#    (1, prod(size(x)[1:dmin-1]), prod(size(x)[dmin:dmax]), :)\n# softmaxdims returns nothing when the backup implementation should be used.\n\nfunction softmaxdims(x, dims)\n    dims === Colon() && return (1, 1, length(x), 1)\n    mind,maxd = minimum(dims),maximum(dims)\n    all(i in dims for i in mind:maxd) || return nothing # cannot handle if not contiguous\n    stride = dimsize = 1\n    for i in 1:(mind-1); stride *= size(x,i); end # Using size(x,i) assumes trailing dims = 1, robust to maxd > ndims(x)\n    for i in mind:maxd; dimsize *= size(x,i); end\n    batchsize = length(x)÷(stride*dimsize)\n    # Here is a region where cudnn is slower, so we go with the backup:\n    batchsize == 1 && 64 <= stride <= 4096 && 64 <= dimsize <= 4096 && return nothing\n    return (1, stride, dimsize, batchsize)\nend\n\n# Determine softmax algo based on math_mode\n\nsoftmaxalgo() = (CUDA.math_mode()===CUDA.FAST_MATH ? CUDNN_SOFTMAX_FAST : CUDNN_SOFTMAX_ACCURATE)\n\n# Main implementations:\n\nfunction softmax!(y::T, x::T = y; dims=1) where {T<:DenseCuArray}\n    s = softmaxdims(x, dims)\n    s === nothing && return _softmax!(y, x; dims)\n    cudnnSoftmaxForward!(reshape(y,s), reshape(x,s); mode = CUDNN_SOFTMAX_MODE_CHANNEL, algo = softmaxalgo())\n    return y\nend\n\nfunction ∇softmax!(dx::T, dy::T, x::T, y::T; dims=1) where {R,T<:DenseCuArray{R}}\n    s = softmaxdims(x, dims)\n    s === nothing && return _∇softmax!(dx, dy, x, y; dims)\n    xDesc = cudnnTensorDescriptor(reshape(x,s))\n    alpha, beta = scalingParameter(R,1), scalingParameter(R,0)\n    cudnnSoftmaxBackward(handle(), softmaxalgo(), CUDNN_SOFTMAX_MODE_CHANNEL,\n                         alpha, xDesc, y, xDesc, dy, beta, xDesc, dx)\n    return dx\nend\n\nfunction logsoftmax!(y::T, x::T = y; dims=1) where {T<:DenseCuArray}\n    s = softmaxdims(x, dims)\n    s === nothing && return _logsoftmax!(y, x; dims)\n    cudnnSoftmaxForward!(reshape(y,s), reshape(x,s); mode = CUDNN_SOFTMAX_MODE_CHANNEL, algo = CUDNN_SOFTMAX_LOG)\n    return y\nend\n\nfunction ∇logsoftmax!(dx::T, dy::T, x::T, y::T; dims=1) where {R,T<:DenseCuArray{R}}\n    s = softmaxdims(x, dims)\n    s === nothing && return _∇logsoftmax!(dx, dy, x, y; dims)\n    xDesc = cudnnTensorDescriptor(reshape(x,s))\n    alpha, beta = scalingParameter(R,1), scalingParameter(R,0)\n    cudnnSoftmaxBackward(handle(), CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL,\n                         alpha, xDesc, y, xDesc, dy, beta, xDesc, dx)\n    return dx\nend\n"
  },
  {
    "path": "ext/NNlibCUDAExt/NNlibCUDAExt.jl",
    "content": "module NNlibCUDAExt\n\nusing NNlib\nusing CUDA\nusing Random, Statistics\n\ninclude(\"sampling.jl\")\ninclude(\"activations.jl\")\ninclude(\"batchedadjtrans.jl\")\ninclude(\"batchedmul.jl\")\ninclude(\"ctc.jl\")\ninclude(\"scatter.jl\")\ninclude(\"utils.jl\")\n\nend # module\n"
  },
  {
    "path": "ext/NNlibCUDAExt/activations.jl",
    "content": "# Activation functions\n\n# Some of activation functions need a wrapper for GPU support\n# https://github.com/JuliaGPU/CuArrays.jl/issues/614\n\n# @cufunc softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))\n\n# @cufunc logσ(x::Real) = -softplus(-x)\n\n# @cufunc function gelu(x::Real)\n#     p = oftype(x / 1, π)\n#     λ = oftype(x / 1, √(2 / p))\n#     α = oftype(x / 1, 0.044715)\n#     h = oftype(x / 1, 0.5)\n#     h * x * (one(x) + tanh(λ * (x + α * x^3)))\n# end\n\n# @cufunc lisht(x::Real) = x * tanh(x)\n\n# @cufunc logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2))\n\n# @cufunc mish(x::Real) = x * tanh(softplus(x))\n\n# @cufunc tanhshrink(x::Real) = x - tanh(x)\n"
  },
  {
    "path": "ext/NNlibCUDAExt/batchedadjtrans.jl",
    "content": "using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans\nusing Adapt\nusing Adapt: WrappedArray\n\nconst CuBatchedAdjoint{T} = BatchedAdjoint{T, <: CuArray{T}}\nconst CuBatchedTranspose{T} = BatchedTranspose{T, <: CuArray{T}}\nconst CuBatchedAdjOrTrans{T} = Union{CuBatchedAdjoint{T}, CuBatchedTranspose{T}}\nconst WrappedCuBatchedAdjOrTrans{T, N} = WrappedArray{T, N, CuBatchedAdjOrTrans{T}, CuBatchedAdjOrTrans{T}}\n\n\nBase.print_array(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))\nBase._show_nonempty(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)\nBase.show_vector(io::IO, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)\n\nBase.convert(::Type{T}, b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))\nBase.Array{T, N}(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))\nBase.collect(b::Union{CuBatchedAdjOrTrans, WrappedCuBatchedAdjOrTrans}) = collect(adapt(Array, b))\n"
  },
  {
    "path": "ext/NNlibCUDAExt/batchedmul.jl",
    "content": "# Batched matrix multiplication\n# 1st argument is produced by NNlib.storage_type(A)\nNNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =\n     CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C)\n\nBase.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} =\n    Base.unsafe_convert(CuPtr{T}, parent(A))\n"
  },
  {
    "path": "ext/NNlibCUDAExt/ctc.jl",
    "content": "# CTC loss moved from Flux.jl to NNlib\n\nimport NNlib: ctc_loss, ctc_alpha, ∇ctc_loss\n\n## GPU implementation\n\n# a port of the GPU kernels from Baidu's C++ warp-ctc package,\n# which itself is Copyright 2015-2016 Baidu USA LLC\n# and available under the Apache 2.0 license\n#\n# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0\n# GitHub: https://github.com/baidu-research/warp-ctc/\n# paper: https://arxiv.org/pdf/1512.02595.pdf\n\nconst MAX_THREADS = 256\n\n@inline function log_plus_f(p1, p2)\n  isinf(p1) && return p2\n  isinf(p2) && return p1\n  if p1 < p2\n    p1, p2 = p2, p1\n  end\n  return p1 + log(1+exp(p2 - p1))\nend\n\nfunction count_repeats(A)\n  repeats = 0\n  for (i,elem) in enumerate(A)\n    if i > 1 && A[i] == A[i-1]\n      repeats += 1\n    end\n  end\n  return repeats\nend\n\nfunction compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)\n\n  tid = threadIdx().x\n  L = labelSize\n  T = uttLength\n  S = length(labelsWithBlanks)\n\n  if L + repeats > T\n    return nothing\n  end\n  labels = labelsWithBlanks\n\n  # Corner-case checking\n  start = (L + repeats <= T) ? 0 : 1\n  last = S > 1 ? 2 : 1\n\n  # Fill in first column (time step)\n  i = tid\n  while i <= last - start\n    alpha[start+i, 1] = probs[labels[start+i], 1]\n    i += blockDim().x\n  end\n  sync_threads()\n\n  # Fill in coefficients for each time step\n  for t=2:T\n    # Corner-case checking\n    if tid == 1 && !(1 < S - 2*(T-t) - 1)\n      if start == 0\n        alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t]\n      elseif start == 1\n        alpha[1, t] = alpha[1, t-1]\n      end\n    end\n    sync_threads()\n\n    # Fill in coefficients for each label class in the target output sequence;\n    # each thread will process the calculations for one class\n    idx = tid+1\n    while idx <= S\n      prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])\n      if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]\n        prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])\n      end\n      if idx < S - 2*(T-t) - 1\n        alpha[idx, t] = -Inf32\n      else\n        alpha[idx, t] = prevSum + probs[labels[idx], t]\n      end\n      idx += blockDim().x\n    end\n    sync_threads()\n  end\n  return nothing\nend\n\nfunction compute_beta_and_grad_kernel(probs, labelSize, uttLength,\n                  repeatsInLabel, labelsWithBlanks,\n                  alphas, beta, output, accum,\n                  grad, blankLabel, loss)\n\n  tid = threadIdx().x\n  L = labelSize\n  T = uttLength\n  S = 2*L + 1\n  repeats = repeatsInLabel\n  labels = labelsWithBlanks\n\n  if (L+repeats) > T\n    return nothing\n  end\n\n  # Corner-case checking\n  start = S > 1 ? S-2 : 0\n  last = L + repeats < T ? S : S-1\n  sync_threads()\n  i = tid\n\n  # Calculate coefficients for last column (time step)\n  # then determine alpha and beta product\n  while i <= last - start\n    beta[i+start, T] = 0\n    output[i+start, T] = beta[i+start, T] + alphas[i+start, T]\n    i += blockDim().x\n  end\n  sync_threads()\n\n  # Fill in `accum` for last column (time step)\n  if tid == 1\n    for i=1:S\n      labelIdx = labels[i]\n      accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])\n    end\n  end\n  sync_threads()\n\n  # Fill in `grad` for last column (time step)\n  idx = tid\n  while idx <= size(grad, 1)\n    s = -Inf32\n    for i=1:S\n      s = log_plus_f(s, output[i, T])\n    end\n\n    # ∂L/∂a (where a is activation before logsoftmax)\n    grad[idx, T] = exp(probs[idx, T]) - exp(accum[idx, T] - s)\n    idx += blockDim().x\n  end\n  sync_threads()\n\n  # Fill in the rest of the coefficients\n  t = T-1\n  while t >= 1\n    if t < T\n      idx = tid\n      while idx <= S\n        nextSum = probs[labels[idx], t+1] + beta[idx, t+1]\n        if idx < S\n          nextSum = log_plus_f(nextSum,\n            probs[labels[idx+1], t+1] + beta[idx+1, t+1])\n        end\n        if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]\n          nextSum = log_plus_f(nextSum,\n            probs[labels[idx+2], t+1] + beta[idx + 2, t+1])\n        end\n        if idx > 2*t\n          beta[idx, t] = -Inf32\n        else\n          beta[idx, t] = nextSum\n        end\n        idx += blockDim().x\n      end\n      sync_threads()\n      idx = tid\n      while idx <= S\n        output[idx, t] = alphas[idx, t] + beta[idx, t]\n        idx += blockDim().x\n      end\n      sync_threads()\n    end\n    sync_threads()\n\n    # Calculate accumulated alpha-beta products for each label class for\n    # each time step; used in calculating gradients\n    if tid == 1\n      for i=1:S\n        labelIdx = labels[i]\n        accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])\n      end\n    end\n    sync_threads()\n    idx = tid\n\n    # Calculate gradients\n    while idx <= size(grad, 1)\n\n      # ∂L/∂a (where a is activation before logsoftmax)\n      grad[idx, t] = exp(probs[idx, t]) - exp(accum[idx, t] + loss)\n      idx += blockDim().x\n    end\n    sync_threads()\n    t -= 1\n    sync_threads()\n  end\n  return nothing\nend\n\nfunction ctc_alpha(ŷ::CuArray, y)\n  ŷ = logsoftmax(ŷ)\n  blank = size(ŷ, 1)\n  ycu = cu(y)\n  z′ = CUDA.fill(blank, 2 * length(y) + 1)\n  z′[eachindex(y) .* 2] .= ycu\n  T = size(ŷ, 2)\n  U′ = 2*length(y) + 1\n  alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T)\n  nRepeats = count_repeats(CUDA.adapt(Array, y))\n  nThreads = min(U′, MAX_THREADS)\n  @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank)\n  return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)\nend\n\nctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss\n\nfunction ∇ctc_loss(ŷ::CuArray, y, out)\n  loss, alphas, z′, ŷ, nRepeats = out\n  U′, T = size(alphas)\n  blank = size(ŷ, 1)\n  typed_zero = zero(eltype(ŷ))\n  betas = CUDA.fill(log(typed_zero), U′, T)\n  output = CUDA.fill(log(typed_zero), U′, T)\n  nThreads = min(U′, MAX_THREADS)\n  grads = CUDA.fill(log(typed_zero), size(ŷ))\n  accum = CUDA.fill(log(typed_zero), size(ŷ))\n  @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss)\n  return grads\nend\n"
  },
  {
    "path": "ext/NNlibCUDAExt/sampling.jl",
    "content": "@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 4}, value, ix, iy, c, n) where T\n    @inbounds CUDA.@atomic dx[ix, iy, c, n] += value\nend\n\nfunction grid_sample_kernel!(n_elem, output::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V}\n    index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x\n    if index < n_elem\n        iW, iH, iC, _ = size(input)\n        _, gW, gH, _ = size(grid)\n\n        w = index % gW + 1\n        h = (index ÷ gW) % gH + 1\n        n = index ÷ (gW * gH) + 1\n        NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, n, iW, iH, iC)\n    end\n    nothing\nend\n\nfunction ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 4}, dgrid::AbstractArray{V, 4}, Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V}\n    index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x\n    if index < n_elem\n        iW, iH, iC, _ = size(input)\n        _, gW, gH, _ = size(grid)\n\n        w = index % gW + 1\n        h = (index ÷ gW) % gH + 1\n        n = index ÷ (gW * gH) + 1\n        NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, n, iW, iH, iC)\n    end\n    nothing\nend\n\nfunction NNlib.grid_sample(x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V}\n    pad = Val(padding_mode)\n    _, _, xC, xN = size(x)\n    _, gW, gH, _ = size(grid)\n    n_elem = gW * gH * xN\n    y = similar(x, T, (gW, gH, xC, xN))\n\n    kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad)\n    config = launch_configuration(kernel.fun; max_threads=256)\n    threads = min(n_elem, config.threads)\n    blocks = cld(n_elem, threads)\n    kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks)\n    y\nend\n\nfunction NNlib.∇grid_sample(Δ::CuArray{T, 4}, x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V}\n    pad = Val(padding_mode)\n    xN = size(x, 4)\n    _, gW, gH, _ = size(grid)\n    n_elem = gW * gH * xN\n    dx, dgrid = CUDA.zeros(T, size(x)), similar(grid)\n\n    kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)\n    config = launch_configuration(kernel.fun; max_threads=256)\n    threads = min(n_elem, config.threads)\n    blocks = cld(n_elem, threads)\n    kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)\n    dx, dgrid\nend\n\n\n@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 5}, value, ix, iy, iz, c, n) where T\n    @inbounds CUDA.@atomic dx[ix, iy, iz, c, n] += value\nend\n\nfunction grid_sample_kernel!(n_elem, output::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V}\n    index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x\n    if index < n_elem\n        iW, iH,iD, iC, _ = size(input)\n        _, gW, gH, gD, _ = size(grid)\n\n        w = index % gW + 1\n        h = (index ÷ gW) % gH + 1\n        d = (index ÷ (gW * gH)) % gD + 1\n        n = index ÷ (gW * gH * gD) + 1\n        # n = index ÷ (gW * gH) + 1\n        # d = (index ÷ (gW * gH * n)) + 1\n\n        NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)\n    end\n    nothing\nend\n\nfunction ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 5}, dgrid::AbstractArray{V, 5}, Δ::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V}\n    index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x\n    if index < n_elem\n        iW, iH, iD, iC, _ = size(input)\n        _, gW, gH, gD, _ = size(grid)\n\n        w = index % gW + 1\n        h = (index ÷ gW) % gH + 1\n        d = (index ÷ (gW * gH)) % gD + 1\n        n = index ÷ (gW * gH * gD) + 1\n        # n = index ÷ (gW * gH) + 1\n        # d = (index ÷ (gW * gH * n)) + 1\n\n        NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)\n    end\n    nothing\nend\n\nfunction NNlib.grid_sample(x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V}\n    pad = Val(padding_mode)\n    _, _, _, xC, xN = size(x)\n    _, gW, gH, gD, _ = size(grid)\n    n_elem = gW * gH * gD * xN\n    y = similar(x, T, (gW, gH, gD, xC, xN))\n\n    kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad)\n    config = launch_configuration(kernel.fun; max_threads=256)\n    threads = min(n_elem, config.threads)\n    blocks = cld(n_elem, threads)\n    kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks)\n    y\nend\n\nfunction NNlib.∇grid_sample(Δ::CuArray{T, 5}, x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V}\n    pad = Val(padding_mode)\n    xN = size(x, 5)\n    _, gW, gH, gD, _ = size(grid)\n    n_elem = gW * gH * gD * xN\n    dx, dgrid = CUDA.zeros(T, size(x)), similar(grid)\n\n    kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)\n    config = launch_configuration(kernel.fun; max_threads=256)\n    threads = min(n_elem, config.threads)\n    blocks = cld(n_elem, threads)\n    kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)\n    dx, dgrid\nend"
  },
  {
    "path": "ext/NNlibCUDAExt/scatter.jl",
    "content": "# supported op: +, -, *, /, max, min, &, |, mean\n\n## TODO support sparse dst/src/idx\n## See issue https://github.com/FluxML/NNlib.jl/issues/647\n# import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector\n# const AnyCuSparseMatrix{Tv,Ti} = Union{\n#     AbstractCuSparseMatrix{Tv,Ti},\n#     CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix\n#     CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX\n#     CUDA.CuSparseMatrixCOO{Tv,Ti},\n#     }\n# const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}}\n\nfunction scatter_kernel!(op::OP, dst, src, idx) where OP\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= length(idx)\n        CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src[index])\n    end\n    return nothing\nend\n\nfunction scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}) where OP\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= length(idx)\n        li = Base._to_linear_index(dst, Tuple(idx[index])...)\n        CUDA.@atomic dst[li] = op(dst[li], src[index])\n    end\n    return nothing\nend\n\nfunction scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size) where OP\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= max_idx\n        j, k = divrem(index-1, max_dims_idx)\n        dims_i = CartesianIndices(dims_size)[k+1]\n        CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index])\n    end\n    return nothing\nend\n\nfunction scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},\n            max_idx, max_dims_idx, dims_size) where OP\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= max_idx\n        j, k = divrem(index-1, max_dims_idx)\n        dims_i = CartesianIndices(dims_size)[k+1]\n        li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...)\n        CUDA.@atomic dst[li] = op(dst[li], src[index])\n    end\n    return nothing\nend\n\n\nfunction NNlib.scatter!(op::OP, dst::AnyCuArray,\n        src::AnyCuArray,\n        idx::AnyCuArray) where OP\n    isempty(idx) && return dst\n    dims = NNlib.scatter_dims(dst, src, idx)\n    args = if dims == 0\n        max_idx = length(idx)\n        op, dst, src, idx\n    else\n        dims_size = size(dst)[1:dims]\n        max_dims_idx = prod(dims_size)\n        max_idx = max_dims_idx * length(idx)\n        op, dst, src, idx, max_idx, max_dims_idx, dims_size\n    end\n\n    kernel = @cuda launch=false scatter_kernel!(args...)\n    config = launch_configuration(kernel.fun; max_threads=256)\n    threads = min(max_idx, config.threads)\n    blocks = cld(max_idx, threads)\n    kernel(args...; threads=threads, blocks=blocks)\n    return dst\nend\n\nfunction NNlib.scatter!(op::typeof(mean), dst::AnyCuArray,\n        src::AnyCuArray,\n        idx::AnyCuArray)\n    Ns = NNlib.scatter!(+, zero(dst), one.(src), idx)\n    dst_ = NNlib.scatter!(+, zero(dst), src, idx)\n    dst .+= NNlib.safe_div.(dst_, Ns)\n    return dst\nend\n\n\n## Gradients\n\nfunction ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,\n    rev_idx, max_idx, T::Type{TT}) where {OP,TT}\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= max_idx\n        cart_j = CartesianIndices(idx)[index]\n        # get aggregating indeices, which is to be aggregated together, and itself index\n        inds = rev_idx[idx[cart_j]...]\n        # multiply all values to be aggregated but not itself\n        x = one(T)\n        for k in inds\n            x *= src[k]\n        end\n        x /= src[cart_j]\n        # apply `op` on `Δsrc[i, k]` and `x`\n        Δsrc[cart_j] = op(Δsrc[cart_j], x)\n    end\n    return nothing\nend\n\nfunction ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},\n            rev_idx, max_idx, T::Type{TT}) where {OP,TT}\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= max_idx\n        cart_j = CartesianIndices(idx)[index]\n        # get aggregating indeices, which is to be aggregated together, and itself index\n        inds = rev_idx[Tuple(idx[cart_j])...]\n        # multiply all values to be aggregated but not itself\n        x = one(T)\n        for k in inds\n            x *= src[k]\n        end\n        x /= src[cart_j]\n        # apply `op` on `Δsrc[i, k]` and `x`\n        Δsrc[cart_j] = op(Δsrc[cart_j], x)\n    end\n    return nothing\nend\n\nfunction ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,\n    rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= max_idx\n        i, j = fldmod1(index, max_dims_idx)\n        cart_i = CartesianIndices(idx)[i]\n        cart_j = pre_cart_idx[j]\n        # get aggregating indeices, which is to be aggregated together, and itself index\n        inds = rev_idx[idx[cart_i]...]\n        # multiply all values to be aggregated but not itself\n        x = one(T)\n        for k in inds\n            jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)\n            x *= src[jk]\n        end\n        x /= src[index]\n        # apply `op` on `Δsrc[i, k]` and `x`\n        Δsrc[index] = op(Δsrc[index], x)\n    end\n    return nothing\nend\n\nfunction ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},\n                rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= max_idx\n        i, j = fldmod1(index, max_dims_idx)\n        cart_i = CartesianIndices(idx)[i]\n        cart_j = pre_cart_idx[j]\n        # get aggregating indeices, which is to be aggregated together, and itself index\n        inds = rev_idx[Tuple(idx[cart_i])...]\n        # multiply all values to be aggregated but not itself\n        x = one(T)\n        for k in inds\n            jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)\n            x *= src[jk]\n        end\n        x /= src[index]\n        # apply `op` on `Δsrc[i, k]` and `x`\n        Δsrc[index] = op(Δsrc[index], x)\n    end\n    return nothing\nend\n\nfunction NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,\n    src::AnyCuArray,\n    idx::AnyCuArray)\n    dims = ndims(src) - ndims(idx)\n    Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)\n    rev_idx = NNlib.reverse_indices(idx)\n    rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))\n\n    if dims == 0\n        max_idx = length(idx)\n        args = op, Δsrc, src, idx, rev_idx, max_idx, eltype(src)\n    else\n        pre_cart_idx = CartesianIndices(axes(src)[1:dims])\n        max_dims_idx = length(pre_cart_idx)\n        max_idx = max_dims_idx * length(idx)\n        args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, eltype(src)\n    end\n\n    kernel = @cuda launch=false ∇scatter_src_kernel!(args...)\n    config = launch_configuration(kernel.fun; max_threads=256)\n    threads = min(max_idx, config.threads)\n    blocks = cld(max_idx, threads)\n    kernel(args...; threads=threads, blocks=blocks)\n\n    CUDA.unsafe_free!(rev_idx)\n    return Δsrc\nend\n"
  },
  {
    "path": "ext/NNlibCUDAExt/utils.jl",
    "content": "NNlib._rng_from_array(::CuArray) = CUDA.default_rng()\n\nNNlib._rng_compat_array(rng::CUDA.RNG, A::CuArray) = nothing\nNNlib._rng_compat_array(rng::AbstractRNG, A::CuArray) = throw(ArgumentError(\n    \"cannot use rng::$(typeof(rng)) with array::CuArray, only CUDA's own RNG type works\"))\n\nfunction divide_kernel!(xs, ys, max_idx)\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= max_idx\n        xs[index] = xs[index] / ys[index]\n    end\n    return nothing\nend\n\nfunction divide_kernel!(xs, counts, max_idx, max_dims_idx, dims_size)\n    index = threadIdx().x + (blockIdx().x - 1) * blockDim().x\n\n    @inbounds if index <= max_idx\n        j, k = divrem(index-1, max_dims_idx)\n        dims_i = Tuple(CartesianIndices(dims_size)[k+1])\n        CUDA.@atomic xs[dims_i..., j+1] = xs[dims_i..., j+1] / counts[j+1]\n    end\n    return nothing\nend\n\nfunction NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N\n    max_dims = NNlib.maximum_dims(idx)\n    T = CartesianIndex{N}\n    rev = Array{Vector{T}}(undef, max_dims...)\n    for i in eachindex(rev)\n        rev[i] = T[]\n    end\n    NNlib.reverse_indices!(rev, idx)\n    return map(cu, rev)\nend\n"
  },
  {
    "path": "ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl",
    "content": "module NNlibEnzymeCoreExt\n\nusing NNlib\nimport EnzymeCore\nusing Random\n\nusing EnzymeCore.EnzymeRules\n\nfor (name, dataname, filtername) in (\n                                     (typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!),\n                                     (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!),\n                                     (typeof(NNlib.∇conv_data!), NNlib.conv!, NNlib.∇conv_filter!),\n                                     (typeof(NNlib.∇conv_filter!), NNlib.∇conv_data!, NNlib.conv!),\n                                    )\n    @eval begin\n\n\t\tfunction EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT},\n\t\t                                                y::EnzymeCore.Annotation{<:AbstractArray{yT, N}},\n\t\t                                                x::EnzymeCore.Annotation{<:AbstractArray{xT, N}},\n\t\t                                                w::EnzymeCore.Annotation{<:AbstractArray{wT, N}},\n\t\t                                                cdims; kwargs...) where {RT, yT, xT, wT, N}\n\n\t\t    if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated\n\t\t        func.val(y.val, x.val, w.val, cdims.val; kwargs...)\n\t\t    end\n\n\t\t    primal = if EnzymeRules.needs_primal(config)\n\t\t        y.val\n\t\t    else\n\t\t        nothing\n\t\t    end\n\t\t    shadow = if EnzymeRules.needs_shadow(config)\n\t\t        y.dval\n\t\t    else\n\t\t        nothing\n\t\t    end\n\n\t\t    # Cache x if its overwritten and w is active (and thus required)\n\t\t    cache_x = ( EnzymeRules.overwritten(config)[3]\n\t\t                && !(typeof(w) <: EnzymeCore.Const)\n\t\t                && !(typeof(y) <: EnzymeCore.Const)\n\t\t                ) ? copy(x.val) : nothing\n\n\t\t    # Cache w if its overwritten and x is active (and thus required)\n\t\t    cache_w = ( EnzymeRules.overwritten(config)[4]\n\t\t                && !(typeof(x) <: EnzymeCore.Const)\n\t\t                && !(typeof(y) <: EnzymeCore.Const)\n\t\t                ) ? copy(w.val) : nothing\n\n\t\t    cache = (cache_x, cache_w)\n\n\t\t    return EnzymeRules.AugmentedReturn(primal, shadow, cache)\n\t\tend\n\n\t\tfunction EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache,\n\t\t                                                y::EnzymeCore.Annotation{<:AbstractArray{yT, N}},\n\t\t                                                x::EnzymeCore.Annotation{<:AbstractArray{xT, N}},\n\t\t                                                w::EnzymeCore.Annotation{<:AbstractArray{wT, N}},\n\t\t                                                cdims; kwargs...) where {RT, yT, xT, wT, N}\n\t\t    cache_x, cache_w = cache\n\n\t\t    # Don't cache x if not overwritten and w is active (and thus required)\n\t\t    if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)\n\t\t        if !EnzymeRules.overwritten(config)[3]\n\t\t            cache_x = x.val\n\t\t        end\n\t\t    end\n\n\t\t    # Don't cache w if not overwritten and x is active (and thus required)\n\t\t    if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)\n\t\t        if !EnzymeRules.overwritten(config)[4]\n\t\t            cache_w = w.val\n\t\t        end\n\t\t    end\n\n\t\t    dys = y.dval\n\t\t    dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval\n\t\t    dws = (typeof(w) <: EnzymeCore.Const) ? dys : w.dval\n\n\t\t    if EnzymeRules.width(config) == 1\n\t\t        dys = (dys,)\n\t\t        dxs = (dxs,)\n\t\t        dws = (dws,)\n\t\t    end\n\n\t\t    for (dy, dx, dw) in zip(dys, dxs, dws)\n\t\t        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val\n\n\t\t            if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val\n\t\t                # dx += grad wrt x.val\n\t\t                $dataname(dx, $(name != typeof(NNlib.∇conv_filter!) ? :dy : :cache_w), $(name != typeof(NNlib.∇conv_filter!) ? :cache_w : :dy), cdims.val; alpha=xT(1), beta=xT(1), kwargs...)\n\t\t            end\n\t\t            if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val\n\t\t                # dw += grad wrt w.val\n                        $filtername(dw, $(name != typeof(NNlib.∇conv_data!) ? :cache_x : :dy), $(name != typeof(NNlib.∇conv_data!) ? :dy : :cache_x), cdims.val; alpha=wT(1), beta=wT(1), kwargs...)\n\t\t            end\n\t\t            \n\t\t            dy .= 0\n\t\t        end\n\t\t    end\n\n\t\t    return (nothing, nothing, nothing, nothing)\n\t\tend\n\nend\nend\n\nfunction EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}\n\n    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated\n        func.val(dst.val, src.val, idx.val)\n    end\n\n    primal = if EnzymeRules.needs_primal(config)\n        dst.val\n    else\n        nothing\n    end\n    shadow = if EnzymeRules.needs_shadow(config)\n        dst.dval\n    else\n        nothing\n    end\n\n    # Cache idx if its overwritten\n    cache_idx = ( EnzymeRules.overwritten(config)[4]\n                    && !(typeof(src) <: EnzymeCore.Const)\n                    && !(typeof(dst) <: EnzymeCore.Const)\n                    ) ? copy(idx.val) : nothing\n\n    return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)\nend\n\nfunction EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}\n\n    # Don't cache idx if not overwritten\n    if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)\n        if !EnzymeRules.overwritten(config)[4]\n            cache_idx = idx.val\n        end\n    end\n\n    ddsts = dst.dval\n    dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval\n\n    if EnzymeRules.width(config) == 1\n        ddsts = (ddsts,)\n        dsrcs = (dsrcs,)\n    end\n\n    for (ddst, dsrc) in zip(ddsts, dsrcs)\n        if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val\n\n            if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val\n                NNlib.scatter!(+, dsrc, ddst, cache_idx)\n            end\n\n            ddst .= 0\n        end\n    end\n\n    return (nothing, nothing, nothing)\nend\n\n\n\nfunction EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}\n\n    @assert !(OutType <: EnzymeCore.Const)\n    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated\n        func.val(op.val, dst.val, src.val, idx.val)\n    end\n\n    primal = if EnzymeRules.needs_primal(config)\n        dst.val\n    else\n        nothing\n    end\n    shadow = if EnzymeRules.needs_shadow(config)\n        dst.dval\n    else\n        nothing\n    end\n\n    # Cache idx if its overwritten\n    cache_idx = ( EnzymeRules.overwritten(config)[4]\n                    && !(typeof(src) <: EnzymeCore.Const)\n                    && !(typeof(dst) <: EnzymeCore.Const)\n                    ) ? copy(idx.val) : nothing\n\n    return EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)\nend\n\nfunction EnzymeRules.reverse(config,\n\t\t\t\t\t\t\t\t\t\tfunc::EnzymeCore.Const{typeof(NNlib.scatter!)},\n\t\t\t\t\t\t\t\t\t\t::Type{RT},\n\t\t\t\t\t\t\t\t\t\tcache_idx,\n\t\t\t\t\t\t\t\t\t\top::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType,\n\t\t\t\t\t\t\t\t\t\tsrc,\n\t\t\t\t\t\t\t\t\t\tidx::EnzymeCore.Const) where {OutType, RT}\n\n    # Don't cache idx if not overwritten\n    if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)\n        if !EnzymeRules.overwritten(config)[4]\n            cache_idx = idx.val\n        end\n    end\n\n    ddsts = dst.dval\n    dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval\n\n    if EnzymeRules.width(config) == 1\n        ddsts = (ddsts,)\n        dsrcs = (dsrcs,)\n    end\n\n    for (ddst, dsrc) in zip(ddsts, dsrcs)\n        if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val\n\n            if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val\n\n                if eltype(typeof(op)) == typeof(+)\n                    dsrc .+= NNlib.gather(ddst, cache_idx)\n                else\n                    @assert eltype(typeof(op)) == typeof(-)\n                    dsrc .-= NNlib.gather(ddst, cache_idx)\n                end\n            end\n\n        end\n    end\n\n    return (nothing, nothing, nothing, nothing)\nend\n\n\n\nfor pool in [:maxpool, :meanpool, :lpnormpool]\n    pool! = Symbol(pool, :!)\n    ∇pool = Symbol(:∇, pool, :!)\n\n    @eval begin\n\nfunction EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT}\n\n    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated\n        func.val(y.val, x.val, dims.val; kwargs...)\n    end\n\n    primal = if EnzymeRules.needs_primal(config)\n        y.val\n    else\n        nothing\n    end\n    shadow = if EnzymeRules.needs_shadow(config)\n        y.dval\n    else\n        nothing\n    end\n\n    cache_y = ( EnzymeRules.overwritten(config)[2] \n                && !(typeof(x) <: EnzymeCore.Const) \n                && !(typeof(y) <: EnzymeCore.Const) \n                ) ? copy(y.val) : nothing\n\n    cache_x = ( EnzymeRules.overwritten(config)[3]\n                && !(typeof(x) <: EnzymeCore.Const) \n                && !(typeof(y) <: EnzymeCore.Const) \n                ) ? copy(x.val) : nothing\n\n    cache = (cache_y, cache_x)\n\n    return EnzymeRules.AugmentedReturn(primal, shadow, cache)\nend\n\nfunction EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT}\n    cache_y, cache_x = cache\n\n    # Don't cache y if not overwritten\n    if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)\n        if !EnzymeRules.overwritten(config)[2]\n            cache_y = y.val\n        end\n    end\n\n    # Don't cache x if not overwritten\n    if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)\n        if !EnzymeRules.overwritten(config)[3]\n            cache_x = x.val\n        end\n    end\n\n    dys = y.dval\n    dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval\n\n    if EnzymeRules.width(config) == 1\n        dys = (dys,)\n        dxs = (dxs,)\n    end\n\n    for (dy, dx) in zip(dys, dxs)\n        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val\n\n            if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val\n                NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...)\n            end\n\n            dy .= 0\n        end\n    end\n\n    return (nothing, nothing, nothing)\nend\n\nend\nend\n\nfunction EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT}\n\n    T = float(real(eltype(dst.val)))\n    val = convert(T, 1/(1-p.val))\n    keep = if dims.val isa Colon\n        similar(dst.val, T, size(dst.val))\n    else\n        similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val)))\n    end\n    rand!(rng.val, keep)\n    \n    keep = keep .> p.val\n\n    if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated\n        dst.val .= (keep .* val) .* src.val\n    end\n\n    primal = if EnzymeRules.needs_primal(config)\n        dst.val\n    else\n        nothing\n    end\n    shadow = if EnzymeRules.needs_shadow(config)\n        dst.dval\n    else\n        nothing\n    end\n\n    if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const\n        keep = nothing\n    end\n\n    return EnzymeRules.AugmentedReturn(primal, shadow, keep)\nend\n\nfunction EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT}\n    T = float(real(eltype(dst.val)))\n    val = convert(T, 1/(1-p.val))\n\n    ddsts = dst.dval\n    dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval\n\n    if EnzymeRules.width(config) == 1\n        ddsts = (ddsts,)\n        dsrcs = (dsrcs,)\n    end\n\n    for (ddst, dsrc) in zip(ddsts, dsrcs)\n        if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val\n\n            if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val\n                dsrc .+= (keep .* val) .* ddst\n            end\n\n            ddst .= 0\n        end\n    end\n\n    dp = if typeof(p) <: EnzymeCore.Active\n        typeof(p.val)(0)\n    else\n        nothing\n    end\n\n    return (nothing, nothing, nothing, dp, nothing)\nend\n\n\nend\n"
  },
  {
    "path": "ext/NNlibFFTWExt/NNlibFFTWExt.jl",
    "content": "module NNlibFFTWExt\n\nusing FFTW\nusing NNlib\nusing KernelAbstractions\n\ninclude(\"stft.jl\")\n\nend\n"
  },
  {
    "path": "ext/NNlibFFTWExt/stft.jl",
    "content": "function NNlib.stft(x;\n    n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,\n    center::Bool = true, normalized::Bool = false,\n)\n    kab = get_backend(x)\n    use_window = !isnothing(window)\n\n    use_window && kab != get_backend(window) && throw(ArgumentError(\n        \"`window` must be on the same device as stft input `x` ($kab), \\\n        instead: `$(get_backend(window))`.\"))\n    use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError(\n        \"Expected `0 < length(window) ≤ n_fft=$n_fft`, \\\n        but got `length(window)=$(length(window))`.\"))\n    hop_length < 0 && throw(ArgumentError(\n        \"Expected `hop_length > 0`, but got `hop_length=$hop_length`.\"))\n\n    # Pad window on both sides with `0` to `n_fft` length if needed.\n    if use_window && length(window) < n_fft\n        left = ((n_fft - length(window)) ÷ 2) + 1\n        tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)\n        tmp[left:left + length(window) - 1] .= window\n        window = tmp\n    end\n\n    if center\n        pad_amount = n_fft ÷ 2\n        x = pad_reflect(x, pad_amount; dims=1)\n    end\n\n    n = size(x, 1)\n    (0 < n_fft ≤ n) || throw(ArgumentError(\n        \"Expected `0 < n_fft ≤ size(x, 1)=$n`, but got `n_fft=$n_fft`.\"))\n\n    n_frames = 1 + (n - n_fft) ÷ hop_length\n\n    # time2col.\n    # Reshape `x` to (n_fft, n_frames, B) if needed.\n    # Each row in `n_frames` is shifted by `hop_length`.\n    if n_frames > 1\n        # TODO can be more efficient if we support something like torch.as_strided\n        ids = [\n            row + hop_length * col\n            for row in 1:n_fft, col in 0:(n_frames - 1)]\n        x = @inbounds x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]\n    end\n\n    region = 1\n    use_window && (x = x .* window;)\n    y = eltype(x) <: Complex ? fft(x, region) : rfft(x, region)\n\n    normalized && (y = y .* eltype(y)(n_fft^-0.5);)\n    return y\nend\n\nfunction NNlib.istft(y;\n    n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,\n    center::Bool = true, normalized::Bool = false,\n    return_complex::Bool = false,\n    original_length::Union{Nothing, Int} = nothing,\n)\n    kab = get_backend(y)\n    use_window = !isnothing(window)\n\n    use_window && kab != get_backend(window) && throw(ArgumentError(\n        \"`window` must be on the same device as istft input `y` ($kab), \\\n        instead: `$(get_backend(window))`.\"))\n    use_window && !(0 < length(window) ≤ n_fft) && throw(ArgumentError(\n        \"Expected `0 < length(window) ≤ n_fft=$n_fft`, \\\n        but got `length(window)=$(length(window))`.\"))\n    hop_length < 0 && throw(ArgumentError(\n        \"Expected `hop_length > 0`, but got `hop_length=$hop_length`.\"))\n\n    # TODO check `y` eltype is complex\n\n    n_frames = size(y, 2)\n\n    # Pad window on both sides with `0` to `n_fft` length if needed.\n    if use_window && length(window) < n_fft\n        left = ((n_fft - length(window)) ÷ 2) + 1\n        tmp = KernelAbstractions.zeros(kab, eltype(window), n_fft)\n        tmp[left:left + length(window) - 1] .= window\n        window = tmp\n    end\n\n    # Denormalize.\n    normalized && (y = y .* eltype(y)(n_fft^0.5);)\n\n    region = 1\n    x = return_complex ? ifft(y, region) : irfft(y, n_fft, region)\n\n    # De-apply window.\n    use_window && (x = x ./ window;)\n\n    # col2time.\n    expected_output_len = n_fft + hop_length * (n_frames - 1)\n\n    ids = Vector{Int}(undef, expected_output_len)\n    in_idx, out_idx = 0, 0\n    prev_e, v = 0, 0\n\n    for col in 0:(n_frames - 1)\n        for row in 1:n_fft\n            in_idx += 1\n            v = row + hop_length * col\n            v > prev_e || continue\n\n            out_idx += 1\n            ids[out_idx] = in_idx\n        end\n        prev_e = v\n    end\n\n    # In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).\n    nd = ntuple(_ -> Colon(), ndims(x) - 2)\n    ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)\n    x = @inbounds x[ids, nd...]\n\n    # Trim padding.\n    left = center ? (n_fft ÷ 2 + 1) : 1\n    right = if isnothing(original_length)\n        center ? (size(x, 1) - n_fft ÷ 2) : expected_output_len\n    else\n        left + original_length - 1\n    end\n    x = x[left:right, nd...]\n    return x\nend\n"
  },
  {
    "path": "ext/NNlibForwardDiffExt.jl",
    "content": "module NNlibForwardDiffExt\n\nusing ForwardDiff: ForwardDiff\nusing NNlib: NNlib\n\nNNlib.within_gradient(x::ForwardDiff.Dual) = true\nNNlib.within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true\n\nend\n"
  },
  {
    "path": "ext/NNlibMetalExt.jl",
    "content": "module NNlibMetalExt\n\n\nusing Metal: method_table, @device_override\nusing NNlib: NNlib\n\n@device_override NNlib.tanh_fast(x) = Base.FastMath.tanh_fast(x)\n\nend\n"
  },
  {
    "path": "ext/NNlibSpecialFunctionsExt.jl",
    "content": "module NNlibSpecialFunctionsExt\n\nusing NNlib: NNlib, oftf\nusing SpecialFunctions: erf\n\n# Full gelu (gelu_erf)\nNNlib.gelu_erf(x) = x/2*(1 + erf(x/sqrt(oftf(x,2))))\n\nfunction NNlib.deriv_gelu_erf(x)\n    SQRT2 = sqrt(oftf(x,2))\n    Φ = (1 + erf(x/SQRT2))/2\n    Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π))\nend\n\nend"
  },
  {
    "path": "src/NNlib.jl",
    "content": "module NNlib\n\nimport Atomix\nimport ChainRulesCore: rrule\n\nusing Base.Broadcast: broadcasted\nusing Base.Threads\nusing ChainRulesCore\nusing GPUArraysCore\nusing KernelAbstractions\nusing KernelAbstractions: @atomic\nusing LinearAlgebra\nusing LinearAlgebra.BLAS: @blasfunc, BlasInt\nusing LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose\nusing Random\nusing ScopedValues\nusing Statistics\nusing Statistics: mean\n\nconst Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}\n\n# internal. TODO: change to an approach where amount of threading is controlled, not just on/off\nconst ALLOW_SPAWNS = ScopedValue(true)\nshould_use_spawn() = Threads.nthreads(:default) > 1 && ALLOW_SPAWNS[]\n\"\"\"\n    @disallow_spawns ex\n\nDisallow NNlib to use `@spawn` on divisible workloads. i.e. within `conv` etc.\n\"\"\"\nmacro disallow_spawns(ex)\n    quote\n        @with ALLOW_SPAWNS => false $(esc(ex))\n    end\nend\n\n# Include APIs\ninclude(\"dim_helpers.jl\")\nexport ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims\n\ninclude(\"activations.jl\")\nfor f in ACTIVATIONS\n    @eval export $(f)\nend\nexport sigmoid, hardsigmoid, logsigmoid, thresholdrelu, gelu # Aliases\n\ninclude(\"attention.jl\")\nexport dot_product_attention, dot_product_attention_scores, make_causal_mask\n\ninclude(\"dropout.jl\")\nexport dropout, dropout!\n\ninclude(\"softmax.jl\")\nexport softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,\n    logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp\n\ninclude(\"batched/batchedadjtrans.jl\")\ninclude(\"batched/batchedmul.jl\")\nexport batched_mul, batched_mul!, ⊠,  batched_vec,\n    batched_transpose, batched_adjoint\n\ninclude(\"gemm.jl\")\nexport grid_sample, ∇grid_sample\n\ninclude(\"conv.jl\")\nexport conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,\n    ∇conv_filter!, depthwiseconv, depthwiseconv!,\n    ∇depthwiseconv_data, ∇depthwiseconv_data!,\n    ∇depthwiseconv_filter, ∇depthwiseconv_filter!\n\ninclude(\"conv_bias_act.jl\")\nexport conv_bias_act, conv_bias_act!\n\ninclude(\"bias_act.jl\")\nexport bias_act!\n\ninclude(\"fold.jl\")\n\ninclude(\"ctc.jl\")\nexport ctc_loss\n\ninclude(\"pooling.jl\")\nexport maxpool, maxpool!, meanpool, meanpool!, lpnormpool, lpnormpool!,\n    ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇lpnormpool, ∇lpnormpool!\n\ninclude(\"padding.jl\")\nexport pad_constant, pad_repeat, pad_reflect, pad_zeros, pad_symmetric, pad_circular\n\ninclude(\"upsample.jl\")\nexport upsample_nearest, ∇upsample_nearest,\n    upsample_linear, ∇upsample_linear,\n    upsample_bilinear, ∇upsample_bilinear,\n    upsample_trilinear, ∇upsample_trilinear,\n    pixel_shuffle\n\ninclude(\"gather.jl\")\ninclude(\"scatter.jl\")\ninclude(\"utils.jl\")\n\ninclude(\"sampling.jl\")\ninclude(\"functions.jl\")\n\ninclude(\"normalization.jl\")\n# export batchnorm, ∇batchnorm\n\n## Include implementations\ninclude(\"impl/padding_edges.jl\")\n\n# Direct implementations of convolutional and depthwise-convolutional algorithms\ninclude(\"impl/conv_direct.jl\")\ninclude(\"impl/depthwiseconv_direct.jl\")\n# im2col implementations of convolutional and depthwise-convolutional algorithms\ninclude(\"impl/conv_im2col.jl\")\ninclude(\"impl/depthwiseconv_im2col.jl\")\n\n# Direct implementations of pooling\ninclude(\"impl/pooling_direct.jl\")\ninclude(\"deprecations.jl\")\n\ninclude(\"rotation.jl\")\nexport imrotate, ∇imrotate\n\ninclude(\"audio/stft.jl\")\ninclude(\"audio/spectrogram.jl\")\ninclude(\"audio/mel.jl\")\nexport stft, istft, hann_window, hamming_window, spectrogram, melscale_filterbanks\n\nend # module NNlib\n"
  },
  {
    "path": "src/activations.jl",
    "content": "## Activation functions\n#\n# Some of activation functions have its wrapper function for GPU in NNlibCUDAExt.jl.\n# https://github.com/JuliaGPU/CuArrays.jl/issues/614\n\nACTIVATIONS = [\n    :σ, :hardσ, :hardtanh, :relu,\n    :leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_sigmoid, :gelu_erf, :swish, :hardswish, :selu,\n    :celu, :softplus, :softsign, :logσ, :logcosh,\n    :mish, :tanhshrink, :softshrink, :trelu, :lisht,\n    :tanh_fast, :sigmoid_fast,\n]\n\n# of type float (to allow for integer inputs)\noftf(x, y) = oftype(float(x), y)\n\n# oftype contains control flow on 1.10+, which can lead to type instabilities under AD \nfunction rrule(::typeof(oftf), x, y)\n    proj_y = ChainRulesCore.ProjectTo(y)\n    oftf_pullback(Δ) = (NoTangent(), NoTangent(), proj_y(Δ))\n    return oftf(x, y), oftf_pullback\nend\n\n\"\"\"\n    σ(x) = 1 / (1 + exp(-x))\n\nClassic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation\nfunction.\nUnicode `σ` can be entered as `\\\\sigma` then tab, in many editors.\nThe ascii name `sigmoid` is also exported.\n\nSee also [`sigmoid_fast`](@ref).\n\n```julia-repl\njulia> using UnicodePlots\n\njulia> lineplot(sigmoid, -5, 5, height=7)\n          ┌────────────────────────────────────────┐     \n        1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x)\n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡏⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡔⠋⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠊⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n        0 │⣀⣀⣀⣀⣀⣀⣀⠤⠤⠤⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n          └────────────────────────────────────────┘     \n          ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀     \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀     \n\njulia> sigmoid === σ\ntrue\n```\n\"\"\"\nfunction σ(x)\n    t = exp(-abs(x))\n    ifelse(x ≥ 0, inv(1 + t), t / (1 + t))\nend\n\nconst sigmoid = σ\n\n\"\"\"\n    hardσ(x) = max(0, min(1, (x + 3) / 6))\n\nPiecewise linear approximation of [`sigmoid`](@ref).\n\n```julia-repl\njulia> lineplot(hardsigmoid, -5, 5, height=7)\n          ┌────────────────────────────────────────┐         \n        1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠉⠉⠉⠉⠉⠉⠉⠉│ hardσ(x)\n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡠⠔⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⡗⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠋⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n        0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⠤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          └────────────────────────────────────────┘         \n          ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀         \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         \n\njulia> lineplot(sigmoid, -5, 5, height=7)\n          ┌────────────────────────────────────────┐     \n        1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x)\n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡏⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡔⠋⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠊⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n        0 │⣀⣀⣀⣀⣀⣀⣀⠤⠤⠤⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│     \n          └────────────────────────────────────────┘     \n          ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀     \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀     \n```\n\"\"\"\nhardσ(x) = clamp((x + 3) / 6, 0, 1)\n\n# https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html\n\nconst hardsigmoid = hardσ\n\n\"\"\"\n    logσ(x)\n\nReturn `log(σ(x))` which is computed in a numerically stable way.\n\n```julia-repl\njulia> lineplot(logsigmoid, -5, 5, height=7)\n           ┌────────────────────────────────────────┐        \n         0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡧⠤⠔⠒⠒⠒⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ logσ(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠉⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n   f(x)    │⠀⠀⠀⠀⠀⠀⢀⡤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⡤⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n        -6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           └────────────────────────────────────────┘        \n           ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀        \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        \n```\n\"\"\"\nlogσ(x) = -softplus(-x)\n\nconst logsigmoid = logσ\n\n\"\"\"\n    hardtanh(x) = max(-1, min(1, x))\n\nSegment-wise linear approximation of `tanh`, much cheaper to compute.\nSee [\"Large Scale Machine Learning\"](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf).\n\nSee also [`tanh_fast`](@ref).\n```julia-repl\njulia> lineplot(hardtanh, -2, 2, height=7)\n           ┌────────────────────────────────────────┐            \n         1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⠔⠋⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│ hardtanh(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡷⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│            \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠖⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠖⠋⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n        -1 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⠔⠋⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n           └────────────────────────────────────────┘            \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀            \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x\n\njulia> lineplot(tanh, -2, 2, height=7)\n           ┌────────────────────────────────────────┐        \n         1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠤⠒⠒⠒⠊⠉⠉⠉│ tanh(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡷⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠊⠁⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n        -1 │⣀⣀⣀⡠⠤⠤⠤⠖⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           └────────────────────────────────────────┘        \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        \n```\n\"\"\"\nhardtanh(x) = clamp(x, oftype(x, -1), oftype(x, 1))  # clamp(x, -1, 1) is type-stable, but would promote Int32, for which we have tests\n\n\"\"\"\n    relu(x) = max(0, x)\n\n[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))\nactivation function.\n\n```julia-repl\njulia> lineplot(relu, -2, 2, height=7)\n          ┌────────────────────────────────────────┐        \n        2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│ relu(x)\n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠊⠁⠀⠀│        \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀⠀⠀⠀⠀│        \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀│        \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⡠⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⡠⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n        0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⠔⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n          └────────────────────────────────────────┘        \n          ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        \n```\n\"\"\"\nrelu(x) = ifelse(x<0, zero(x), x)  # faster than max(zero(x), x), still preserves NaN\n\n\"\"\"\n    leakyrelu(x, a=0.01) = max(a*x, x)\n\nLeaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))\nactivation function.\nYou can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.\n\n```julia-repl\njulia> lineplot(x -> leakyrelu(x, 0.5), -2, 2, height=7)\n           ┌────────────────────────────────────────┐       \n         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ #42(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│       \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│       \n   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       \n           │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│       \n           │⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠤⠒⠒⠋⠉⠁⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       \n        -1 │⣀⣀⠤⠤⠒⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       \n           └────────────────────────────────────────┘       \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀       \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀       \n\njulia> leakyrelu(-10f0, 0.2)\n-2.0f0\n\njulia> leakyrelu(-10f0, 0.02)\n-0.5f0\n```\n\"\"\"\nleakyrelu(x, a=oftf(x, leakyrelu_a)) = ifelse(x>0, float(x), oftf(x, a*x))  # max(a*x, x) is 3x slower\n\nconst leakyrelu_a = 0.01  # also used in gradient below\n\n\"\"\"\n    relu6(x) = min(max(0, x), 6)\n\n[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))\nactivation function capped at 6.\nSee [\"Convolutional Deep Belief Networks\"](https://www.cs.toronto.edu/~kriz/conv-cifar10-aug2010.pdf) from CIFAR-10.\n\n```julia-repl\njulia> lineplot(relu6, -10, 10, height=7)\n          ┌────────────────────────────────────────┐         \n        6 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠎⠉⠉⠉⠉⠉⠉⠉⠉│ relu6(x)\n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⡤⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⡠⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡔⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n        0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⡧⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          └────────────────────────────────────────┘         \n          ⠀-10⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀10⠀         \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         \n```\n\"\"\"\nrelu6(x) = clamp(x, oftype(x, 0), oftype(x, 6))  # clamp promotes, but clamp(x, 0, 6) would promote x::Int32\n\n\"\"\"\n    rrelu(x, lo=1/8, hi=1/3) = max(a*x, x)\n    # where `a` is randomly sampled from uniform distribution `U(lo, hi)`\n\nRandomized Leaky Rectified Linear Unit activation function.\nSee [\"Empirical Evaluation of Rectified Activations\"](https://arxiv.org/abs/1505.00853)\nYou can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`.\n\n```julia-repl\njulia> lineplot(rrelu, -20, 10, height=7)\n            ┌────────────────────────────────────────┐         \n         10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ rrelu(x)\n            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀│         \n            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀│         \n   f(x)     │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⠤⣤⣤⢤⣤⣤⠤⠤⠤⢼⠮⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│         \n            │⣰⢀⣆⡄⣄⡄⡠⡰⠦⠷⡜⢢⠷⠳⠢⠊⠉⠉⠀⠀⠁⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n            │⠃⠉⠙⠘⠃⠈⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n        -10 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n            └────────────────────────────────────────┘         \n            ⠀-20⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀10⠀         \n            ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         \n\njulia> extrema(rrelu.(fill(-10f0, 1000)))\n(-3.3316886f0, -1.2548422f0)\n```\n\"\"\"\nfunction rrelu(x::T, l=oftf(x,1/8), u=oftf(x,1/3)) where T<:Number\n    a = (u - l) * rand(float(T)) + l\n    return leakyrelu(x, a)\nend\n\n\"\"\"\n    elu(x, α=1) = x > 0 ? x : α * (exp(x) - 1)\n\nExponential Linear Unit activation function.\nSee [\"Fast and Accurate Deep Network Learning by Exponential Linear Units\"](https://arxiv.org/abs/1511.07289).\nYou can also specify the coefficient explicitly, e.g. `elu(x, 1)`.\n\n```julia-repl\njulia> lineplot(elu, -2, 2, height=7)\n           ┌────────────────────────────────────────┐       \n         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ elu(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│       \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│       \n   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       \n           │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│       \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠔⠒⠋⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       \n        -1 │⠤⠤⠤⠤⠔⠒⠒⠒⠊⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│       \n           └────────────────────────────────────────┘       \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀       \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀       \n\njulia> elu(-10f0)\n-0.9999546f0\n\njulia> elu(-10f0, 2)\n-1.9999092f0\n```\n\"\"\"\nelu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1))\n\nderiv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α))\n\n\"\"\"\n    gelu_tanh(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))\n\nActivation function from [\"Gaussian Error Linear Units\"](https://arxiv.org/abs/1606.08415) using tanh approximation.\n\nThis implementation uses `tanh` which allows for better pattern matching and fusion in optimizing \ncompilers compared to the sigmoid-based implementation. For a potentially faster implementation \nthat uses `sigmoid_fast`, see [`gelu_sigmoid`](@ref).\n\n```julia-repl\njulia> lineplot(gelu_tanh, -2, 2, height=7)\n           ┌────────────────────────────────────────┐        \n         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu_tanh(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀│        \n   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⣤⣤⣤⣤⣤⣤⣤⣤⡤⠤⠤⠤⠤⠤⠤⠤⣤⣤⣤⡤⡧⠶⠶⠭⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠉⠉⠉⠉⠉⠉⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n        -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           └────────────────────────────────────────┘        \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        \n\njulia> lineplot(gelu_tanh, -5, 0, height=7);\n\njulia> lineplot!(ans, swish)\n             ┌────────────────────────────────────────┐         \n           0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu_tanh(x) \n             │⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x)\n             │⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│         \n   f(x)      │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│         \n             │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⣄⠀⠀⠀⠀⠀⢠⡞⠀⠀│         \n             │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⢄⣀⣀⡤⢣⠃⠀⠀│         \n        -0.2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠇⠀⠀⠀│         \n             └────────────────────────────────────────┘         \n             ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀         \n             ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         \n```\n\"\"\"\nfunction gelu_tanh(x)\n    α = oftf(x, 0.044715)\n    λ = oftf(x, gelu_λ)\n    x/2 * (1 + tanh_fast(λ * (x + α * x^3)))\nend\n\nconst gelu_λ = √(2 / π)\nconst gelu_2λ = √(8 / π)\n\nfunction deriv_gelu_tanh(x)\n    α = oftf(x, 0.044715)\n    α2 = oftf(x, 0.08943)\n    λ = oftf(x, gelu_λ)\n    x2 = x * x\n    t = muladd(x2, α, one(x))\n    z = λ * x * t\n    Ω = tanh_fast(z)\n    sech2 = 1 - Ω^2\n    (1 + Ω)/2 + x * λ * muladd(x2, α2, t) * sech2 / 2\nend\n\n\"\"\"\n    gelu_sigmoid(x) = x * σ(√(8/π) * (x + 0.044715x^3))\n\nAlternative implementation of the GELU activation function using `sigmoid` instead of `tanh`.\nThis is mathematically equivalent to [`gelu_tanh`](@ref) but may be faster in some cases.\n\nThe sigmoid-based implementation may prevent pattern matching and fusion in some optimizing \ncompilers. Use [`gelu_tanh`](@ref) if you need better compiler optimization support.\n\nSee [\"Gaussian Error Linear Units\"](https://arxiv.org/abs/1606.08415).\n\"\"\"\nfunction gelu_sigmoid(x)\n    α = oftf(x, 0.044715)\n    λλ = oftf(x, gelu_2λ)\n    x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x)))\nend\n\nfunction deriv_gelu_sigmoid(x)\n    α = oftf(x, 0.044715)\n    α2 = oftf(x, 0.08943)\n    λλ = oftf(x, gelu_2λ)\n    x2 = x * x\n    t = muladd(x2, α, one(x))\n    Ω = sigmoid_fast(λλ * x * t)\n    dσ = conj(Ω * (1 - Ω))\n    muladd(dσ * λλ * muladd(x2, α2, t), x, Ω)\nend\n\n\"\"\"\n    gelu_erf(x) = xΦ(x) = 0.5x * (1 + erf(x/√2))\n\nActivation function from [\"Gaussian Error Linear Units\"](https://arxiv.org/abs/1606.08415) without approximation.\nThe SpecialFunctions.jl package needs to be loaded to use this function.\n\"\"\"\nfunction gelu_erf end\nfunction deriv_gelu_erf end\n\n\"\"\"\n    gelu(x) = gelu_tanh(x)\n\nActivation function from [\"Gaussian Error Linear Units\"](https://arxiv.org/abs/1606.08415). \nSee [`gelu_tanh`](@ref).\n\"\"\"\nconst gelu = gelu_tanh\n# Need to alias the type as well to ensure serialization libraries still work\n# See https://github.com/FluxML/NNlib.jl/issues/631\nconst var\"#gelu\" = typeof(gelu_tanh)\nconst deriv_gelu = deriv_gelu_tanh\n\n\"\"\"\n    swish(x) = x * σ(x)\n\nSelf-gated activation function.\nSee [\"Swish: a Self-Gated Activation Function\"](https://arxiv.org/abs/1710.05941).\n\n```julia-repl\njulia> lineplot(swish, -2, 2, height=7)\n           ┌────────────────────────────────────────┐         \n         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤│ swish(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋⠁⠀│         \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀│         \n   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⢀⣀⡤⠔⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n           │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⣤⣤⡤⡧⠴⠶⠯⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│         \n           │⠉⠑⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n        -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n           └────────────────────────────────────────┘         \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀         \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         \n```\n\"\"\"\n@inline swish(x) = x * sigmoid_fast(x)\n\n\"\"\"\n    hardswish(x) = x * hardσ(x)\n\nHard-Swish activation function.\nSee [\"Searching for MobileNetV3\"](https://arxiv.org/abs/1905.02244).\n\n```julia-repl\njulia> lineplot(hardswish, -2, 5, height = 7)\n           ┌────────────────────────────────────────┐             \n         5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠒⠉│ hardswish(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠒⠉⠁⠀⠀⠀⠀│             \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│             \n   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│             \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│             \n           │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣤⣤⣖⣚⣉⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│             \n        -1 │⠉⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│             \n           └────────────────────────────────────────┘             \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀             \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀             \n\njulia> lineplot(hardswish, -4, 0, height = 7);\n\njulia> lineplot!(ans, swish)\n             ┌────────────────────────────────────────┐             \n           0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜│ hardswish(x)\n             │⠒⠒⠢⠤⢄⣀⡀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀│ swish(x)    \n             │⠀⠀⠀⠀⠀⠀⠈⠉⠑⠒⠦⢄⣘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡴⠃⠀⠀│             \n   f(x)      │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠑⡖⠦⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⢔⠏⠁⠀⠀⠀│             \n             │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠣⣄⠀⠉⠑⠒⠦⠤⢄⣀⣀⣀⣀⡠⠤⠖⣊⠕⠁⠀⠀⠀⠀⠀│             \n             │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⠤⡀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀│             \n        -0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠒⠢⠤⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│             \n             └────────────────────────────────────────┘             \n             ⠀-4⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀             \n             ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀             \n\njulia> hardswish.(-5:5)'\n1×11 adjoint(::Vector{Float64}) with eltype Float64:\n -0.0  -0.0  -0.0  -0.333333  -0.333333  0.0  0.666667  1.66667  3.0  4.0  5.0\n```\n\"\"\"\n@inline hardswish(x) = x * hardσ(x)\n\nderiv_hardswish(x) = ifelse(x < -3, oftf(x,0), ifelse(x > 3, oftf(x,1), x/3 + oftf(x,1/2)))\n\n\"\"\"\n    lisht(x) = x * tanh(x)\n\nActivation function from \n[\"LiSHT: Non-Parametric Linearly Scaled Hyperbolic Tangent ...\"](https://arxiv.org/abs/1901.05894)\n\n```julia-repl\njulia> lineplot(lisht, -2, 2, height=7)\n          ┌────────────────────────────────────────┐         \n        2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x)\n          │⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│         \n          │⠀⠀⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⠀⠀│         \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n        0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⠦⣄⣀⣀⣇⣀⣀⠤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          └────────────────────────────────────────┘         \n          ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀         \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         \n\njulia> lineplot!(ans, logcosh)\n          ┌────────────────────────────────────────┐           \n        2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x)  \n          │⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│ logcosh(x)\n          │⠢⣄⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⣀⠔│           \n   f(x)   │⠀⠈⠑⠢⣀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⣀⠔⠊⠁⠀│           \n          │⠀⠀⠀⠀⠀⠉⠢⢄⡀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⡠⠔⠋⠁⠀⠀⠀⠀│           \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠦⣌⡓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⣁⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀│           \n        0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠪⠷⣦⣄⣀⣀⣇⣀⣀⣤⠶⠕⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│           \n          └────────────────────────────────────────┘           \n          ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀           \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀           \n```\n\"\"\"\nlisht(x) = x * tanh_fast(x)\n\n\"\"\"\n    selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1))\n\n    λ ≈ 1.05070...\n    α ≈ 1.67326...\n\nScaled exponential linear units.\nSee [\"Self-Normalizing Neural Networks\"](https://arxiv.org/abs/1706.02515).\n\n```julia-repl\njulia> lineplot(selu, -3, 2, height=7)\n           ┌────────────────────────────────────────┐        \n         3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠊⠉⠀⠀⠀⠀│        \n   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⡠⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⣉⠭⠛⡏⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⡤⠤⠒⠊⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n        -2 │⠤⠤⠖⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           └────────────────────────────────────────┘        \n           ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        \n\njulia> selu(-10f0)\n-1.7580194f0\n```\n\"\"\"\nfunction selu(x)\n    λ = oftf(x, selu_λ)\n    α = oftf(x, selu_α)\n    λ * ifelse(x > 0, x, @fastmath α * (exp(x) - 1))\nend\n\nconst selu_λ = 1.0507009873554804934193349852946\nconst selu_α = 1.6732632423543772848170429916717\n\nfunction deriv_selu(Ω)\n    λ = oftf(Ω, selu_λ)\n    α = oftf(Ω, selu_α)\n    ifelse(Ω > 0, λ, Ω + α * λ)\nend\n\n\"\"\"\n    celu(x, α=1) = x ≥ 0 ? x : α * (exp(x/α) - 1)\n\nActivation function from [\"Continuously Differentiable Exponential Linear Units\"](https://arxiv.org/abs/1704.07483).\n\n```julia-repl\njulia> lineplot(celu, -2, 2, height=7)\n           ┌────────────────────────────────────────┐        \n         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉│ celu(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠉⠀⠀⠀⠀│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡤⠖⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│        \n   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡤⡧⠶⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠔⠒⠋⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n        -1 │⠤⠤⠤⠤⠔⠒⠒⠒⠊⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           └────────────────────────────────────────┘        \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀        \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        \n\njulia> celu(-10f0)\n-0.9999546f0\n```\n\"\"\"\ncelu(x, α=1) = ifelse(x ≥ 0, float(x), oftf(x,α) * (exp(x/oftf(x,α)) - 1))\n\nderiv_celu(Ω, α=1) = ifelse(Ω > 0, oftf(Ω, 1), Ω / oftf(Ω, α) + 1)\n\n\"\"\"\n    trelu(x, theta=1) = x > theta ? x : 0\n\nThreshold gated rectified linear activation function.\nSee [\"Zero-bias autoencoders and the benefits of co-adapting features\"](https://arxiv.org/abs/1402.3337)\n\n```julia-repl\njulia> lineplot(trelu, -2, 4, height=7)\n          ┌────────────────────────────────────────┐         \n        4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋│ trelu(x)\n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋⠁⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀│         \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠴⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⡏⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n        0 │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣀⣀⣀⣀⣀⣀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│         \n          └────────────────────────────────────────┘         \n          ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀4⠀         \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀         \n```\n\"\"\"\ntrelu(x, theta=1) = ifelse(x <= theta, zero(x), x)\n\nconst thresholdrelu = trelu\n\n\"\"\"\n    softsign(x) = x / (1 + |x|)\n\nSee [\"Quadratic Polynomials Learn Better Image Features\"](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205) (2009).\n\n```julia-repl\njulia> lineplot(softsign, -5, 5, height=7)\n           ┌────────────────────────────────────────┐            \n         1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⣀⣀⠤⠤⠤⠤⠤│ softsign(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⣀⡤⠖⠒⠋⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⡔⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡯⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│            \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠒⠋⠁⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n        -1 │⠒⠒⠒⠒⠒⠊⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n           └────────────────────────────────────────┘            \n           ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀            \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀            \n\njulia> lineplot!(ans, tanh)\n           ┌────────────────────────────────────────┐            \n         1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡤⠖⠊⠉⠉⠉⣉⣉⣉⣉⣉⠭⠭⠭⠭⠭│ softsign(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⡔⣃⡤⠖⠒⠋⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanh(x)    \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣧⡞⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⡯⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│            \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡴⠃⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠒⢋⠕⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n        -1 │⣒⣒⣒⣒⣒⣊⣉⣉⣉⣉⣁⣀⣀⡠⠤⠒⠁⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n           └────────────────────────────────────────┘            \n           ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀            \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀            \n\njulia> softsign(1f0)\n0.5f0\n\njulia> softsign(100f0)\n0.990099f0\n```\n\"\"\"\nsoftsign(x) = x / (1 + abs(x))\n\nderiv_softsign(x) = 1 / (1 + abs(x))^2\n\n\"\"\"\n    softplus(x) = log(exp(x) + 1)\n\nSee [\"Deep Sparse Rectifier Neural Networks\"](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf), JMLR 2011.\n\n```julia-repl\njulia> lineplot(softplus, -3, 3, height=7)\n          ┌────────────────────────────────────────┐            \n        4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x)\n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠│            \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀│            \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠊⠁⠀⠀⠀⠀⠀│            \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⠤⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡧⠤⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n        0 │⣀⣀⣀⣀⣀⣀⣀⡠⠤⠤⠤⠤⠔⠒⠒⠚⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n          └────────────────────────────────────────┘            \n          ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀            \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀            \n\njulia> lineplot!(ans, relu)\n          ┌────────────────────────────────────────┐            \n        4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ softplus(x)\n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠│ relu(x)    \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠⡴⠞⠋⠁│            \n   f(x)   │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣤⡴⠞⠋⠁⠀⠀⠀⠀│            \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⢤⡲⠝⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀│            \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡧⠤⠒⠊⣉⠥⠚⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n        0 │⣀⣀⣀⣀⣀⣀⣀⣠⣤⣤⣤⣤⣔⣒⣒⣚⣉⣉⣁⣀⣇⠴⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│            \n          └────────────────────────────────────────┘            \n          ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀            \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀            \n\njulia> softplus(16f0)\n16.0f0\n```\n\"\"\"\nsoftplus(x) = log1p(exp(-abs(x))) + relu(x)\n\n\"\"\"\n    logcosh(x)\n\nReturn `log(cosh(x))` which is computed in a numerically stable way.\n\n```julia-repl\njulia> lineplot(logcosh, -5, 5, height=7)\n          ┌────────────────────────────────────────┐           \n        5 │⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ logcosh(x)\n          │⠉⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠋│           \n          │⠀⠀⠀⠑⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀│           \n   f(x)   │⠀⠀⠀⠀⠀⠀⠑⠦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠊⠁⠀⠀⠀⠀⠀│           \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀│           \n          │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⠦⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡤⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│           \n        0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠢⢄⣀⣀⣇⣀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│           \n          └────────────────────────────────────────┘           \n          ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀           \n          ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀           \n```\n\"\"\"\nlogcosh(x) = x + softplus(-2x) - oftf(x, log2)\n\nconst log2 = log(2)\n\n\"\"\"\n    mish(x) = x * tanh(softplus(x))\n\nActivation function from [\"Mish: A Self Regularized Non-Monotonic Neural Activation Function\"](https://arxiv.org/abs/1908.08681).\n\n```julia-repl\njulia> lineplot(mish, -5, 5, height=7)\n           ┌────────────────────────────────────────┐        \n         5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠖⠋│ mish(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠒⠁⠀⠀⠀│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠔⠋⠁⠀⠀⠀⠀⠀⠀│        \n   f(x)    │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⡠⠖⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           │⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣧⣔⣊⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│        \n        -1 │⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│        \n           └────────────────────────────────────────┘        \n           ⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀        \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀        \n```\n\"\"\"\nmish(x) = x * tanh(softplus(x))\n\n\"\"\"\n    tanhshrink(x) = x - tanh(x)\n\nSee [\"Tanhshrink Activation Function\"](https://www.gabormelli.com/RKB/Tanhshrink_Activation_Function).\n\n```julia-repl\njulia> lineplot(tanhshrink, -3, 3, height=7)\n           ┌────────────────────────────────────────┐              \n         3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ tanhshrink(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠊│              \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⣀⡠⠤⠒⠊⠉⠁⠀⠀⠀⠀│              \n   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⢤⣤⡤⠤⠤⠤⠤⠤⠤⡷⠶⠶⠶⠶⠶⠮⠭⠥⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│              \n           │⠀⠀⠀⠀⠀⣀⡠⠴⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n           │⡠⠴⠒⠊⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n        -3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n           └────────────────────────────────────────┘              \n           ⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀3⠀              \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀              \n\njulia> tanhshrink.((-10f0, 10f0))\n(-9.0f0, 9.0f0)\n```\n\"\"\"\ntanhshrink(x) = x - tanh_fast(x)\n\n\"\"\"\n    softshrink(x, λ=0.5) =\n        (x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))\n\nSee [\"Softshrink Activation Function\"](https://www.gabormelli.com/RKB/Softshrink_Activation_Function).\n\n```julia-repl\njulia> lineplot(softshrink, -2, 2, height=7)\n           ┌────────────────────────────────────────┐              \n         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠒⠉⠁│              \n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀│              \n   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⡤⠤⠤⠤⠤⠤⠤⡧⠤⠤⠤⠤⠶⠮⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│              \n           │⠀⠀⠀⠀⠀⠀⢀⣀⠤⠖⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n           │⠀⣀⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n        -2 │⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n           └────────────────────────────────────────┘              \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀              \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀              \n\njulia> lineplot!(ans, tanhshrink)\n           ┌────────────────────────────────────────┐              \n         2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ softshrink(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠒⣉⡡│ tanhshrink(x)\n           │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⣒⣋⠥⠤⠒⠊⠉⠁⠀│              \n   f(x)    │⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⣤⣤⡤⠤⠤⠤⠤⠤⠤⡷⠶⠶⠶⠶⠶⠾⠿⠯⠭⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤│              \n           │⠀⢀⣀⡠⠤⠖⢒⣋⠭⠗⠒⠉⠁⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n           │⠊⣉⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n        -2 │⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│              \n           └────────────────────────────────────────┘              \n           ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀              \n           ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀\n\njulia> softshrink.((-10f0, 10f0))\n(-9.5f0, 9.5f0)\n```\n\"\"\"\nfunction softshrink(x, λ = 0.5)\n    lo = x - oftf(x, λ)\n    hi = x + oftf(x, λ)\n    ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi)\nend\n\n# Define broadcasts for activation functions on arrays\nfor f in ACTIVATIONS\n  @eval $(f)(x::AbstractArray, args...) = $(f).(x, args...)\nend\n\n## Faster, less accurate, versions of some.\n\n\"\"\"\n    tanh_fast(x)\n\nThis is a faster but slighly less accurate version of `tanh`.\n\nWhere Julia's `tanh` function has an error under 2 eps, this\nmay be wrong by 5 eps, a reduction by less than one decimal digit. \n\nFor `x::Float32` this is usually about 10 times faster,\nwith a smaller speedup for `x::Float64`.\nFor any other number types, it just calls `tanh`.\n\nSee also [`sigmoid_fast`](@ref).\n\n```julia-repl\njulia> tanh(0.5f0)\n0.46211717f0\n\njulia> tanh_fast(0.5f0)\n0.46211714f0\n\njulia> hard_tanh(0.5f0)\n0.5f0\n```\n\"\"\"\n@inline function tanh_fast(x::Float32)\n    # This method added in NNlib.jl#345 by @mcabbott and @oscardssmith,\n    # with coeffiecients found using Remez.jl\n    x2 = abs2(x)\n    n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8))\n    d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7))\n    ifelse(x2 < 66f0, x * (n / d), sign(x))\nend\n\n@inline function tanh_fast(x::Float64)\n    exp2x = @fastmath exp(x + x)\n    y = (exp2x - 1) / (exp2x + 1) \n    # That has large errors near zero; using `expm1` would more accurate, but about as slow as `tanh`.\n    # Instead, we switch to a polynomial, which is very accurate within its range:\n    x2 = x * x\n    ypoly = x * evalpoly(x2, (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953))\n    ifelse(x2 > 900.0, sign(x), ifelse(x2 < 0.017, ypoly, y))\nend\n\n# These approximations are very badly behaved for Float16; none are fast.\n# They are also a bit slower with ForwardDiff.Dual numbers, let's use Base:\ntanh_fast(x::Number) = Base.tanh(x)\n\n\"\"\"\n    sigmoid_fast(x)\n\nThis is a faster, and very slightly less accurate, version of `sigmoid`.\nFor `x::Float32`, perhaps 3 times faster, and maximum errors 2 eps instead of 1.\n\nSee also [`tanh_fast`](@ref).\n\n```julia-repl\njulia> sigmoid(0.2f0)\n0.54983395f0\n\njulia> sigmoid_fast(0.2f0)\n0.54983395f0\n\njulia> hardσ(0.2f0)\n0.53333336f0\n```\n\"\"\"\nfunction sigmoid_fast(x::Real)\n    @static if VERSION ≥ v\"1.11-\"\n        @inline\n    end\n    t = @fastmath exp(-abs(x))\n    y = ifelse(x ≥ 0, inv(1 + t), t / (1 + t))\n    ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y))\nend\n# For x::Float32, this is not as quick as the rational tanh_fast(x) above,\n# but that polynomial has poor relative accuracy for negative x.\n\nsigmoid_fast(x::Float16) = sigmoid(x)  # sigmoid_fast is extremely badly behaved at large x\n\nfunction sigmoid_fast(x::Number)\n    Base.depwarn(\"sigmoid only makes sense on real numbers, got $(typeof(x))\", :sigmoid_fast)\n    sigmoid(x)\nend\n\n\"\"\"\n    NNlib.fast_act(f, [x::AbstractArray])\n\nReplaces `f == tanh` with [`tanh_fast`](@ref), etc.\n\nTakes an optional 2nd argument, so that you can disable\nthis replacement for some array or element types.\n\"\"\"\n@inline fast_act(f::F, ::AbstractArray = 1:0) where {F<:Function} = f\n@inline fast_act(::typeof(tanh), ::AbstractArray = 1:0) = tanh_fast\n@inline fast_act(::typeof(sigmoid), ::AbstractArray = 1:0) = sigmoid_fast\n\n## Define rrules for some activation functions, along with the\n## broadcasted rrule activation functions.\n\n## This is a performance hack specifically for Zygote, because it doesn't handle fused\n## broadcasts well; but it generally should be good (or at least harmless) for any AD, as\n## it saves ADing the broadcasting machinery.\n## Related Issue https://github.com/JuliaDiff/ChainRulesCore.jl/issues/271\n\n## TODO: add to the lists below all activations.\n\nUNARY_ACTS = [ # f, dfdx\n    ## In the same order as above!\n    (:σ,            :(conj(Ω * (1 - Ω)))),\n    (:hardσ,        :(ifelse((Ω>0)&(Ω<1), oftf(Ω, 1/6), oftf(Ω, 1)))),\n    (:logσ,         :(sigmoid_fast(-x))),\n    (:hardtanh,     :((Ω>-1) & (Ω<1))),\n    (:relu,         :(Ω > 0)),\n    (:leakyrelu,    :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, leakyrelu_a)))),\n    (:relu6,        :((Ω>0) & (Ω<6))),\n    # rrelu is random, can't write a rule.\n    (:elu,          :(deriv_elu(Ω))),\n    (:gelu_tanh,    :(deriv_gelu_tanh(x))),\n    (:gelu_sigmoid, :(deriv_gelu_sigmoid(x))),\n    (:gelu_erf,     :(deriv_gelu_erf(x))),\n    (:swish,        :(Ω + sigmoid_fast(x) * (1 - Ω))),\n    (:hardswish,    :(deriv_hardswish(x))),\n    # lisht\n    (:selu,         :(deriv_selu(Ω))),\n    (:celu,         :(deriv_celu(Ω))),\n    (:trelu,        :(Ω > 0)),\n    (:softsign,     :(deriv_softsign(x))),\n    (:softplus,     :(sigmoid_fast(x))),\n    # (:softplus,     :(1 - @fastmath exp(-Ω))),  # slightly faster, check accuracy?\n    # logcosh\n    # mish\n    (:tanhshrink,    :((x - Ω)^2)),\n    (:softshrink,    :(Ω != 0)),\n    ## Fast variants are the same!\n    (:tanh_fast,    :(conj(1 - Ω^2))),\n    (:sigmoid_fast, :(conj(Ω * (1 - Ω)))),\n]\n\nfor (f, dfdx) in UNARY_ACTS\n    @eval @scalar_rule($f(x), $dfdx)\n\n    pullback = Symbol(:broadcasted_, f, :_pullback)\n    @eval function rrule(::typeof(broadcasted),\n                         ::typeof($f), x::Union{Numeric, Broadcast.Broadcasted})\n        Ω = $f.(x)\n        function $pullback(dΩ)\n            x_thunk = InplaceableThunk(\n                dx -> @.(dx += dΩ * $dfdx),\n                @thunk @.(dΩ * $dfdx)\n            )\n            NoTangent(), NoTangent(), x_thunk\n        end\n        return Ω, $pullback\n    end\nend\n\n# NO_ACT_GRAD = ChainRulesCore.@not_implemented \"for simplicitly NNlib assumes the 2nd argument of this activation function is a constant\"\nNO_ACT_GRAD = NaN  ## Still reminds you not to use this, but is perhaps more GPU friendly.\n\nBINARY_ACTS = [ # f, dfdx1, dfdx2\n    ## In the same order as above!\n    (:leakyrelu,   :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, x2))), NO_ACT_GRAD),\n    (:elu,         :(deriv_elu(Ω, x2)),      NO_ACT_GRAD),\n    (:celu,        :(deriv_celu(Ω, x2)),     NO_ACT_GRAD),\n    (:trelu,       :(Ω > 0),                 ZeroTangent()),\n    (:softshrink,  :(Ω != 0),                NO_ACT_GRAD),\n]\n\nfor (f, dfdx1, dfdx2) in BINARY_ACTS\n    @eval @scalar_rule($f(x1, x2), ($dfdx1, $dfdx2))\n\n    pullback = Symbol(:broadcasted_, f, :_pullback_2arg)\n    @eval function rrule(::typeof(broadcasted),\n                         ::typeof($f), \n                         x1::Union{Numeric, Broadcast.Broadcasted}, x2::Number)\n        Ω = $f.(x1, x2)\n        ## Allowing x2::Array would allow size(Ω) != size(x1), which is not handled here:\n        $pullback(dΩ) = (NoTangent(), NoTangent(), @.(dΩ * $dfdx1), NO_ACT_GRAD)\n        return Ω, $pullback\n    end\nend\n"
  },
  {
    "path": "src/attention.jl",
    "content": "const AA3{T} = AbstractArray{T,3}\nconst AA4{T} = AbstractArray{T,4}\nconst AA{N,T} = AbstractArray{T,N}\n\n\"\"\"\n    dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])\n\nMultihead dot product attention used in transformer architectures.\n\nThe input arrays must have the first two dimensions given by the number of features\nand the sequence length, then an arbitrary number of batch dimensions or none.\n\n\nReturns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores\nof size `(kv_len, q_len, nheads, batch_size...)`.\n\nSee also [`dot_product_attention_scores`](@ref) if you only need the attention scores.\n\n# Arguments\n\n- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.\n- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.\n- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.\n- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.\n          It will be added to the attention scores before applying the softmax. Default `nothing`.\n- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.\n           Default `identity` (no dropout).\n- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.\n          The mask is applied to the attention scores just before the softmax.\n          See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`.\n- `nheads`: Number of heads to split the input arrays into. Default `1`.\n\n# Examples\n\n```julia\nq, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)\ny, α = dot_product_attention(q, k, v)\n```\n\"\"\"\nfunction dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N\n    batch_size = size(q)[3:end]\n    batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError(\"Batch dimensions have to be the same.\"))\n    q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))\n\n    x, α = dot_product_attention(q, k, v, args...; kws...)\n\n    x = reshape(x, size(x, 1), size(x, 2), batch_size...)\n    α = reshape(α, size(α)[1:3]..., batch_size...)\n    return x, α\nend\n\nfunction dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;\n            fdrop=identity, mask=nothing, nheads=1)\n\n    (all(size.((q, k, v), 1) .% nheads .== 0)) || throw(ArgumentError(\"\"\"\n    First dimension in query, key and value must be divisible by `nheads`.\n    Instead:\n    - size(q): $(size(q))\n    - size(k): $(size(k))\n    - size(v): $(size(v))\n    - nheads: $nheads\n    \"\"\"))\n    (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError(\"\"\"\n    Batch dimensions have to be the same. Instead:\n    - size(q): $(size(q))\n    - size(k): $(size(k))\n    - size(v): $(size(v))\n    \"\"\"))\n    size(q, 1) == size(k, 1) || throw(ArgumentError(\"\"\"\n    First dimension in query and key has to be the same. Instead:\n    - size(q): $(size(q))\n    - size(k): $(size(k))\n    \"\"\"))\n    size(k, 2) == size(v, 2) || throw(ArgumentError(\"\"\"\n    Second dimension in key and value has to be the same. Instead:\n    - size(k): $(size(k))\n    - size(v): $(size(v))\n    \"\"\"))\n\n    # Multihead attention. TODO create fastpath for singlehead attention.\n    q, k, v = split_heads.((q, k, v), nheads)\n    x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)\n    return join_heads(x), α\nend\n\nfunction _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)\n    # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]\n    # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]\n    # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]\n\n    α = dot_product_attention_scores(q, k, bias; fdrop, mask)\n    # [α] = [kv_len, q_len, nheads, batch_size]\n\n    # The following permutedims and batched_mul are equivalent to\n    # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]\n    vt = permutedims(v, (1, 3, 2, 4))\n    x = batched_mul(vt, α)\n    x = permutedims(x, (1, 3, 2, 4))\n    # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size]\n    return x, α\nend\n\n\"\"\"\n    dot_product_attention_scores(query, key, [bias]; [fdrop, mask])\n\nReturn the attention scores for the [`dot_product_attention`](@ref).\nInput arrays must have dimensions\n`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.\n\nSee [`dot_product_attention`](@ref) for more details.\n\"\"\"\nfunction dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;\n            fdrop=identity, mask=nothing) where T\n\n    # The following permutedims and batched_mul are equivalent to\n    # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)\n    kt = permutedims(k, (3, 1, 2, 4))\n    qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1))\n    logits = batched_mul(kt, qt)\n    # [logits] = [kv_len, q_len, nheads, batch_size]\n\n    logits = apply_attn_bias(logits, bias)\n    logits = apply_attn_mask(logits, mask)\n\n    α = softmax(logits, dims=1)\n    return fdrop(α)\nend\n\napply_attn_bias(logits, bias::Nothing) = logits\n\napply_attn_bias(logits, bias) = logits .+ bias\n\napply_attn_mask(logits, mask::Nothing) = logits\n\nfunction apply_attn_mask(logits, mask)\n    neginf = typemin(eltype(logits))\n    ifelse.(mask, logits, neginf)\nend\n\n\n\"\"\"\n    make_causal_mask(x, dims=2)\n\nReturn a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.\nIts elements are set such that `m[i, j] == i ≤ j`.\n\nCan be used to mask the attention scores in [`dot_product_attention`](@ref).\n\"\"\"\nfunction make_causal_mask(x::AbstractArray; dims::Int=2)\n  len = size(x, dims)\n  mask = triu(trues_like(x, (len, len)))\n  return mask\nend\n\ntrues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)\nfalses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)\n\nsplit_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)\njoin_heads(x) = reshape(x, :, size(x)[3:end]...)\n\n@non_differentiable make_causal_mask(::Any...)\n@non_differentiable trues_like(::Any...)\n@non_differentiable falses_like(::Any...)\n"
  },
  {
    "path": "src/audio/mel.jl",
    "content": "\"\"\"\n    melscale_filterbanks(;\n        n_freqs::Int, n_mels::Int, sample_rate::Int,\n        fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2))\n\nCreate triangular Mel scale filter banks\n(ref: [Mel scale - Wikipedia](https://en.wikipedia.org/wiki/Mel_scale)).\nEach column is a filterbank that highlights its own frequency.\n\n# Arguments:\n\n- `n_freqs::Int`: Number of frequencies to highlight.\n- `n_mels::Int`: Number of mel filterbanks.\n- `sample_rate::Int`: Sample rate of the audio waveform.\n- `fmin::Float32`: Minimum frequency in Hz.\n- `fmax::Float32`: Maximum frequency in Hz.\n\n# Returns:\n\nFilterbank matrix of shape `(n_freqs, n_mels)` where each column is a filterbank.\n\n```jldoctest\njulia> n_mels = 8;\n\njulia> fb = melscale_filterbanks(; n_freqs=200, n_mels, sample_rate=16000);\n\njulia> plot = lineplot(fb[:, 1]);\n\njulia> for i in 2:n_mels\n           lineplot!(plot, fb[:, i])\n       end\n\njulia> plot\n     ┌────────────────────────────────────────┐\n   1 │⠀⡀⢸⠀⢸⠀⠀⣧⠀⠀⢸⡄⠀⠀⠀⣷⠀⠀⠀⠀⠀⣷⠀⠀⠀⠀⠀⠀⢀⣿⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⡇⢸⡆⢸⡇⠀⣿⠀⠀⡜⡇⠀⠀⢰⠋⡆⠀⠀⠀⢰⠁⡇⠀⠀⠀⠀⠀⡸⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⣿⢸⡇⡇⡇⢰⠹⡄⠀⡇⢱⠀⠀⢸⠀⢣⠀⠀⠀⡜⠀⢸⡀⠀⠀⠀⢀⠇⠀⠈⡇⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⣿⡇⡇⡇⡇⢸⠀⡇⢀⠇⠸⡀⠀⡇⠀⠸⡀⠀⢀⠇⠀⠀⢇⠀⠀⠀⡸⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀│\n     │⢠⢻⡇⡇⡇⢱⢸⠀⢇⢸⠀⠀⡇⢀⠇⠀⠀⡇⠀⢸⠀⠀⠀⠸⡀⠀⢠⠇⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀│\n     │⢸⢸⡇⢱⡇⢸⡇⠀⢸⢸⠀⠀⢣⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⢇⠀⡜⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀⠀│\n     │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⡎⠀⠀⠀⠈⣶⠁⠀⠀⠀⠀⠸⣤⠃⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀⠀⠀│\n     │⢸⠀⡇⢸⠀⠀⡇⠀⠀⡇⠀⠀⠀⡇⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⢱⡀⠀⠀⠀⠀│\n     │⢸⢸⡇⢸⠀⢸⡇⠀⢸⡇⠀⠀⢸⢇⠀⠀⠀⢀⠿⡀⠀⠀⠀⠀⢰⠛⡄⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀⠀⠀│\n     │⢸⢸⡇⡸⡇⢸⡇⠀⢸⢸⠀⠀⡜⢸⠀⠀⠀⢸⠀⡇⠀⠀⠀⠀⡎⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠘⡆⠀⠀⠀│\n     │⢸⢸⡇⡇⡇⡸⢸⠀⡎⢸⠀⠀⡇⠈⡆⠀⠀⡇⠀⢸⠀⠀⠀⢰⠁⠀⠘⡆⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀│\n     │⡇⢸⡇⡇⡇⡇⢸⠀⡇⠈⡆⢰⠁⠀⡇⠀⢰⠁⠀⠈⡆⠀⠀⡎⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀│\n     │⡇⢸⢸⡇⡇⡇⠸⣰⠃⠀⡇⡸⠀⠀⢸⠀⡜⠀⠀⠀⢣⠀⢸⠁⠀⠀⠀⠈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀│\n     │⡇⡇⢸⠇⢸⡇⠀⣿⠀⠀⢣⡇⠀⠀⠸⣄⠇⠀⠀⠀⠸⡀⡇⠀⠀⠀⠀⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄│\n   0 │⣇⣇⣸⣀⣸⣀⣀⣟⣀⣀⣸⣃⣀⣀⣀⣿⣀⣀⣀⣀⣀⣿⣀⣀⣀⣀⣀⣀⣈⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣱│\n     └────────────────────────────────────────┘\n     ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀200⠀\n```\n\"\"\"\nfunction melscale_filterbanks(;\n    n_freqs::Int, n_mels::Int, sample_rate::Int,\n    fmin::Float32 = 0f0, fmax::Float32 = Float32(sample_rate ÷ 2),\n)\n    mel_min, mel_max = _hz_to_mel(fmin), _hz_to_mel(fmax)\n    mel_points = range(mel_min, mel_max; length=n_mels + 2)\n\n    all_freqs = collect(range(0f0, Float32(sample_rate ÷ 2); length=n_freqs))\n    freq_points = _mel_to_hz.(mel_points)\n    filter_banks = _triangular_filterbanks(freq_points, all_freqs)\n\n    if any(maximum(filter_banks; dims=1) .≈ 0f0)\n        @warn \"\"\"At least one mel filterbank has all zero values.\n        The value for `n_mels=$n_mels` may be set too high.\n        Or the value for `n_freqs=$n_freqs` may be set too low.\n        \"\"\"\n    end\n    return filter_banks\nend\n\n_hz_to_mel(freq::T) where T = T(2595) * log10(T(1) + (freq / T(700)))\n\n_mel_to_hz(mel::T) where T = T(700) * (T(10)^(mel / T(2595)) - T(1))\n\n\"\"\"\n    _triangular_filterbanks(\n        freq_points::Vector{Float32}, all_freqs::Vector{Float32})\n\nCreate triangular filter banks.\n\n# Arguments:\n\n- `freq_points::Vector{Float32}`: Filter midpoints of size `n_filters`.\n- `all_freqs::Vector{Float32}`: Frequency points of size `n_freqs`.\n\n# Returns:\n\nArray of size `(n_freqs, n_filters)`.\n\"\"\"\nfunction _triangular_filterbanks(\n    freq_points::Vector{Float32}, all_freqs::Vector{Float32},\n)\n    diff = @view(freq_points[2:end]) .- @view(freq_points[1:end - 1])\n    slopes = transpose(reshape(freq_points, :, 1) .- reshape(all_freqs, 1, :))\n\n    down_slopes = -(@view(slopes[:, 1:end - 2]) ./ reshape(@view(diff[1:end - 1]), 1, :))\n    up_slopes = @view(slopes[:, 3:end]) ./ reshape(@view(diff[2:end]), 1, :)\n    return max.(0f0, min.(down_slopes, up_slopes))\nend\n"
  },
  {
    "path": "src/audio/spectrogram.jl",
    "content": "\"\"\"\n    spectrogram(waveform;\n        pad::Int = 0, n_fft::Int, hop_length::Int, window,\n        center::Bool = true, power::Real = 2.0,\n        normalized::Bool = false, window_normalized::Bool = false,\n    )\n\nCreate a spectrogram or a batch of spectrograms from a raw audio signal.\n\n# Arguments\n\n- `pad::Int`:\n    Then amount of padding to apply on both sides.\n- `window_normalized::Bool`:\n    Whether to normalize the waveform by the window’s L2 energy.\n- `power::Real`:\n    Exponent for the magnitude spectrogram (must be ≥ 0)\n    e.g., `1` for magnitude, `2` for power, etc.\n    If `0`, complex spectrum is returned instead.\n\nSee [`stft`](@ref) for other arguments.\n\n# Returns\n\nSpectrogram in the shape `(T, F, B)`, where\n`T` is the number of window hops and `F = n_fft ÷ 2 + 1`.\n\"\"\"\nfunction spectrogram(waveform::AbstractArray{T};\n    pad::Int = 0, n_fft::Int, hop_length::Int, window,\n    center::Bool = true, power::Real = 2.0,\n    normalized::Bool = false, window_normalized::Bool = false,\n) where T\n    pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);)\n\n    # Pack batch dimensions.\n    sz = size(waveform)\n    spec_ = stft(reshape(waveform, (sz[1], :));\n        n_fft, hop_length, window, center, normalized)\n    # Unpack batch dimensions.\n    spec = reshape(spec_, (size(spec_)[1:2]..., sz[2:end]...))\n    window_normalized && (spec = spec .* inv(norm(window));)\n\n    if power > 0\n        p = T(power)\n        spec = abs.(spec .+ eps(T)).^p\n    end\n    return spec\nend\n\n\"\"\"\n    power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)\n\nConvert a power spectrogram (amplitude squared) to decibel (dB) units.\n\n# Arguments\n\n- `s`: Input power.\n- `ref`: Scalar w.r.t. which the input is scaled.\n- `amin`: Minimum threshold for `s`.\n- `top_db`: Threshold the output at `top_db` below the peak:\n    `max.(s_db, maximum(s_db) - top_db)`.\n\n# Returns\n\n`s_db ~= 10 * log10(s) - 10 * log10(ref)`\n\"\"\"\nfunction power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)\n    log_spec = 10f0 .* (log10.(max.(amin, s)) .- log10.(max.(amin, ref)))\n    return max.(log_spec, maximum(log_spec) - top_db)\nend\n\n\"\"\"\n    db_to_power(s_db; ref::Real = 1f0)\n\nInverse of [`power_to_db`](@ref).\n\"\"\"\nfunction db_to_power(s_db; ref::Real = 1f0)\n    return ref .* 10f0.^(s_db .* 0.1f0)\nend\n"
  },
  {
    "path": "src/audio/stft.jl",
    "content": "\"\"\"\n    hamming_window(\n        window_length::Int, ::Type{T} = Float32; periodic::Bool = true,\n        α::T = T(0.54), β::T = T(0.46),\n    ) where T <: Real\n\nHamming window function\n(ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)).\nGeneralized version of `hann_window`.\n\n``w[n] = \\\\alpha - \\\\beta \\\\cos(\\\\frac{2 \\\\pi n}{N - 1})``\n\nWhere ``N`` is the window length.\n\n```julia-repl\njulia> lineplot(hamming_window(100); width=30, height=10)\n     ┌──────────────────────────────┐\n   1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠚⠉⠉⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠁⠀⠀⠀⠀⠀⠈⢢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⠀⠀⢰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⠀⣠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⡀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⢰⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⡰⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠱⡀⠀⠀⠀⠀│\n     │⠀⠀⠀⢀⠴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀│\n     │⠀⢀⡠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠳⣀⠀│\n   0 │⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉│\n     └──────────────────────────────┘\n     ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀\n```\n\n# Arguments:\n\n- `window_length::Int`: Size of the window.\n- `::Type{T}`: Elemet type of the window.\n\n# Keyword Arguments:\n\n- `periodic::Bool`: If `true` (default), returns a window to be used as\n    periodic function. If `false`, return a symmetric window.\n\n    Following always holds:\n\n```jldoctest\njulia> N = 256;\n\njulia> hamming_window(N; periodic=true) ≈ hamming_window(N + 1; periodic=false)[1:end - 1]\ntrue\n```\n- `α::Real`: Coefficient α in the equation above.\n- `β::Real`: Coefficient β in the equation above.\n\n# Returns:\n\nVector of length `window_length` and eltype `T`.\n\"\"\"\nfunction hamming_window(\n    window_length::Int, ::Type{T} = Float32; periodic::Bool = true,\n    α::T = T(0.54), β::T = T(0.46),\n) where T <: Real\n    window_length < 1 && throw(ArgumentError(\n        \"`window_length` must be > 0, instead: `$window_length`.\"))\n\n    n::T = ifelse(periodic, window_length, window_length - 1)\n    scale = T(2) * π / n\n    return [α - β * cos(scale * T(k)) for k in 0:(window_length - 1)]\nend\n\n\"\"\"\n    hann_window(\n        window_length::Int, ::Type{T} = Float32; periodic::Bool = true,\n    ) where T <: Real\n\nHann window function\n(ref: [Window function § Hann and Hamming windows - Wikipedia](https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows)).\n\n``w[n] = \\\\frac{1}{2}[1 - \\\\cos(\\\\frac{2 \\\\pi n}{N - 1})]``\n\nWhere ``N`` is the window length.\n\n```julia-repl\njulia> lineplot(hann_window(100); width=30, height=10)\n     ┌──────────────────────────────┐\n   1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠚⠉⠉⠉⠢⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡔⠁⠀⠀⠀⠀⠀⠘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⠀⠀⢀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⠀⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⠀⠀⡎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⠀⢀⠞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢆⠀⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⠀⢀⡜⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢇⠀⠀⠀⠀⠀│\n     │⠀⠀⠀⠀⢀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⠀⠀⠀⠀│\n     │⠀⠀⠀⢠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠣⡀⠀⠀│\n   0 │⣀⣀⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢤⣀│\n     └──────────────────────────────┘\n     ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀100⠀\n```\n\n# Arguments:\n\n- `window_length::Int`: Size of the window.\n- `::Type{T}`: Elemet type of the window.\n\n# Keyword Arguments:\n\n- `periodic::Bool`: If `true` (default), returns a window to be used as\n    periodic function. If `false`, return a symmetric window.\n\n    Following always holds:\n\n```jldoctest\njulia> N = 256;\n\njulia> hann_window(N; periodic=true) ≈ hann_window(N + 1; periodic=false)[1:end - 1]\ntrue\n\njulia> hann_window(N) ≈ hamming_window(N; α=0.5f0, β=0.5f0)\ntrue\n```\n\n# Returns:\n\nVector of length `window_length` and eltype `T`.\n\"\"\"\nfunction hann_window(\n    window_length::Int, ::Type{T} = Float32; periodic::Bool = true,\n) where T <: Real\n    hamming_window(window_length, T; periodic, α=T(0.5), β=T(0.5))\nend\n\n\"\"\"\n    stft(x;\n        n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,\n        center::Bool = true, normalized::Bool = false,\n    )\n\nShort-time Fourier transform (STFT).\n\nThe STFT computes the Fourier transform of short overlapping windows of the input,\ngiving frequency components of the signal as they change over time.\n\n``Y[\\\\omega, m] = \\\\sum_{k = 0}^{N - 1} \\\\text{window}[k] \\\\text{input}[m \\\\times \\\\text{hop length} + k] \\\\exp(-j \\\\frac{2 \\\\pi \\\\omega k}{\\\\text{n fft}})``\n\nwhere ``N`` is the window length,\n``\\\\omega`` is the frequency ``0 \\\\le \\\\omega < \\\\text{n fft}``\nand ``m`` is the index of the sliding window.\n\n# Arguments:\n\n- `x`: Input, must be either a 1D time sequence (`(L,)` shape)\n    or a 2D batch of time sequence (`(L, B)` shape).\n\n# Keyword Arguments:\n\n- `n_fft::Int`: Size of Fourier transform.\n- `hop_length::Int`: Distance between neighboring sliding window frames.\n- `window`: Optional window function to apply.\n    Must be 1D vector `0 < length(window) ≤ n_fft`.\n    If window is shorter than `n_fft`, it is padded with zeros on both sides.\n    If `nothing` (default), then no window is applied.\n- `center::Bool`: Whether to pad input on both sides so that ``t``-th frame\n    is centered at time ``t \\\\times \\\\text{hop length}``.\n    Padding is done with `pad_reflect` function.\n- `normalized::Bool`: Whether to return normalized STFT,\n    i.e. multiplied with ``\\\\text{n fft}^{-0.5}``.\n\n# Returns:\n\nComplex array of shape `(n_fft, n_frames, B)`,\nwhere `B` is the optional batch dimension.\n\"\"\"\nfunction stft end\n\n\"\"\"\n    istft(y;\n        n_fft::Int, hop_length::Int = n_fft ÷ 4, window = nothing,\n        center::Bool = true, normalized::Bool = false,\n        return_complex::Bool = false,\n        original_length::Union{Nothing, Int} = nothing,\n    )\n\nInverse Short-time Fourier Transform.\n\nReturn the least squares estimation of the original signal\n\n# Arguments:\n\n- `y`: Input complex array in the `(n_fft, n_frames, B)` shape.\n    Where `B` is the optional batch dimension.\n\n# Keyword Arguments:\n\n- `n_fft::Int`: Size of Fourier transform.\n- `hop_length::Int`: Distance between neighboring sliding window frames.\n- `window`: Window function that was applied to the input of `stft`.\n    If `nothing` (default), then no window was applied.\n- `center::Bool`: Whether input to `stft` was padded on both sides\n    so that ``t``-th frame is centered at time ``t \\\\times \\\\text{hop length}``.\n    Padding is done with `pad_reflect` function.\n- `normalized::Bool`: Whether input to `stft` was normalized.\n- `return_complex::Bool`: Whether the output should be complex,\n    or if the input should be assumed to derive from a real signal and window.\n- `original_length::Union{Nothing, Int}`: Optional size of the first dimension\n    of the input to `stft`. Helps restoring the exact `stft` input size.\n    Otherwise, the array might be a bit shorter.\n\"\"\"\nfunction istft end\n"
  },
  {
    "path": "src/batched/batchedadjtrans.jl",
    "content": "import Base: -\nimport Adapt: adapt_structure, adapt\n\n_batched_doc = \"\"\"\n    batched_transpose(A::AbstractArray{T,3})\n    batched_adjoint(A)\n\nEquivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`.\n\nThese exist to control how `batched_mul` behaves,\nas it operates on such matrix slices of an array with `ndims(A)==3`.\n\n`PermutedDimsArray(A, (2,1,3))` is equivalent to `batched_transpose(A)`,\nand is also understood by `batched_mul` (and more widely supported elsewhere).\n\n    BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}\n    BatchedAdjoint{T, S}\n\nLazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose` etc.\n\"\"\"\n\n@doc _batched_doc\nstruct BatchedTranspose{T, S} <: AbstractArray{T, 3}\n    parent::S\n    BatchedTranspose{T, S}(X::S) where {T, S} = new{T, S}(X)\nend\n\n@doc _batched_doc\nbatched_transpose(A::AbstractArray{T, 3}) where T = BatchedTranspose(A)\nbatched_transpose(A::BatchedTranspose) = A.parent\n\n@doc _batched_doc\nstruct BatchedAdjoint{T, S} <: AbstractArray{T, 3}\n    parent::S\n    BatchedAdjoint{T, S}(X::S) where {T, S} = new{T, S}(X)\nend\n\n@doc _batched_doc\nbatched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A)\nbatched_adjoint(A::BatchedAdjoint) = A.parent\n\nbatched_adjoint(A::BatchedTranspose{<:Real}) = A.parent\nbatched_transpose(A::BatchedAdjoint{<:Real}) = A.parent\nbatched_adjoint(A::PermutedDimsArray{<:Real,3,(2,1,3)}) = A.parent\nbatched_transpose(A::PermutedDimsArray{<:Number,3,(2,1,3)}) = A.parent\n# if you can't unwrap, put BatchedAdjoint outside (for dispatch):\nbatched_transpose(A::BatchedAdjoint{<:Complex}) = BatchedAdjoint(BatchedTranspose(A.parent))\n\nBatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)\nBatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)\n\nconst BatchedAdjOrTrans{T, S} = Union{BatchedTranspose{T, S}, BatchedAdjoint{T, S}}\n\nLinearAlgebra.wrapperop(A::BatchedAdjoint) = batched_adjoint\nLinearAlgebra.wrapperop(B::BatchedTranspose) = batched_transpose\n\n# AbstractArray Interface\nBase.length(A::BatchedAdjOrTrans) = length(A.parent)\nBase.size(m::BatchedAdjOrTrans) = (size(m.parent, 2), size(m.parent, 1), size(m.parent, 3))\nBase.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.parent, 3))\n\nBase.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian()\nBase.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k)\nBase.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k))\nBase.@propagate_inbounds Base.setindex!(m::BatchedTranspose, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)\nBase.@propagate_inbounds Base.setindex!(m::BatchedAdjoint, v, i::Int, j::Int, k::Int) = setindex!(m.parent, adjoint(v), j, i, k)\n\nBase.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims)\nBase.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims)\nBase.similar(A::BatchedAdjOrTrans, T::Type) = similar(A.parent, T, size(A))\nBase.similar(A::BatchedAdjOrTrans) = similar(A.parent, size(A))\n\nBase.parent(A::BatchedAdjOrTrans) = A.parent\n\n(-)(A::BatchedAdjoint)   = BatchedAdjoint(  -A.parent)\n(-)(A::BatchedTranspose) = BatchedTranspose(-A.parent)\n\n# C interface\nfunction Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})\n    sp = strides(A.parent)\n    (sp[2], sp[1], sp[3])\nend\n\nfunction Base.stride(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}, d::Integer)\n    d == 1 && return Base.stride(A.parent, 2)\n    d == 2 && return Base.stride(A.parent, 1)\n    Base.stride(A.parent, d)\nend\n\nBase.pointer(A::BatchedAdjOrTrans) = pointer(parent(A))\nBase.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =\n    Base.unsafe_convert(Ptr{T}, parent(A))\n\n# Gradients\nfunction rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3})\n    b_transpose_back(Δ) = (NoTangent(), batched_transpose(unthunk(Δ)))\n    batched_transpose(A), b_transpose_back\nend\nfunction rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3})\n    b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(unthunk(Δ)))\n    batched_adjoint(A), b_adjoint_back\nend\n\nadapt_structure(to, x::BatchedAdjoint) = BatchedAdjoint(adapt(to, parent(x)))\nadapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x)))\n\nBroadcast.BroadcastStyle(::Type{<:BatchedAdjOrTrans{T, S}}) where {T, S} = Broadcast.BroadcastStyle(S)\n"
  },
  {
    "path": "src/batched/batchedmul.jl",
    "content": "_unbatch(A) = A\n_unbatch(A::BatchedAdjOrTrans) = parent(A)\n\n\"\"\"\n    batched_mul(A, B) -> C\n    A ⊠ B  # \\\\boxtimes\n\nBatched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent \nany indices in the last dimensions.\n\nIf `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.\n\nTo transpose each matrix, apply `batched_transpose` to the array,\nor `batched_adjoint` for conjugate-transpose:\n\n```jldoctest\njulia> A, B = randn(2,5,17), randn(5,9,17);\n\njulia> A ⊠ B |> size\n(2, 9, 17)\n\njulia> batched_adjoint(A) |> size\n(5, 2, 17)\n\njulia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size\n(2, 9, 17)\n\njulia> A ⊠ randn(5,9,1) |> size\n(2, 9, 17)\n\njulia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3))\ntrue\n```\n\nThe equivalent `PermutedDimsArray` may be used in place of `batched_transpose`.\nOther permutations are also handled by BLAS,\nprovided that the batch index `k` is not the first dimension of the underlying array.\nThus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine.\n\nHowever, `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS,\nsince the batch dimension is the contiguous one: `stride(A,3) == 1`.\nThis will be copied, as doing so is faster than `batched_mul_generic!`.\n\nBoth this `copy` and `batched_mul_generic!` produce `@debug` messages,\nand setting for instance `ENV[\"JULIA_DEBUG\"] = NNlib` will display them.\n\"\"\"\nfunction batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}\n    batch_size = size(x)[3:end]\n    @assert batch_size == size(y)[3:end] \"batch size has to be the same for the two arrays.\"\n    x2 = reshape(x, size(x, 1), size(x, 2), :)\n    y2 = reshape(y, size(y, 1), size(y, 2), :)\n    z = batched_mul(x2, y2)\n    return reshape(z, size(z, 1), size(z, 2), batch_size...)\n  end\n\nfunction batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}\n    size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 ||\n        throw(DimensionMismatch(\"batch size mismatch: A != B\"))\n    _batched_mul(storage_typejoin(A, B), A, B)\nend\n\nconst ⊠ = batched_mul\n\nfunction _batched_mul(::Type, A, B)\n    T = promote_type(eltype(A), eltype(B))\n    C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))))\n    batched_mul!(C, A, B)\n    C\nend\nfunction _batched_mul(::Type{<:DenseArray{T}}, A, B) where {T<:BlasFloat}\n    C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))))\n    batched_mul!(C, _copy_if_faster(A), _copy_if_faster(B))\n    C\nend\n\nfunction _copy_if_faster(X::AbstractArray{<:Number, 3})\n    is_strided(X) || return X\n    if Base.stride(X, 3) == 1 && Base.stride(X, 1) != 1\n        @debug \"copying to avoid batched_mul_generic!\" typeof(X) size(X) strides(X)\n        return copy(X)\n    end\n    X\nend\nfunction _copy_if_faster(X::BatchedAdjoint{<:Complex})\n    Xbase = _unbatch(X)\n    is_strided(Xbase) || return X\n    if Base.stride(Xbase, 1) != 1\n        @debug \"copying to avoid batched_mul_generic!\" typeof(X) size(X) strides(_unbatch(X))\n        return copy(X) # or batched_adjoint(copy(Xbase)), may be better on GPU?\n    end\n    X\nend\n\n# Gradient, allowing that size(A,3)==1 means it's \"broadcasted\" out to size(B,3)\n\nfunction rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3})\n    function batched_mul_pullback(_Δ)\n        Δ = unthunk(_Δ)\n        Athunk = @thunk begin\n            tmp = batched_mul(Δ, batched_adjoint(B))\n            size(A,3) == 1 ? sum(tmp, dims=3) : tmp\n        end\n        Bthunk = @thunk begin\n            tmp = batched_mul(batched_adjoint(A), Δ)\n            size(B,3) == 1 ? sum(tmp, dims=3) : tmp\n        end\n        return (NoTangent(), Athunk, Bthunk)\n    end\n    batched_mul(A, B), batched_mul_pullback\nend\n\n\"\"\"\n    batched_mul(A::Array{T,3}, B::Matrix)\n    batched_mul(A::Matrix, B::Array{T,3})\n    A ⊠ B\n\nThis is always matrix-matrix multiplication, but\neither `A` or `B` may lack a batch index.\n\n* When `B` is a matrix, result has `C[:,:,k] == A[:,:,k] * B[:,:]` for all `k`.\n\n* When `A` is a matrix, then `C[:,:,k] == A[:,:] * B[:,:,k]`.\n  This can also be done by reshaping and calling `*`,\n  for instance `A ⊡ B` using TensorCore.jl, but is implemented here using\n  `batched_gemm` instead of `gemm`.\n\n```jldoctest\njulia> randn(16,8,32) ⊠ randn(8,4) |> size\n(16, 4, 32)\n\njulia> randn(16,8,32) ⊠ randn(8,4,1) |> size  # equivalent\n(16, 4, 32)\n\njulia> randn(16,8) ⊠ randn(8,4,32) |> size\n(16, 4, 32)\n```\n\nSee also `batched_vec` to regard `B` as a batch of vectors, `A[:,:,k] * B[:,k]`.\n\"\"\"\nbatched_mul(A::AbstractArray{T,3} where T, B::AbstractMatrix) = _semi_batched_mul(A,B)\n\n# Simplify signature of batched_mul by hiding dispatch on Adjoint etc:\n\n_semi_batched_mul(A::AbstractArray{<:Any,3}, B::AbstractMatrix) =\n    batched_mul(A, reshape(B, size(B)..., 1))\n\n_semi_batched_mul(A::AbstractArray{<:Any,3}, B::Adjoint{<:Number,<:AbstractMatrix}) =\n    batched_mul(A, batched_adjoint(reshape(parent(B), size(parent(B))..., 1)))\n\n_semi_batched_mul(A::AbstractArray{<:Any,3}, B::Transpose{<:Number,<:AbstractMatrix}) =\n    batched_mul(A, batched_transpose(reshape(parent(B), size(parent(B))..., 1)))\n\nbatched_mul(A::AbstractMatrix, B::AbstractArray{T,3} where T) = _semi_batched_mul(A,B)\n\n_semi_batched_mul(A::AbstractMatrix, B::AbstractArray{<:Any,3}) =\n    batched_mul(reshape(A, size(A)..., 1), B)\n\n_semi_batched_mul(A::Adjoint{<:Number,<:AbstractMatrix}, B::AbstractArray{<:Any,3}) =\n    batched_mul(batched_adjoint(reshape(parent(A), size(parent(A))..., 1)), B)\n\n_semi_batched_mul(A::Transpose{<:Number,<:AbstractMatrix}, B::AbstractArray{<:Any,3}) =\n    batched_mul(batched_transpose(reshape(parent(A), size(parent(A))..., 1)), B)\n\n\"\"\"\n    batched_vec(A::AbstractArray{T,3}, B::AbstractMatrix)\n    batched_vec(A::AbstractArray{T,3}, b::AbstractVector)\n    batched_vec(A::AbstractArray, B::AbstractArray)\n\nBatched matrix-vector multiplication. For the 3D case:\nthe result has `C[:,:,k] == A[:,:,k] * B[:,k]` for all `k`,\nor else `C[:,:,k] == A[:,:,k] * b` for `b::Vector`.\n\nFor the general N-D case where `ndims(A) == ndims(B) + 1`:\nthe result has `C[:,k...] == A[:,:,k...] * B[:,k...]` for all batch indices `k...`.\nThe batch dimensions must match: `size(A)[3:end] == size(B)[2:end]`.\n\nWith the same argument types, `batched_mul(A, B)` would regard `B` as\na fixed matrix, not a batch of vectors. Both reshape and then\ncall `batched_mul(::Array{T,3}, ::Array{T,3})`.\n\n```jldoctest\njulia> A, B, b = randn(16,8,32), randn(8,32), randn(8);\n\njulia> batched_vec(A,B) |> size\n(16, 32)\n\njulia> batched_vec(A,b) |> size\n(16, 32)\n\njulia> A4d, B3d = randn(16,8,10,32), randn(8,10,32);  # 4D and 3D arrays\n\njulia> batched_vec(A4d, B3d) |> size\n(16, 10, 32)\n```\n\"\"\"\nfunction batched_vec(A::AbstractArray, B::AbstractArray)\n    ndims(A) == ndims(B) + 1 || throw(DimensionMismatch(\n        \"batched_vec requires ndims(A) == ndims(B) + 1, got ndims(A)=$(ndims(A)) and ndims(B)=$(ndims(B))\"))\n    size(A)[3:end] == size(B)[2:end] || throw(DimensionMismatch(\n        \"batch dimensions must match: size(A)[3:end]=$(size(A)[3:end]) != size(B)[2:end]=$(size(B)[2:end])\"))\n    \n    # Reshape B to add a singleton dimension for matrix multiplication\n    B_reshaped = reshape(B, size(B, 1), 1, size(B)[2:end]...)\n    # Perform batched multiplication\n    C = batched_mul(A, B_reshaped)\n    # Remove the singleton dimension\n    return dropdims(C, dims=2)\nend\n\nbatched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) =\n    reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3))\n\n# If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway:\nbatched_vec(A::AbstractArray{T,3} where T, B::AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}) =\n    batched_vec(A, copy(B))\n\nbatched_vec(A::AbstractArray{T,3} where T, b::AbstractVector) =\n    reshape(batched_mul(A, reshape(b, length(b), 1, 1)), size(A,1), size(A,3))\n\n\n\"\"\"\n    batched_mul!(C, A, B) -> C\n    batched_mul!(C, A, B, α=1, β=0)\n\nIn-place batched matrix multiplication, equivalent to\n`mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`.\nIf `size(B,3) == 1` then every batch uses `B[:,:,1]` instead.\n\nThis will call `batched_gemm!` whenever possible. For real arrays this means that,\nfor `X ∈ [A,B,C]`, either `stride(X,1)==1` or `stride(X,2)==1`, the latter may\nbe caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`.\nUnlike `batched_mul` this will never make a copy.\n\nFor complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen.\nIn this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then\nonly `stride(AorB::BatchedAdjoint,2) == 1` is accepted.\n\"\"\"\nfunction batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3},\n        α::Number=one(T), β::Number=zero(T)) where {T}\n    _batched_mul!(storage_typejoin(C,A,B), C, A, B, α, β)\n    C\nend\n\n_batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β)\n\n_batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} =\n    _batched_try_gemm!(DT, C, A, B, α, β)\n\nfunction _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat}\n    alpha, beta = promote(α, β, zero(T))\n    alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)\n\n    are_strided(_unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β)\n    C isa StridedArray || return batched_mul_generic!(C, A, B, α, β)\n\n    blasA, transA = if A isa BatchedAdjoint && T <: Complex\n        Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β)\n        parent(A), 'C'\n    elseif Base.stride(A,2) == 1 && size(A,1) > 1\n        batched_transpose(A), 'T'\n    elseif Base.stride(A,1) == 1\n        A, 'N'\n    elseif Base.stride(A,2) == 1  # This is awful, but exhaustively tested. Issues 268, 282.\n        batched_transpose(A), 'T'\n    else\n        return batched_mul_generic!(C, A, B, α, β)\n    end\n\n    blasB, transB = if B isa BatchedAdjoint && T <: Complex\n        Base.stride(parent(B),1) == 1 || return batched_mul_generic!(C, A, B, α, β)\n        parent(B), 'C'\n    elseif Base.stride(B,2) == 1 && size(B,1) > 1\n        batched_transpose(B), 'T'\n    elseif Base.stride(B,1) == 1\n        B, 'N'\n    elseif Base.stride(B,2) == 1\n        batched_transpose(B), 'T'\n    else\n        return batched_mul_generic!(C, A, B, α, β)\n    end\n\n    _batched_gemm!(DT, transA, transB, alpha, blasA, blasB, beta, C)\n    C\nend\n\n_batched_gemm!(::Type{<:Array}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =\n    batched_gemm!(transA, transB, α, A, B, β, C)\n\n_BATCHED_LIST = [\n    (:(AbstractArray{<:Any, 3}), :identity),\n    (:BatchedTranspose,          :transpose),\n    (:BatchedAdjoint,            :adjoint),\n]\nfor (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST\n\n    @eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB,\n            α::Number=one(T), β::Number=zero(T)) where {T}\n\n        size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch(\"batch size mismatch: A != C\"))\n        size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch(\"batch size mismatch: B != C\"))\n        @debug \"calling fallback method for batched_mul!\" typeof(A) size(A) typeof(B) size(B) typeof(C)\n\n        Abase, Bbase = _unbatch(A), _unbatch(B)\n        sA, oA = size(A,3) == 1 ? (0,1) : (1,0)\n        sB, oB = size(B,3) == 1 ? (0,1) : (1,0)\n\n        @inbounds for k in 1:size(C,3)\n            @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β)\n        end\n        C\n    end\n\nend\n\n\"\"\"\n    storage_type(A) -> Type\n\nRemoves all wrappers to return the `Array` or `CuArray` (or whatever) type within.\n```\njulia> view(reshape(ones(10)',2,5),:, 3:4) |> storage_type\nArray{Float64,1}\n\njulia> reshape(sparse(rand(10)), 5,2) |> storage_type\nSparseVector{Float64,Int64}\n```\n\"\"\"\nfunction storage_type(A::AbstractArray)\n    P = parent(A)\n    typeof(A) === typeof(P) ? typeof(A) : storage_type(P)\nend\nstorage_type(A) = typeof(A)\n\n\"\"\"\n    storage_typejoin(A, B, C, ...) -> Type\n\nReduces with `Base.promote_typejoin`, in order that this conveys useful information\nfor dispatching to BLAS. It does not tell you what container to allocate:\n```\njulia> storage_typejoin(rand(2), rand(Float32, 2))\nArray{T,1} where T\n\njulia> eltype(ans) <: LinearAlgebra.BlasFloat\nfalse\n\njulia> storage_typejoin(rand(2), rand(2,3), rand(2,3,4))\nArray{Float64,N} where N\n```\n\"\"\"\nstorage_typejoin(A, Bs...) = Base.promote_typejoin(storage_type(A), storage_typejoin(Bs...))\nstorage_typejoin(A) = storage_type(A)\n\n\"\"\"\n    is_strided(A::AbstractArray) -> Bool\n\nThis generalises `A isa StridedArray` to treat wrappers like `A::PermutedDimsArray`,\nfor which it returns `is_strided(parent(A))`.\n\nIt returns `true` for `CuArray`s, and `PermutedDimsArray`s of those.\n\nOther wrappers (defined outside Base, LinearAlgebra) are assumed not to break\nstrided-ness, and hence also return `is_strided(parent(A))`.\nThis correctly handles things like `NamedDimsArray` wihch don't alter indexing.\nHowever, it's a little pessimistic in that e.g. a `view` of such a container will return\n`false`, even in cases where the same `view` of `parent(A)` would be a `StridedArray`.\n\"\"\"\nis_strided(A::StridedArray) = true\nis_strided(A) = false\nfunction is_strided(A::AbstractArray)\n    M = parentmodule(typeof(A))\n    if parent(A) === A # SparseMatrix, StaticArray, etc\n        false\n    elseif M === Base || M === Core || M ===LinearAlgebra\n        # bad reshapes, etc, plus Diagonal, UpperTriangular, etc.\n        false\n    else\n        is_strided(parent(A)) # PermutedDimsArray, NamedDimsArray\n    end\nend\n\nis_strided(A::BatchedAdjoint) = eltype(A) <: Real && is_strided(parent(A))\nis_strided(A::BatchedTranspose) = is_strided(parent(A))\n\nis_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A))\nis_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A))\n# This needs Compat 3.14, for any Julia < 1.6\n\nare_strided(As...) = mapfoldl(is_strided, &, As; init=true)\n"
  },
  {
    "path": "src/bias_act.jl",
    "content": "\nusing NNlib: fast_act, tanh_fast\nusing ChainRulesCore\n\nconst RCR = RuleConfig{>:HasReverseMode}\n\n# This just saves typing `only.(only.(` many times:\n@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x)))\n\n# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`\n# is independent of `x`, as `return_type` says `Union{}` when calling is an error.\nstruct NotaNumber <: Real end\n\n\"\"\"\n    bias_act!(σ, x, b)\n\nThis is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh`\nwith `sigmoid_fast` & `tanh_fast`.\nIt will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`.\n\nWhen used within a gradient, it will overwrite only when `σ` has\na method of `derivatives_given_output` which does not need the input at all.\nSuch methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative\ncontains only `Ω` (the output) not `x`.\n\n!!! warning\n    This is not safe to use if `x` is still needed for the gradient\n    of some other function. Incorrect use will give silently wrong answers.\n    It is intended mainly for Flux layers, in which the previous operation is\n    known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer.\n\"\"\"\nbias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) =\n    _fast_broadcast!(fast_act(σ, x)∘(+), x, b)  # works around a SIMD bug\n\nfunction bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool)\n    b === true && error(\"bias=true is not accepted; layer constructors shoud guarantee this\")\n    _fast_broadcast!(fast_act(σ, x), x)\nend\n\nfunction bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool)\n    b === true && error(\"bias=true is not accepted; layer constructors shoud guarantee this\")\n    x  # pass-through\nend\n\nfunction bias_act!(σ::Function, x::AbstractArray, b)\n    b === true && error(\"bias=true is not accepted; layer constructors shoud guarantee this\")\n    fast_act(σ, x).(x .+ b)  # fallback\nend\n\nfunction ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}\n    biasgrad = if eltype(B) !== Bool\n        # Summing over ndims(x)+1 is a trick to make b_dims type-stable\n        dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)\n        _biasgrad(dx) = reshape(sum(dx; dims), size(b))\n    else\n        Returns(NoTangent())\n    end\n\n    # Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ\n    if isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, NotaNumber}))\n        Ω = bias_act!(σ, x, b)  # now x === Ω, when x isa StridedArray{<:AbstractFloat}\n        function bias_act!_fastback(Δ)\n            # Tempting to overwrite x again, but only safe if you call pullback at most once,\n            # TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340\n            # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592\n            dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ)\n            return (NoTangent(), NoTangent(), dx, biasgrad(dx))\n        end\n        return Ω, bias_act!_fastback\n\n    # # Slower path: can't overwrite x, but can use derivatives_given_output\n    # # This case is WRONG and tests fail, but not sure why\n    # elseif isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, T}))\n    #     Ω2 = fast_act(σ, x).(x) .+ b\n    #     @show σ b\n    #     function bias_act!_back2(Δ)\n    #         dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)\n    #         return (NoTangent(), NoTangent(), dx, biasgrad(dx))\n    #     end\n    #     return Ω2, bias_act!_back2\n\n    # Fallback path: let AD handle the broadcast\n    else\n        Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b))\n        @inline function bias_act!_slowback(Δ)\n            _, _, dx = back(Δ)\n            return (NoTangent(), NoTangent(), dx, biasgrad(dx))\n        end\n        return Ω3, bias_act!_slowback\n    end\nend\n\n# Two easy cases with identity\nfunction rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B}\n    dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)\n    biasgrad(dx) = reshape(sum(dx; dims), size(b))\n    function bias_act!_idback(Δ)\n        dx = unthunk(Δ)\n        return (NoTangent(), NoTangent(), dx,  biasgrad(dx))\n    end\n    return bias_act!(identity, x, b), bias_act!_idback\nend\nfunction rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N}\n    bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent())\n    return x, bias_act!_trivial\nend\n\n"
  },
  {
    "path": "src/conv.jl",
    "content": "## Convolution API\n#\n#  We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d,\n#  2d and 3d convolutions, based on the rank of the input tensors, in both mutating and\n#  non-mutating auto-allocating variants:\n#   - Convolution:\n#     - conv(x, w, cdims)\n#     - conv!(y, x, w, cdims)\n#   - Convolution data backpropagation\n#     - ∇conv_data(dy, w, cdims)\n#     - ∇conv_data!(dx, dy, w, cdims)\n#   - Convolution filter backpropagation\n#     - ∇conv_filter(x, dy, cdims)\n#     - ∇conv_filter!(dw, x, dy, cdims)\n#\n#   All methods require a `ConvDims` object to define the dimensions and optional\n#   elements of the convolution (padding, stride, dilation, kernel-flipping, etc...),\n#   which is easily constructable through something like `DenseConvDims(x, w)`.  All\n#   methods take in the `ConvDims` of the associated normal, forward-pass convolution,\n#   that is, the following is legal:\n#\n#       cdims = ConvDims(x, w; stride=2, dilation=(3,2))\n#       dx = ∇conv_data(conv(x, w, cdims), w, cdims)\n\n#   The computational flow, starting from the user facing functions,\n#   goes through the following steps:\n#\n#   STEP 1:\n#       use ConvDims objects (only for `conv` and `depthwiseconv`)\n#   STEP 2:\n#        define autoallocating version (frontend and implementations)\n#   STEP 3:\n#        reshape to 3d convolutions (frontend and implementions)\n#   STEP 4:\n#        choose implementation\n\n# TODO: should we also add\n#   STEP X:\n#        use homogeneus datatypes\n# to handle etherogeneus inputs now handled by conv_direct?\n\n\n########## STEP 1 ############\n\"\"\"\n    conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1)\n\nApply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors\nin 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types.\n\"\"\"\nfunction conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N}\n    stride = expand(Val(N - 2), stride)\n    padding = expand(Val(N - 2), pad)\n    dilation = expand(Val(N - 2), dilation)\n    cdims = DenseConvDims(\n        size(x), size(w); stride, padding, dilation, flipkernel=flipped, groups)\n    return conv(x, w, cdims)\nend\n\n\"\"\"\n    depthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false)\n\nDepthwise convolution operation with filter `w` on input `x`. `x` and `w`\nare 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.\n\"\"\"\nfunction depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N}\n    stride = expand(Val(N-2), stride)\n    pad = expand(Val(N-2), pad)\n    dilation = expand(Val(N-2), dilation)\n    cdims = DepthwiseConvDims(x, w; stride=stride, padding=pad, dilation=dilation, flipkernel=flipped)\n    return depthwiseconv(x, w, cdims)\nend\n##############################\n\n\n########### STEP 2 ###################\n# Let's generate auto-allocating versions of all our functions, for all backends.\n# We `@timeit` these methods separately, as we want to know how much time is spent in\n# allocation.  :P\nfor backend in (Symbol(), :_direct, :_im2col)\n    # First make auto-allocating versions of the conv()-like calls:\n    for name in (:conv, :depthwiseconv)\n        @eval begin\n            function $(Symbol(\"$(name)$(backend)\"))(\n                            x::AbstractArray{xT,N}, w::AbstractArray{wT,N},\n                            cdims::ConvDims; kwargs...) where {xT, wT, N}\n                y = similar(x, promote_type(xT, wT), output_size(cdims)...,\n                               channels_out(cdims), size(x,N))\n                return $(Symbol(\"$(name)$(backend)!\"))(y, x, w, cdims; kwargs...)\n            end\n        end\n    end\n\n    for name in (:∇conv_data, :∇depthwiseconv_data)\n        @eval begin\n            function $(Symbol(\"$(name)$(backend)\"))(\n                            dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},\n                            cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims}\n                dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N))\n                return $(Symbol(\"$(name)$(backend)!\"))(dx, dy, w, cdims; kwargs...)\n            end\n        end\n    end\n\n    # We do the conv/depthwiseconv filter backprops separately, as the shape calculation\n    # for `w` is slightly different for depthwise than for normal dense convolution.\n    @eval begin\n        function $(Symbol(\"∇conv_filter$(backend)\"))(\n                        x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},\n                        cdims::ConvDims; kwargs...) where {xT, yT, N}\n            dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims),\n                                                    channels_out(cdims))\n            return $(Symbol(\"∇conv_filter$(backend)!\"))(dw, x, dy, cdims; kwargs...)\n        end\n    end\n\n    @eval begin\n        function $(Symbol(\"∇depthwiseconv_filter$(backend)\"))(\n                        x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},\n                        cdims::ConvDims; kwargs...) where {xT, yT, N}\n            dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims),\n                                                    channels_in(cdims))\n            return $(Symbol(\"∇depthwiseconv_filter$(backend)!\"))(dw, x, dy, cdims;\n                                                                 kwargs...)\n        end\n    end\nend\n##########################################\n\n\n########## STEP 3 ############\n\n# Our strategy for 1d and 2d convolution is to reshape to 3d convolutions, which\n# makes things MUCH EASIER for us on the backend side, and is in general pretty fast,\n# since we can specialize on sizes.\nfor front_name in (:conv, :∇conv_data, :∇conv_filter,\n                   :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)\n    for backend in (Symbol(), :_direct, :_im2col)\n        for N in (3, 4)\n            @eval begin\n                function $(Symbol(\"$(front_name)$(backend)!\"))(\n                                y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N},\n                                w::AbstractArray{wT,$N}, cdims::ConvDims;\n                                kwargs...) where {yT, xT, wT}\n\n                    $(Symbol(\"$(front_name)$(backend)!\"))(\n                        insert_singleton_spatial_dimension(y, $(5 - N)),\n                        insert_singleton_spatial_dimension(x, $(5 - N)),\n                        insert_singleton_spatial_dimension(w, $(5 - N)),\n                        insert_singleton_spatial_dimension(cdims, $(5 - N));\n                        kwargs...\n                    )\n\n                    # We explicitly return `y` here, because the backend call\n                    # itself may return a reshaped view, which we don't want.\n                    return y\n                end\n            end\n        end\n    end\nend\n\n#######################################\n\n\n########### STEP 4 ############\n\n# First, we will define mappings from the generic API names to our accelerated backend\n# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using\n# im2col + GEMM.\n# But we always support a fallback, non-accelerated path, where we use the direct, but\n# slow, implementations. These should not typically be used, hence the `@warn`,\n\n# These are the GEMM types we will accelerate with `im2col`\nconst G = Union{[x[2] for x in gemm_datatype_mappings]...}\n\nfor (front_name, backend, signature) in (\n    # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause\n    # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))\n    (:conv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),\n    (:conv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),\n)\n    # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution\n    @eval begin\n\n        function $(Symbol(\"$(front_name)!\"))(\n                        out::AbstractArray{$(signature[1][1]), $(signature[1][2])},\n                        in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},\n                        in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},\n                        cdims::$(signature[4]);\n                        kwargs...) where {$(signature[5]...)}\n            if $(string(backend)) == \"direct\" && yT == Float64  # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual\n                @warn string(\"Slow fallback implementation invoked for \", $(string(front_name)), \"!  \",\n                        \"You probably don't want this; check your datatypes.\") yT T1 T2 maxlog=1\n            end\n\n            x_cs = Iterators.partition(1:size(in1, 4),\n                                    channels_in(cdims) ÷ groupcount(cdims))\n            w_cs = Iterators.partition(1:size(in2, 5),\n                                    channels_out(cdims) ÷ groupcount(cdims))\n            cdims2 = basetype(C)(cdims,\n                                G = 1,\n                                C_in = channels_in(cdims) ÷ groupcount(cdims),\n                                C_out = channels_out(cdims) ÷ groupcount(cdims))\n\n            function conv_group(xc, wc)\n                x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]\n                w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]\n                y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]\n                $(Symbol(\"$(front_name)_$(backend)!\"))(y, x, w, cdims2; kwargs...)\n            end\n\n            if should_use_spawn() && length(x_cs) > 1\n                Threads.@sync for (xc, wc) in zip(x_cs, w_cs)\n                    Threads.@spawn conv_group(xc, wc)\n                end\n            else\n                for (xc, wc) in zip(x_cs, w_cs)\n                    conv_group(xc, wc)\n                end\n            end\n\n            return out\n        end\n    end\nend\n\n# im2col-accelerated function forwarding definition\nfor (front_name, backend, signature) in (\n    # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause\n    # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))\n    (:∇conv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),\n    (:∇conv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),\n)\n    # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution\n    @eval begin\n        function $(Symbol(\"$(front_name)!\"))(\n                        out::AbstractArray{$(signature[1][1]), $(signature[1][2])},\n                        in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},\n                        in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},\n                        cdims::$(signature[4]);\n                        kwargs...) where {$(signature[5]...)}\n            if $(string(backend)) == \"direct\" && yT == Float64  # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual\n                @warn string(\"Slow fallback implementation invoked for \", $(string(front_name)), \"!  \",\n                        \"You probably don't want this; check your datatypes.\") yT T1 T2 maxlog=1\n            end\n\n\n            dx_cs = Iterators.partition(1:size(out, 4),\n                                        channels_in(cdims) ÷ groupcount(cdims))\n            w_cs = Iterators.partition(1:size(in2, 5),\n                                    channels_out(cdims) ÷ groupcount(cdims))\n            dy_cs = Iterators.partition(1:size(in1, 4),\n                                        channels_out(cdims) ÷ groupcount(cdims))\n            cdims2 = basetype(C)(cdims,\n                                G = 1,\n                                C_in = channels_in(cdims) ÷ groupcount(cdims),\n                                C_out = channels_out(cdims) ÷ groupcount(cdims))\n\n            function ∇conv_data_group(xc, yc, wc)\n                dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]\n                dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]\n                wv = @view in2[ntuple(i -> i == 5  ? wc : Colon(), 5)...]\n                $(Symbol(\"$(front_name)_$(backend)!\"))(dxv, dyv, wv, cdims2; kwargs...)\n            end\n\n            if should_use_spawn() && length(dx_cs) > 1\n                Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)\n                    Threads.@spawn ∇conv_data_group(xc, yc, wc)\n                end\n            else\n                for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)\n                    ∇conv_data_group(xc, yc, wc)\n                end\n            end\n\n            return out\n        end\n    end\nend\n\nfor (front_name, backend, signature) in (\n    # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause\n    # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))\n    (:∇conv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),\n    (:∇conv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),\n)\n    # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution\n    @eval begin\n        function $(Symbol(\"$(front_name)!\"))(\n                        out::AbstractArray{$(signature[1][1]), $(signature[1][2])},\n                        in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},\n                        in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},\n                        cdims::$(signature[4]);\n                        kwargs...) where {$(signature[5]...)}\n            if $(string(backend)) == \"direct\" && yT == Float64  # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual\n                @warn string(\"Slow fallback implementation invoked for \", $(string(front_name)), \"!  \",\n                        \"You probably don't want this; check your datatypes.\") yT T1 T2 maxlog=1\n            end\n\n            dw_cs = Iterators.partition(1:size(out, 5),\n                                        channels_out(cdims) ÷ groupcount(cdims))\n            dy_cs = Iterators.partition(1:size(in2, 4),\n                                        channels_out(cdims) ÷ groupcount(cdims))\n            x_cs = Iterators.partition(1:size(in1, 4),\n                                    channels_in(cdims) ÷ groupcount(cdims))\n            cdims2 = basetype(C)(cdims,\n                                G = 1,\n                                C_in = channels_in(cdims) ÷ groupcount(cdims),\n                                C_out = channels_out(cdims) ÷ groupcount(cdims))\n\n            function ∇conv_filter_group(wc, xc, yc)\n                x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]\n                dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]\n                dw = @view out[ntuple(i -> i == 5 ? wc : Colon(), 5)...]\n                $(Symbol(\"$(front_name)_$(backend)!\"))(dw, x, dy, cdims2; kwargs...)\n            end\n\n            if should_use_spawn() && length(dw_cs) > 1\n                Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)\n                    Threads.@spawn ∇conv_filter_group(wc, xc, yc)\n                end\n            else\n                for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)\n                    ∇conv_filter_group(wc, xc, yc)\n                end\n            end\n\n            return out\n        end\n    end\nend\n\n\nfor (front_name, backend, signature) in (\n    # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause\n    # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))\n    (:depthwiseconv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),\n    (:depthwiseconv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),\n\n    (:∇depthwiseconv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),\n    (:∇depthwiseconv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),\n\n    (:∇depthwiseconv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),\n    (:∇depthwiseconv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),\n)\n\n    # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution\n    @eval begin\n        # im2col-accelerated function forwarding definition\n        function $(Symbol(\"$(front_name)!\"))(\n                        out::AbstractArray{$(signature[1][1]), $(signature[1][2])},\n                        in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},\n                        in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},\n                        cdims::$(signature[4]);\n                        kwargs...) where {$(signature[5]...)}\n            if $(string(backend)) == \"direct\" && yT == Float64  # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual\n                @warn string(\"Slow fallback implementation invoked for \", $(string(front_name)), \"!  \",\n                        \"You probably don't want this; check your datatypes.\") yT T1 T2 maxlog=1\n            end\n            $(Symbol(\"$(front_name)_$(backend)!\"))(out, in1, in2, cdims; kwargs...)\n        end\n    end\nend\n\nfor Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]\n    @eval @non_differentiable $Dims(::Any...)\nend\n\ncolmajor(x) = (is_strided(x) && Base.stride(x, 1) == 1) ? x : collect(x)\n\nfor conv in [:conv, :depthwiseconv]\n    local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])\n    conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback)\n\n    @eval function rrule(::typeof($conv), x, w, cdims; kw...)\n        function $conv_pullback(Δraw)\n            Δ = colmajor(unthunk(Δraw))\n            return (\n                NoTangent(),\n                @thunk($∇conv_data(Δ, w, cdims, kw...)),\n                @thunk($∇conv_filter(x, Δ, cdims, kw...)),\n                NoTangent(),\n            )\n        end\n        return $conv(x, w, cdims; kw...), $conv_pullback\n    end\n\n    @eval function rrule(::typeof($∇conv_data), x, w, cdims; kw...)\n        function $∇conv_data_pullback(Δraw)\n            Δ = colmajor(unthunk(Δraw))\n            return (\n                NoTangent(),\n                @thunk($conv(Δ, w, cdims, kw...)),\n                @thunk($∇conv_filter(Δ, x, cdims, kw...)),\n                NoTangent(),\n            )\n        end\n        return $∇conv_data(x, w, cdims; kw...), $∇conv_data_pullback\n    end\nend\n\nfunction rrule(::typeof(∇conv_filter), x, dy, cdims; kw...)\n    function ∇conv_filter_pullback(Δ)\n        Δ1 = colmajor(unthunk(Δ))\n        return (\n            NoTangent(),\n            @thunk(∇conv_data(dy, Δ1, cdims, kw...)),\n            @thunk(conv(x, Δ1, cdims, kw...)),\n            NoTangent(),\n        )\n    end\n    return ∇conv_filter(x, dy, cdims; kw...), ∇conv_filter_pullback\nend\n"
  },
  {
    "path": "src/conv_bias_act.jl",
    "content": "function conv_bias_act(x::AbstractArray{xT,N}, w::AbstractArray{wT,N},\n                cdims::ConvDims, b::AbstractArray{bT,N}, σ=identity; kwargs...) where {xT, wT, bT, N}\n    y = similar(x, promote_type(xT, wT, bT), output_size(cdims)..., channels_out(cdims), size(x,N))\n    conv_bias_act!(y, x, w, cdims, b, σ; kwargs...)\n    return y\nend\n\nfunction conv_bias_act!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5},\n                cdims::ConvDims, b::AbstractArray{bT,5}, σ=identity; kwargs...) where {yT, xT, wT, bT}\n    conv!(y, x, w, cdims)\n    y .= σ.(y .+ b)\n    return y\nend\n\nfor N in (3, 4)\n    @eval begin\n        function $(Symbol(\"conv_bias_act!\"))(\n                        y::AbstractArray{yT,$N}, x::AbstractArray{xT,$N},\n                        w::AbstractArray{wT,$N}, cdims::ConvDims,\n                        b::AbstractArray{bT,$N}, σ=identity;\n                        kwargs...) where {yT, xT, wT, bT}\n            $(Symbol(\"conv_bias_act!\"))(\n                insert_singleton_spatial_dimension(y, $(5 - N)),\n                insert_singleton_spatial_dimension(x, $(5 - N)),\n                insert_singleton_spatial_dimension(w, $(5 - N)),\n                insert_singleton_spatial_dimension(cdims, $(5 - N)),\n                insert_singleton_spatial_dimension(b, $(5 - N)),\n                σ;\n                kwargs...\n            )\n\n            # We explicitly return `y` here, because the backend call\n            # itself may return a reshaped view, which we don't want.\n            return y\n        end\n    end\nend\n"
  },
  {
    "path": "src/ctc.jl",
    "content": "# CTC loss moved from Flux.jl to NNlib\n\n## CPU implementation\n\n\"\"\"\n    logaddexp(a, b)\n\nAdds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))`\n\"\"\"\nfunction logaddexp(a, b)\n  isinf(a) && return b\n  isinf(b) && return a\n\n  # always want the greater number on the left in the exponentiation;\n  # the magnitude difference may end up making the number very positive\n  # which will cause exp() to return Inf\n  # E.g., a = -900, b = -800, will give exp(-800 - -900), which will be\n  # Inf for Float32 values\n  if a < b\n    a, b = b, a\n  end\n  return a + log(1+exp(b-a))\nend\n\n\"\"\"\n    add_blanks(z)\n\nAdds blanks to the start and end of `z`, and between items in `z`\n\"\"\"\nfunction add_blanks(z, blank)\n  z′ = fill(blank, 2*length(z) + 1)\n  z′[2 .* eachindex(z)] = z\n  return z′\nend\n\nfunction ctc_alpha(ŷ::AbstractArray, y)\n  typed_zero = zero(ŷ[1])\n  ŷ = logsoftmax(ŷ)\n  blank = size(ŷ, 1)\n  z′ = add_blanks(y, blank)\n  T = size(ŷ, 2)\n  U′ = length(z′)\n\n  α = fill(log(typed_zero), U′, T)\n  α[1,1] = ŷ[blank, 1]\n  α[2,1] = ŷ[z′[2], 1]\n  for t=2:T\n    bound = max(1, U′ - 2(T - t) - 1)\n    for u=bound:U′\n      if u == 1\n        α[u,t] = α[u, t-1]\n      else\n      α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])\n      \n      # array bounds check and f(u) function from Eq. 7.9\n      if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])\n        α[u,t] = logaddexp(α[u,t], α[u-2,t-1])\n      end\n    end\n    α[u,t] += ŷ[z′[u], t]\n    end\n  end\n  return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ)\nend\n  \nfunction ∇ctc_loss(ŷ::AbstractArray, y, out)\n  loss, α, z′, ŷ = out\n  U′, T = size(α)\n  blank = size(ŷ, 1)\n  typed_zero = zero(first(α))\n\n  # Calculate beta coefficients, from the bottom-right, to the upper-left\n  β = fill(log(typed_zero), U′, T)\n\n  # Fill bottom-right corner so bounding errors can be avoided\n  # by starting `u` at `U′-1`\n  β[U′, T] = typed_zero\n  β[U′-1, T] = typed_zero\n  \n  # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1\n  for t=(T-1):-1:1\n    bound = min(U′, 2t)\n    for u=bound:-1:1\n      if u == U′\n        β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]\n      else\n        β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])\n\n        # array bounds check and g(u) function from Eq. 7.16\n        if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]\n          β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])\n        end\n      end\n    end\n  end\n\n  # Accumulate alpha-beta products for each category,\n  # then calculate gradients\n  accum = fill(log(typed_zero), size(ŷ))\n  for t=1:T\n    for u=1:U′\n      accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t])\n    end\n  end\n  grads = exp.(ŷ) .- exp.(accum .+ loss)\n  return grads\nend\n\n\"\"\"\n    ctc_loss(ŷ, y)\n\nComputes the connectionist temporal classification loss between `ŷ`\nand `y`.\n`ŷ` must be a classes-by-time matrices, i.e., each row\nrepresents a class and each column represents a time step.\nAdditionally, the `logsoftmax` function will be applied to `ŷ`, so\n`ŷ` must be the raw activation values from the neural network and\nnot, for example, the activations after being passed through a\n`softmax` activation function. `y` must be a 1D array of the labels\nassociated with `ŷ`. The blank label is assumed to be the last label\ncategory in `ŷ`, so it is equivalent to `size(ŷ, 1)`.\nUsed for sequence-to-sequence classification problems such as\nspeech recognition and handwriting recognition where the exact\ntime-alignment of the output (e.g., letters) is not needed to\nsolve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf)\nor [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7)\nfor mathematical details.\n\"\"\"\nctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss\n\nfunction ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)\n  tmp = ctc_alpha(ŷ, y)\n  ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent())\n  return tmp.loss, ctc_loss_pullback\nend\n"
  },
  {
    "path": "src/deprecations.jl",
    "content": "### Deprecated while v0.8 was latest\n\nexport ∇softmax,\n    ∇softmax!,\n    logsoftmax,\n    logsoftmax!,\n    ∇logsoftmax,\n    ∇logsoftmax!\n\nfunction ∇softmax!(out::AbstractArray, Δ::AbstractArray, \n                    x::AbstractArray, y::AbstractArray; dims = 1)\n    Base.depwarn(\"`∇softmax!(dx, dy, x, y)` is deprecated, just use `∇softmax_data(dy, y)`\", :∇softmax!)\n    # Removed because using a mutating function blocks 2nd derivatives, and\n    # the CUDA overload was slow anyway, https://github.com/FluxML/NNlibCUDA.jl/issues/30\n    out .= Δ .* y\n    out .= out .- y .* sum(out; dims)\nend\n\nfunction ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray,\n                    x::AbstractArray, y::AbstractArray; dims = 1) \n    Base.depwarn(\"`∇logsoftmax!(dx, dy, x, y)` is deprecated, just use `∇logsoftmax_data(dy, y)`\", :∇softmax!)\n    out .= Δ .- sum(Δ; dims) .* exp.(y)\nend\n\nfunction ∇softmax(dy::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S}\n    # Removed because there's no need to close over `x` here, that was done only to distinguish\n    # this from `∇softmax(Δ, x; dims = 1)` which re-computed `y = softmax(x)`, which is slow.\n    Base.depwarn(\"`∇softmax(dy, x, y)` should be replaced with `∇softmax_data(dy, y)`\", :∇softmax)\n    ∇softmax_data(dy, y)\nend\n\nfunction ∇logsoftmax(dy::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1)\n    Base.depwarn(\"`∇logsoftmax(dy, x, y)` should be replaced with `∇logsoftmax_data(dy, y)`\", :∇softmax)\n    ∇logsoftmax_data(dy, y)\nend\n\n"
  },
  {
    "path": "src/dim_helpers/ConvDims.jl",
    "content": "\"\"\"\n    ConvDims\n\nType system-level information about convolution dimensions. Critical for things like\n`im2col!()` to generate efficient code, and helpful to reduce the number of kwargs\ngetting passed around.\n\"\"\"\nabstract type ConvDims{N} end\n\n@inline spatial_dims(::ConvDims{N}) where N = N\n@inline groupcount(c::ConvDims) = 1\n\n# Below functions should be implemented by dims that subtype `ConvDims`.\nfunction input_size end\nfunction kernel_size end\nfunction stride end\nfunction padding end\nfunction dilation end\nfunction flipkernel end\n\n# Hack to get rid of type parameters\nfunction basetype(::Type{C}) where {C <: ConvDims}\n    if C <: DepthwiseConvDims\n        return DepthwiseConvDims\n    elseif C <: DenseConvDims\n        return DenseConvDims\n    elseif C <: PoolDims\n        return PoolDims\n    else\n        return nothing\n    end\nend\n\nfunction output_size(c::ConvDims)\n    I = input_size(c)\n    K = kernel_size(c)\n    S = stride(c)\n    P = padding(c)\n    D = dilation(c)\n\n    return ntuple(spatial_dims(c)) do i\n        return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1\n    end\nend\n\nfunction Base.show(io::IO, cdims::C) where {C <: ConvDims}\n    I = (input_size(cdims)..., channels_in(cdims))\n    O = (output_size(cdims)..., channels_out(cdims))\n    K = kernel_size(cdims)\n    S = stride(cdims)\n    P = padding(cdims)\n    D = dilation(cdims)\n    F = flipkernel(cdims)\n    G = groupcount(cdims)\n    print(io, \"$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F, groups: $G\")\nend\n\n\"\"\"\n    im2col_dims(c::ConvDims)\n\nim2col calculates, for each output pixel, the \"convolution\" of N kernels where N is the\nnumber of output channels, by doing a matrix multiply.  The dimensions of that matrix\nare given by this function.\n\nNote that because im2col is multithreaded, we need to allocate a separate workspace of\nmemory per-thread; hence the dimensions returned by this will depend on the number of\nthreads Julia is currently running with.\n\"\"\"\nfunction im2col_dims(c::ConvDims)\n    return (\n        # Output size\n        prod(output_size(c)),\n        # Size of single dotproduct within convolution\n        prod(kernel_size(c))*channels_in(c),\n        # One workspace per thread\n        Threads.nthreads(:default),\n    )\nend\n\n\"\"\"\n    ∇filter_im2col_dims(c::ConvDims)\n\nLike [`im2col_dims`](@ref), but saves some memory because multiple (Julia) threads are\nnot required for the filter gradient calculation.\n\nNote: in the future, this may return `Dims{2}` instead of `Dims{3}`.\n\"\"\"\nfunction ∇filter_im2col_dims(c::ConvDims)\n    return (\n        # Output size\n        prod(output_size(c)),\n        # Size of single dotproduct within convolution\n        prod(kernel_size(c))*channels_in(c),\n        # No threading, this is just here for backwards compat\n        1\n    )\nend\n\n# Protect your skin, kids.  Also do common validation of stride, padding, etc...\nfunction check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N}\n    # Number of spatial dimensions in `x` and `w`.\n    nd = N - 2\n\n    # Given a number, duplicate it out to have `nd` length.  If it's already a collection,\n    # just splat it out into a tuple so it's always a tuple.  We'll lint length later.\n    expand_size(p::Number) = ntuple(_ -> Int(p), nd)\n    expand_size(p) = tuple(p...)\n\n    # Convert stride, padding, dilation, etc.. to fully-specified tuples\n    pstride = expand_size(stride)\n    pdilation = expand_size(dilation)\n    ppadding = expand_size(padding)\n\n    if length(pstride) != nd\n        throw(DimensionMismatch(\"Stride $(length(stride))d, should be $(nd)d!\"))\n    end\n    if length(pdilation) != nd\n        throw(DimensionMismatch(\"Dilation $(length(pdilation))d, should be $(nd)d!\"))\n    end\n\n    # padding is kind of a special case; we allow it to be either 2-length or 4-length,\n    # since we support asymmetrical padding\n    if length(ppadding) == 2 * nd\n        _validate_padding(x_size, w_size, ppadding, pdilation)\n        return pstride, ppadding, pdilation\n    end\n\n    length(ppadding) != nd && throw(DimensionMismatch(\n        \"Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!\"))\n\n    # Do this repeat dance so that we get lo/hi symmetrical padding\n    ppadding_expanded = ntuple(i -> ppadding[(i - 1) ÷ 2 + 1], 2 * nd)\n    _validate_padding(x_size, w_size, ppadding_expanded, pdilation)\n    return pstride, ppadding_expanded, pdilation\nend\n\n# Assert that kernel size * dilation is <= padded input size\nfunction _validate_padding(x_size::NTuple{N}, w_size::NTuple{N}, padding, dilation) where N\n    for idx in 1:(N - 2)\n        Is = x_size[idx]\n        Ks = w_size[idx]\n        Pl = padding[(idx - 1) * 2 + 1]\n        Ph = padding[(idx - 1) * 2 + 2]\n        Ds = dilation[idx]\n        if Is + Pl + Ph < (Ks - 1) * Ds + 1\n            throw(DimensionMismatch(\"Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!\"))\n        end\n    end\n    nothing\nend\n"
  },
  {
    "path": "src/dim_helpers/DenseConvDims.jl",
    "content": "\"\"\"\n    DenseConvDims\n\nConcrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d.\n\"\"\"\nstruct DenseConvDims{N, K, S, P, D} <: ConvDims{N}\n    input_size::NTuple{N, Int}\n\n    kernel_size::NTuple{K, Int}\n    channels_in::Int\n    channels_out::Int\n    groupcount::Int\n\n    stride::NTuple{S, Int}\n    padding::NTuple{P, Int}\n    dilation::NTuple{D, Int}\n    flipkernel::Bool\nend\n\nfunction DenseConvDims(\n    x_size::NTuple{M}, w_size::NTuple{M};\n    stride = 1, padding = 0, dilation = 1, groups = 1,\n    flipkernel::Bool = false,\n) where {M}\n    sstride, ppadding, ddilation = check_spdf(\n        x_size, w_size, stride, padding, dilation)\n\n    # Ensure channels are equal\n    if x_size[end - 1] != w_size[end - 1] * groups\n        xs = x_size[end - 1]\n        ws = w_size[end - 1]\n        throw(DimensionMismatch(\"Input channels must match! ($xs vs. $ws)\"))\n    end\n\n    # Ensure groups are valid\n    if x_size[end - 1] % w_size[end - 1] != 0 || w_size[end] % groups != 0\n        throw(DimensionMismatch(\n            \"Group count should be divisble by input and output channels ($groups vs. $(w_size[end-1:end]))\"))\n    end\n\n    DenseConvDims(\n        x_size[1:(end - 2)],\n        w_size[1:(end - 2)], x_size[end - 1], w_size[end], groups,\n        sstride, ppadding, ddilation, flipkernel)\nend\n\nfunction DenseConvDims(x::AbstractArray, w::AbstractArray; kwargs...)\n    if ndims(x) != ndims(w)\n        throw(DimensionMismatch(\n            \"Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))\"))\n    end\n    return DenseConvDims(size(x), size(w); kwargs...)\nend\n\n# Useful for constructing a new DenseConvDims that has only a few elements different\n# from the original progenitor object that it inherits shapes from.\n@inline DenseConvDims(\n    c::C; I=input_size(c), K=kernel_size(c),\n    C_in=channels_in(c), C_out=channels_out(c), S=stride(c),\n    P=padding(c), D=dilation(c), F=flipkernel(c), G=groupcount(c),\n) where C <: ConvDims = DenseConvDims(\n    I,\n    K, C_in, C_out, G,\n    S, P, D, F)\n\n@inline groupcount(c::DenseConvDims) = c.groupcount\n@inline channels_in(c::DenseConvDims) = c.channels_in\n@inline channels_out(c::DenseConvDims) = c.channels_out\n\n@inline input_size(c::DenseConvDims) = c.input_size\n@inline kernel_size(c::DenseConvDims) = c.kernel_size\n\n@inline stride(c::DenseConvDims) = c.stride\n@inline padding(c::DenseConvDims) = c.padding\n@inline dilation(c::DenseConvDims) = c.dilation\n@inline flipkernel(c::DenseConvDims) = c.flipkernel\n\nfunction check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M}\n    # First, check that channel counts are all correct:\n    @assert x[M-1] * groupcount(cdims) == channels_in(cdims) DimensionMismatch(\"Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))\")\n    @assert y[M-1] == channels_out(cdims) ÷ groupcount(cdims)  DimensionMismatch(\"Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))\")\n    @assert w[M-1] * groupcount(cdims) == channels_in(cdims) DimensionMismatch(\"Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))\")\n    @assert w[M] * groupcount(cdims) == channels_out(cdims) DimensionMismatch(\"Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))\")\n\n    # Next, check that the spatial dimensions match up\n    @assert x[1:M-2] == input_size(cdims) DimensionMismatch(\"Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))\")\n    @assert y[1:M-2] == output_size(cdims) DimensionMismatch(\"Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))\")\n    @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch(\"Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))\")\n\n    # Check the groups match\n    @assert channels_in(cdims) % groupcount(cdims) == 0 DimensionMismatch(\"Groups ($(groupcount(cdims))) should be divisble by input channels $(channels_in(cdims))\")\n\n    # Finally, check that the batch size matches\n    @assert x[M] == y[M] DimensionMismatch(\"Batch size ($(x[M]) vs. $(y[M]))\")\nend\n"
  },
  {
    "path": "src/dim_helpers/DepthwiseConvDims.jl",
    "content": "\"\"\"\n    DepthwiseConvDims\n\nConcrete subclass of `ConvDims` for a depthwise convolution.  Differs primarily due to\ncharacterization by `C_in`, `C_mult`, rather than `C_in`, `C_out`.  Useful to be separate from\nDenseConvDims primarily for channel calculation differences.\n\"\"\"\nstruct DepthwiseConvDims{N, K, S, P, D} <: ConvDims{N}\n    input_size::NTuple{N, Int}\n\n    kernel_size::NTuple{K, Int}\n    channels_in::Int\n    channels_multiplier::Int\n\n    stride::NTuple{S, Int}\n    padding::NTuple{P, Int}\n    dilation::NTuple{D, Int}\n    flipkernel::Bool\nend\n\nfunction DepthwiseConvDims(\n    x_size::NTuple{M}, w_size::NTuple{M};\n    stride = 1, padding = 0, dilation = 1, flipkernel::Bool = false,\n) where M\n    sstride, ppadding, ddilation = check_spdf(\n        x_size, w_size, stride, padding, dilation)\n\n    # Ensure channels are equal\n    if x_size[end-1] != w_size[end]\n        xs = x_size[end-1]\n        ws = w_size[end]\n        throw(DimensionMismatch(\"Input channels must match! ($xs vs. $ws)\"))\n    end\n\n    DepthwiseConvDims(\n        x_size[1:(end - 2)],\n        w_size[1:(end - 2)], x_size[end - 1], w_size[end - 1],\n        sstride, ppadding, ddilation, flipkernel)\nend\n\nfunction DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...)\n    if ndims(x) != ndims(w)\n        throw(DimensionMismatch(\"Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))\"))\n    end\n    return DepthwiseConvDims(size(x), size(w); kwargs...)\nend\n\n# Useful for constructing a new DepthwiseConvDims that has only a few elements different\n# from the original progenitor object.\n@inline DepthwiseConvDims(\n    c::DepthwiseConvDims; I=input_size(c), K=kernel_size(c),\n    C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c),\n    P=padding(c), D=dilation(c), F=flipkernel(c),\n) = DepthwiseConvDims(\n    I,\n    K, C_in, C_m,\n    S, P, D, F)\n\n@inline channels_in(c::DepthwiseConvDims) = c.channels_in\n@inline channels_out(c::DepthwiseConvDims) = c.channels_in * c.channels_multiplier\n@inline channel_multiplier(c::DepthwiseConvDims) = c.channels_multiplier\n\n@inline input_size(c::DepthwiseConvDims) = c.input_size\n@inline kernel_size(c::DepthwiseConvDims) = c.kernel_size\n\n@inline stride(c::DepthwiseConvDims) = c.stride\n@inline padding(c::DepthwiseConvDims) = c.padding\n@inline dilation(c::DepthwiseConvDims) = c.dilation\n@inline flipkernel(c::DepthwiseConvDims) = c.flipkernel\n\n# This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count\nfunction check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M}\n    # First, check that channel counts are all correct:\n    @assert x[M-1] == channels_in(cdims) DimensionMismatch(\"Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))\")\n    @assert y[M-1] == channels_out(cdims) DimensionMismatch(\"Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))\")\n    @assert w[M-1] == channel_multiplier(cdims) DimensionMismatch(\"Kernel multiplier channel count ($(w[M-1]) vs. $(channel_multiplier(cdims))\")\n    @assert w[M] == channels_in(cdims) DimensionMismatch(\"Kernel input channel count ($(w[M]) vs. $(channels_in(cdims)))\")\n\n    # Next, check that the spatial dimensions match up\n    @assert x[1:M-2] == input_size(cdims) DimensionMismatch(\"Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))\")\n    @assert y[1:M-2] == output_size(cdims) DimensionMismatch(\"Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))\")\n    @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch(\"Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))\")\n\n    # Finally, check that the batch size matches\n    @assert x[M] == y[M] DimensionMismatch(\"Batch size ($(x[M]) vs. $(y[M]))\")\nend\n"
  },
  {
    "path": "src/dim_helpers/PoolDims.jl",
    "content": "\"\"\"\n    PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int};\n             stride=k, padding=0, dilation=1)  where {M, L}\n\nDimensions for a \"pooling\" operation that can have an arbitrary input size, kernel size,\nstride, dilation, and channel count.  Used to dispatch onto efficient implementations at\ncompile-time.\n\"\"\"\nstruct PoolDims{N, K, S, P, D} <: ConvDims{N}\n    input_size::NTuple{N, Int}\n\n    kernel_size::NTuple{K, Int}\n    channels_in::Int\n\n    stride::NTuple{S, Int}\n    padding::NTuple{P, Int}\n    dilation::NTuple{D, Int}\nend\n\nfunction PoolDims(\n    x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int};\n    stride = k, padding = 0, dilation = 1,\n) where {M, L}\n    _check_kernel(k::Number, N::Int) = ntuple(_ -> Int(k), N)\n    _check_kernel(k::NTuple, ::Int) = k\n\n    kernel = _check_kernel(k, M - 2)\n    length(x_size) == length(kernel) + 2 || error(\n        \"PoolDims expects ndim(x) == length(k)+2 or length(size(x)) == length(kernel)+2,\n        dimension of x_size is $(length(x_size)),\n        length of k need $(length(x_size) - 2),\n        but now it's $(length(kernel))\"\n    )\n    spdf_kernel = NTuple{M, Int}([kernel..., 1, 1])\n\n    sstride, ppadding, ddilation = check_spdf(\n        x_size, spdf_kernel, stride, padding, dilation)\n    PoolDims(\n        x_size[1:(end - 2)], kernel, x_size[end - 1],\n        sstride, ppadding, ddilation)\nend\n\nPoolDims(x::AbstractArray, k; kwargs...) = PoolDims(size(x), k; kwargs...)\n\n# Useful for constructing a new PoolDims that has only a few elements different\n# from the original progenitor object that it inherits shapes from.\nPoolDims(\n    c::C; I=input_size(c), K=kernel_size(c),\n    C_in=channels_in(c), S=stride(c), P=padding(c), D=dilation(c),\n) where C <: ConvDims = PoolDims(I, K, C_in, S, P, D)\n\n@inline channels_in(c::PoolDims) = c.channels_in\n@inline channels_out(c::PoolDims) = c.channels_in\n\n@inline input_size(c::PoolDims) = c.input_size\n@inline kernel_size(c::PoolDims) = c.kernel_size\n\n@inline stride(c::PoolDims) = c.stride\n@inline padding(c::PoolDims) = c.padding\n@inline dilation(c::PoolDims) = c.dilation\n@inline flipkernel(c::PoolDims) = false\n\nfunction check_dims(x::NTuple{M}, y::NTuple{M}, pdims::PoolDims) where {M}\n    # First, check that channel counts are all correct:\n    @assert x[end-1] == channels_in(pdims) DimensionMismatch(\"Data input channel count ($(x[end-1]) vs. $(channels_in(pdims)))\")\n    @assert y[end-1] == channels_out(pdims) DimensionMismatch(\"Data output channel count ($(y[end-1]) vs. $(channels_out(pdims)))\")\n\n    # Next, check that the spatial dimensions match up\n    @assert x[1:end-2] == input_size(pdims) DimensionMismatch(\"Data input spatial size ($(x[1:end-2]) vs. $(input_size(pdims)))\")\n    @assert y[1:end-2] == output_size(pdims) DimensionMismatch(\"Data output spatial size ($(y[1:end-2]) vs. $(output_size(pdims)))\")\n\n    # Finally, check that the batch size matches\n    @assert x[end] == y[end] DimensionMismatch(\"Batch size ($(x[end]) vs. $(y[end]))\")\nend\n"
  },
  {
    "path": "src/dim_helpers.jl",
    "content": "# Various helper functions to calculate dimensions for operations\ninclude(\"dim_helpers/ConvDims.jl\")\ninclude(\"dim_helpers/DenseConvDims.jl\")\ninclude(\"dim_helpers/DepthwiseConvDims.jl\")\ninclude(\"dim_helpers/PoolDims.jl\")\n\n\n\"\"\"\n    transpose_swapbatch(x::AbstractArray)\n\nGiven an AbstractArray, swap its batch and channel axes, as we must during transposed\nconvolution.  We do this to the operands during convolution, and then again to the\noutput once we're done.\n\"\"\"\nfunction transpose_swapbatch(x::AbstractArray)\n    return permutedims(x, ((1:(ndims(x)-2))..., ndims(x), ndims(x)-1))\nend\nfunction transpose_swapbatch(x::Tuple)\n    return (x[1:end-2]..., x[end], x[end-1])\nend\n\n\"\"\"\n    transpose_pad(cdims::ConvDims)\n\nTransposed convolution can be calculated in terms of typical convolution with some extra\npadding.  This method computes the padding of the convolution that would result in the\ntransposed convolution of two operands, in essence taking care of that \"extra padding\".\nNote that this method should almost always be accompanied by a call that predilates one\nof the operands.\n\"\"\"\nfunction transpose_pad(cdims::ConvDims)\n    I = input_size(cdims)\n    K = kernel_size(cdims)\n    D = dilation(cdims)\n    P = padding(cdims)\n    S = stride(cdims)\n    return ntuple(length(P)) do i\n        hi = ceil(Int, i/2)\n        if mod(i, 2) == 1\n            return (K[hi] - 1)*D[hi] - P[i]\n        else\n            return (K[hi] - 1)*D[hi] - P[i] + mod(I[hi] + P[i-1] + P[i] - (K[hi] - 1)*D[hi] - 1, S[hi])\n        end\n    end\nend\n\n\"\"\"\n    insert_singleton_spatial_dimension(cdims::ConvDims)\n\nWhen converting a 1d convolution to a 2d, or a 2d to a 3d, we need to insert a singleton\nspatial dimension at the end of the spatial dimensions.  This does so for a ConvDims.\n\"\"\"\n@inline function insert_singleton_spatial_dimension(cdims::C) where {C <: ConvDims}\n    return basetype(C)(cdims;\n        I=(input_size(cdims)..., 1),\n        K=(kernel_size(cdims)..., 1),\n        S=(stride(cdims)..., 1),\n        # Padding is always the problem child....\n        P=(padding(cdims)..., 0, 0),\n        D=(dilation(cdims)..., 1),\n    )\nend\n\n# We specialize common cases\n@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,3}) where {T}\n    return reshape(x, size(x,1), 1, size(x,2), size(x,3))\nend\n@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,4}) where {T}\n    return reshape(x, size(x,1), size(x,2), 1, size(x,3), size(x,4))\nend\n\n# Helper to do this as many times as needed\n@inline function insert_singleton_spatial_dimension(x, reps::Int)\n    for r in 1:reps\n        x = insert_singleton_spatial_dimension(x)\n    end\n    return x\nend\n\n\"\"\"\n    predilated_size(x_size::Tuple, dilation::Tuple)\n\nCalculate the size of a predilated `x` given a particular dilation factor.  This is used\nwithin `predilate()` and `transpose_cdims()`.\n\"\"\"\nfunction predilated_size(x_size::NTuple{N}, dilation::NTuple{M}) where {N, M}\n    @assert (M == N - 2) DimensionMismatch(\"len(dilation) != number of spatial dims\")\n    return ntuple(N) do idx\n        if idx <= N - 2\n            return (x_size[idx] - 1)*dilation[idx] + 1\n        else\n            x_size[idx]\n        end\n    end\nend\n\n\"\"\"\n    predilate(x, dilation::Tuple)\n\nPlaces elements of `x` within a lattice of zeros, used in expressing a transposed\nconvolution in terms of normal convolution.  Note that while we call this \"predilation\"\nfor aesthetic reasons, you are typically passing a \"stride\" value into here.  Yes,\ntransposed convolution is confusing.\n\"\"\"\nfunction predilate(x::AbstractArray{T,N}, dilation::NTuple{M}) where {T, N, M}\n    @assert (M == N - 2) DimensionMismatch(\"len(dilation) != number of spatial dims\")\n\n    # If there is no dilation to be done, then ignore it.\n    if all(dilation .== 1)\n        return x\n    end\n\n    # Validate dilation factors\n    for idx in 1:length(dilation)\n        @assert dilation[idx] >= 1 ArgumentError(\"dilation cannot be less than 1\")\n    end\n\n    # Create new x that is bigger and holier\n    x_dil = zeros(eltype(x), predilated_size(size(x), dilation))\n\n    # Fill in strategic locations within `x_dil`, such that there are `dilation[idx] - 1`\n    # zeros between each element of `x` along each spatial dimension.\n    x_dil[(1:dilation[idx]:size(x_dil,idx) for idx in 1:(N-2))..., :, :] .= x\n    return x_dil\nend\n\n\"\"\"\n    flipweight(w::AbstractArray)\n\nReorders the weight tensor for supporting both convolution and cross-correlation operations.\n\"\"\"\n\n# For any array with ndims <= 3 it makes no sense to flip the weights so simply return the\n# original array\n@inline flipweight(w::AbstractArray) = w\n\n@inline flipweight(w::AbstractArray{T, 4}) where {T} = w[end:-1:1, end:-1:1, :, :]\n\n@inline flipweight(w::AbstractArray{T, 5}) where {T} = w[end:-1:1, end:-1:1, end:-1:1, :, :]\n"
  },
  {
    "path": "src/dropout.jl",
    "content": "\n\"\"\"\n    dropout([rng], A, p; [dims])\n\nReturns an array in which each element of `A` is either replaced with zero,\nwith probability `p`, or else multiplied by `1/(1-p)`.\n\nBy default every element is treated independently.\nWith keyword `dims=1`, a choice is made for every value of the 1st index\ni.e. each row of a matrix is either zero or not.\n\nOptional first argument is the random number generator used.\n\n# Examples\n```julia-repl\njulia> dropout(ones(2, 10), 0.2)\n2×10 Matrix{Float64}:\n 1.25  1.25  0.0   1.25  1.25  1.25  1.25  1.25  1.25  1.25\n 1.25  1.25  1.25  0.0   1.25  1.25  0.0   1.25  1.25  1.25\n\njulia> mean(dropout(ones(10^4, 5), 0.2), dims=1)\n1×5 Matrix{Float64}:\n 0.998  1.00075  0.99125  0.99575  1.00075\n\njulia> dropout(ones(5, 5), 0.7, dims=1)  # whole row the same\n5×5 Matrix{Float64}:\n 3.33333  3.33333  3.33333  3.33333  3.33333\n 0.0      0.0      0.0      0.0      0.0\n 0.0      0.0      0.0      0.0      0.0\n 3.33333  3.33333  3.33333  3.33333  3.33333\n 0.0      0.0      0.0      0.0      0.0\n\njulia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1)\n1×5 Matrix{Float64}:\n 1.00571  1.00571  1.00571  1.00571  1.00571\n```\n\"\"\"\ndropout(A::AbstractArray, p::Real; dims = :) = dropout(_rng_from_array(A), A, p; dims)\n\nfunction dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)\n    _rng_compat_array(rng, A)\n    T = float(eltype(A))\n    0 <= p <= 1 || throw(ArgumentError(\"dropout expects a probability 0 <= p <= 1\"))\n    if p > 0\n        dst = similar(A, T, size(A))\n        pT = convert(real(T), p)\n        _dropout!(rng, dst, A, pT, dims)\n    else\n        # Not so sure we want fast paths... this tries but doesn't guarantee type-stability,\n        # and the rrule does not have such a fast paths.\n        convert(AbstractArray{T}, A)\n    end\nend\n\n\"\"\"\n    dropout!(B, A, p; [dims])\n\nThis does exactly `B .= dropout(A, p; dims)`,\nor rather, it's the implementation of out-of-place [`dropout`](@ref).\n\"\"\"\ndropout!(B::AbstractArray, A::AbstractArray, p::Real; dims = :) = dropout!(_rng_from_array(B), B, A, p; dims)\n\nfunction dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real; dims=:)\n    size(dst) == size(src) || throw(DimensionMismatch(\"dropout! expects output array the same size as input\"))\n    0 <= p <= 1 || throw(ArgumentError(\"dropout expects a probability 0 <= p <= 1\"))\n    _rng_compat_array(rng, src)\n    if p > 0\n        pT = convert(real(eltype(dst)), p)\n        _dropout!(rng, dst, src, pT, dims)\n    else\n        # This fast path isn't free, but no concerns about types changing:\n        copyto!(dst, src)\n    end\nend\n\n# This is the easy case in that we can safely use the output array for random numbers.\nfunction _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims::Colon)\n    T = real(eltype(dst))\n    val = convert(T, 1/(1-p))\n    rand!(rng, dst)\n    ## This is what we want, but it hits a SIMD bug, solved by _fast_broadcast!\n    # dst .= (dst.>p) .* val .* src\n    _fast_broadcast!(dst, src) do q, x\n        ((real(q)>p) * val) * x\n    end\n    dst\nend\n\n# For other dims, we we do need to allocate something.\nfunction _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims)\n    T = real(eltype(dst))\n    tmp = similar(dst, T, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src)))\n    rand!(rng, tmp)\n    val = convert(T, 1/(1-p))\n    ## One-pass strategy -- faster on GPU\n    dst .= ((tmp.>p) .* val) .* src\n    ## Two-pass strategy -- slightly faster on some CPUs?\n    # _fast_broadcast!(tmp) do q\n    #     (q>p) * val\n    # end\n    # dst .= tmp .* src\nend\n\n# The gradient needs to keep the random choices made, thus store at least a BitArray,\n# but the following way turns out to be faster & simpler:\nfunction ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)\n    T = float(real(eltype(A)))\n    val = convert(T, 1/(1-p))\n    keep = if dims isa Colon\n        similar(A, T, size(A))\n    else\n        similar(A, T, ntuple(d -> d in dims ? size(A,d) : 1, ndims(A)))\n    end\n    rand!(rng, keep)\n    Y = @. ((keep>p) * val) * A\n    function dropout_back(Δ)\n        dY = unthunk(Δ)\n        dA = @. ((keep>p) * val) * dY\n        (NoTangent(), NoTangent(), dA, NoTangent())\n    end\n    return Y, dropout_back\nend\n# Possibly TODO: another approach to the gradient would be to copy the RNG\n# and then re-generate the same mask, instead of storing it. This saves memory\n# and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking.\n# https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402\n\n\n\"\"\"\n    _rng_from_array(x)\n\nReturn the random number generator most appropriate for `x`:\n`CUDA.default_rng()` for `CuArray`, else `Random.default_rng()`\n\"\"\"\n_rng_from_array(::AbstractArray) = Random.default_rng()\n\n@non_differentiable _rng_from_array(::Any)\n\n# This exists because `rand!(default_rng(), CUDA.rand(3))` ignores the RNG,\n# and Flux would prefer an error. NNlibCUDAExt will overload it to produce that.\n_rng_compat_array(::AbstractRNG, ::AbstractArray) = nothing\n"
  },
  {
    "path": "src/fold.jl",
    "content": "\"\"\"\n    unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)\n\nPlaces sliding windows of x into a container tensor of size `(num_windows,\nwindow_size, batchsize)`. The window size is determined by the `prod(spatial dims\nof kernel)*input_channels`. The number of sliding windows will match those of\nconvolution (`conv`) with the same kernel_size and arguments. Note that\nby default `conv` flips the spatial dimensions of its kernel (default\n`flipped=false`), whereas `unfold` does not (default `flipped=true`).\nUses `NNlib.im2col!` as backend.\n\nSee also [`fold`](@ref), the adjoint/transpose operator\nand a potential inverse of `unfold`.\n\n# Example\nThe below example demonstrates that `unfold` uses the same sliding windows as `conv`.\nIn general [`batched_mul`](@ref) + `unfold` should not be used to achieve convolution.\n```jldoctest\njulia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1);  # 1D data, 1 channel, batch of 1\n\njulia> w = reshape([1 0 -1], 3, 1, 1);  # 1D conv kernel of length 3\n\njulia> kws = (pad=1, stride=2, flipped=true);  # use same args for conv and unfold\n\njulia> z = NNlib.unfold(x, size(w); kws...)\n4×3×1 Array{Int64, 3}:\n[:, :, 1] =\n  0  100   2\n  2    3  40\n 40    5   6\n  6  700   0\n\njulia> y1 = conv(x, w; kws...)\n4×1×1 Array{Int64, 3}:\n[:, :, 1] =\n  -2\n -38\n  34\n   6\n\njulia> y2 = z ⊠ w  # ⊠ (\\\\boxtimes) is NNlib.batched_mul\n4×1×1 Array{Int64, 3}:\n[:, :, 1] =\n  -2\n -38\n  34\n   6\n```\n\"\"\"\nfunction unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N}\n    stride = expand(Val(N - 2), stride)\n    padding = expand(Val(N - 2), pad)\n    dilation = expand(Val(N - 2), dilation)\n    cdims = DenseConvDims(size(x), kernel_size; stride, padding, dilation, flipkernel=flipped)\n    return unfold(x, cdims)\nend\n\n\"\"\"\n    fold(y, output_size, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)\n\nThe adjoint/transpose operator of `unfold`. It accumulates sliding windows from\nthe output of `unfold` into a container tensor of size `output_size`. An inverse\nto `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues\nwith a divisor (see example). Uses `NNlib.col2im!` as backend.\n\nSee also [`unfold`](@ref).\n\n# Example\n```jldoctest\njulia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1);  # 1D data, 1 channel, batch of 1\n\njulia> y = NNlib.unfold(x, (3,1,1))  # sliding window of size 3\n5×3×1 Array{Int64, 3}:\n[:, :, 1] =\n 100   2    3\n   2   3   40\n   3  40    5\n  40   5    6\n   5   6  700\n\njulia> z = NNlib.fold(y, size(x), (3,1,1))  # sum of contributions in y. 100 appears once, 40 three times\n7×1×1 Array{Int64, 3}:\n[:, :, 1] =\n 100\n   4\n   9\n 120\n  15\n  12\n 700\n\njulia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3,1,1))\n7×1×1 Array{Float64, 3}:\n[:, :, 1] =\n 1.0\n 2.0\n 3.0\n 3.0\n 3.0\n 2.0\n 1.0\n\njulia> z ./ divisor\n7×1×1 Array{Float64, 3}:\n[:, :, 1] =\n 100.0\n   2.0\n   3.0\n  40.0\n   5.0\n   6.0\n 700.0\n```\nIn general, an inverse to `unfold` does not exist if `divisor` contains zeros.\n\"\"\"\nfunction fold(x::AbstractArray{T, 3}, output_size::NTuple{N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N}\n    stride = expand(Val(N - 2), stride)\n    padding = expand(Val(N - 2), pad)\n    dilation = expand(Val(N - 2), dilation)\n    cdims = DenseConvDims(output_size, kernel_size; stride, padding, dilation, flipkernel=flipped)\n    return fold(x, output_size, cdims)\nend\n\n# im2col_dims returns (numblocks, blocksize, threadnum) where thread dim is used as thread-local\n# workspace for multithreaded conv. Ultimately, we want to threadnum with batchsize.\nunfold_dims(cdims::DenseConvDims) = im2col_dims(cdims)[1:2]\n\n# auto-allocating versions\nfunction unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N}\n    y = similar(x, unfold_dims(cdims)..., size(x, N)) # (numblocks, blocksize, batchsize)\n    return unfold!(y, x, cdims)\nend\n\nfunction fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T}\n    x = similar(y, output_size)\n    return fold!(x, y, cdims)\nend\n\n# N < 5 -dimension in-place versions\nfunction unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N}\n    unfold!(\n        y,\n        insert_singleton_spatial_dimension(x, 5-N),\n        insert_singleton_spatial_dimension(cdims, 5-N),\n    )\n    return y\nend\n\nfunction fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N}\n    fold!(\n        insert_singleton_spatial_dimension(x, 5-N),\n        y,\n        insert_singleton_spatial_dimension(cdims, 5-N),\n    )\n    return x\nend\n\n# 5-dimension in-place versions\nfunction unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT}\n    @threads for batch_idx in 1:size(x, 5)\n        y_slice = view(y, :, :, batch_idx)\n        im2col!(y_slice, view(x, :, :, :, :, batch_idx), cdims)\n    end\n    return y\nend\n\nfunction fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {xT, yT}\n    @threads for batch_idx in 1:size(x, 5)\n        y_slice = view(y, :, :, batch_idx)\n        col2im!(view(x, :, :, :, :, batch_idx), y_slice, cdims)\n    end\n    return x\nend\n\n@kernel function unfold_kernel!(\n    col::AbstractArray{T}, x, col_size,\n    input_size, output_size, kernel_size,\n    flipkernel, stride, pad_lo, dilation, max_idx,\n) where T\n    index = @index(Global)\n\n    @inbounds if index ≤ max_idx\n        i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices\n        w, h, d = CartesianIndices(output_size)[i].I # x indices\n\n        # project\n        w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation\n\n        if !flipkernel\n            kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1\n        end\n\n        # check out of bounds\n        if !all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d)))\n            col[i, kw, kh, kd, c, b] = T(0)\n        else\n            xval::T = x[w, h, d, c, b]\n            col[i, kw, kh, kd, c, b] = xval\n        end\n    end\nend\n\n@kernel function fold_kernel!(\n    x::AbstractArray{T}, col, col_size,\n    input_size, output_size, kernel_size,\n    flipkernel, stride, pad_lo, dilation, max_idx,\n) where T\n    index = @index(Global)\n\n    @inbounds if index ≤ max_idx\n        i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices\n        w, h, d = CartesianIndices(output_size)[i].I # x indices\n\n        # project\n        w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation\n\n        # check out of bounds\n        if all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d)))\n            if !flipkernel\n                kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1\n            end\n\n            cval::T = col[i, kw, kh, kd, c, b]\n            @atomic x[w, h, d, c, b] += cval\n        end\n    end\nend\n\nfunction unfold!(\n    col::AnyGPUArray{cT,3}, x::AnyGPUArray{xT,5}, cdims::DenseConvDims,\n) where {cT, xT}\n    spatial_dims(cdims) != 3 && throw(DimensionMismatch(\n        \"unfold!() only accepts 3d convoluitional inputs\"))\n\n    C_in = channels_in(cdims)\n    ker_size = kernel_size(cdims)\n    pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)\n    pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)\n\n    out_size = output_size(cdims)\n    col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :))\n\n    max_idx = prod(size(col))\n    unfold_kernel!(get_backend(x))(\n        col_reshaped, x, size(col_reshaped),\n        input_size(cdims), out_size, ker_size,\n        flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx;\n        ndrange=max_idx)\n    return col\nend\n\nfunction fold!(\n    x::AnyGPUArray{xT,5}, col::AnyGPUArray{cT,3}, cdims::DenseConvDims,\n) where {xT, cT}\n    spatial_dims(cdims) != 3 && throw(DimensionMismatch(\n        \"fold!() only accepts 3d convoluitional inputs\"))\n\n    # going to accumulate into x\n    fill!(x, xT(0))\n\n    C_in = channels_in(cdims)\n    ker_size = kernel_size(cdims)\n    pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)\n    pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)\n    out_size = output_size(cdims)\n\n    col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :))\n\n    max_idx = prod(size(col))\n    fold_kernel!(get_backend(x))(\n        x, col_reshaped, size(col_reshaped),\n        input_size(cdims), out_size, ker_size,\n        flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx;\n        ndrange=max_idx)\n\n    return x\nend\n\n# reverse diff rules\nfunction rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...)\n    function unfold_pullback(Δ)\n        return (\n            NoTangent(),\n            fold(unthunk(Δ), size(x), cdims; kw...),\n            NoTangent(),\n        )\n    end\n    return unfold(x, cdims; kw...), unfold_pullback\nend\n\nfunction rrule(::typeof(fold), x, output_size, cdims::DenseConvDims; kw...)\n    function fold_pullback(Δ)\n        return (\n            NoTangent(),\n            unfold(unthunk(Δ), cdims; kw...),\n            NoTangent(),\n            NoTangent(),\n        )\n    end\n    return fold(x, output_size, cdims; kw...), fold_pullback\nend\n\n"
  },
  {
    "path": "src/functions.jl",
    "content": "\"\"\"\n    glu(x, dim = 1)\n\nThe gated linear unit from the [\"Language Modeling with Gated Convolutional Networks\"](https://arxiv.org/abs/1612.08083) paper.\n\nCalculates `a .* sigmoid(b)`, where `x` is split in half along given dimension `dim` to form `a` and `b`.\n\"\"\"\nfunction glu(x, dim = 1)\n    maxdim = size(x, dim)\n    @assert maxdim % 2 == 0 \"Dimension must be even\"\n    half = maxdim ÷ 2\n    a, b = selectdim(x, dim, 1:half), selectdim(x, dim, half+1:maxdim)\n    a .* sigmoid.(b)\nend\n\n"
  },
  {
    "path": "src/gather.jl",
    "content": "\"\"\"\n    NNlib.gather(src, idx) -> dst\n\nReverse operation of [`scatter`](@ref). Gathers data from source `src`\nand writes it in a destination `dst` according to the index\narray `idx`.\nFor each `k` in `CartesianIndices(idx)`, assign values to `dst`\naccording to\n\n    dst[:, ... , k] .= src[:, ... , idx[k]...]\n\nNotice that if `idx` is a vector containing integers\nand `src` is a matrix, previous expression simplifies to\n\n    dst[:, k] .= src[:, idx[k]]\n\nand `k` will run over `1:length(idx)`.\n\nThe elements of `idx` can be integers or integer tuples and may be repeated.\nA single `src` column can end up being copied into zero, one,\nor multiple `dst` columns.\n\nSee [`gather!`](@ref) for an in-place version.\n\n# Examples\n\n```jldoctest\njulia> NNlib.gather([1,20,300,4000], [2,4,2])\n3-element Vector{Int64}:\n   20\n 4000\n   20\n\njulia> NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1])\n2×5 Matrix{Int64}:\n 1  3  1  3  1\n 4  6  4  6  4\n```\n\"\"\"\nfunction gather(\n    src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{Tidx, Nidx},\n) where {Tsrc, Nsrc, Nidx, Tidx}\n    M = typelength(Tidx)\n    dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)\n    dst = similar(src, Tsrc, dstsize)\n    return gather!(dst, src, idx)\nend\n\n\"\"\"\n    gather(src, IJK...)\n\nConvert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and\ncall `gather` on it: `gather(src, CartesianIndex.(IJK...))`.\n\n# Examples\n\n```jldoctest\njulia> src = reshape([1:15;], 3, 5)\n3×5 Matrix{Int64}:\n 1  4  7  10  13\n 2  5  8  11  14\n 3  6  9  12  15\n\njulia> NNlib.gather(src, [1, 2], [2, 4])\n2-element Vector{Int64}:\n  4\n 11\n```\n\"\"\"\nfunction gather(\n    src::AbstractArray{Tsrc, Nsrc},\n    I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer},\n    Ks::AbstractVector{<:Integer}...,\n) where {Nsrc, Tsrc}\n    return gather(src, to_cartesian_index(I, J, Ks...))\nend\n\nto_cartesian_index(IJK...) = CartesianIndex.(IJK...)\n\n@non_differentiable to_cartesian_index(::Any...)\n\"\"\"\n    NNlib.gather!(dst, src, idx)\n\nReverse operation of [`scatter!`](@ref). Gathers data from source `src`\nand writes it in destination `dst` according to the index array `idx`.\nFor each `k` in `CartesianIndices(idx)`, assign values to `dst` according to\n\n    dst[:, ... , k] .= src[:, ... , idx[k]...]\n\nNotice that if `idx` is a vector containing integers,\nand both `dst` and `src` are matrices, previous expression simplifies to\n\n    dst[:, k] .= src[:, idx[k]]\n\nand `k` will run over `1:length(idx)`.\n\nThe elements of `idx` can be integers or integer tuples and may be repeated.\nA single `src` column can end up being copied into zero, one,\nor multiple `dst` columns.\n\nSee [`gather`](@ref) for an allocating version.\n\"\"\"\nfunction gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)\n    dims = scatter_dims(src, dst, idx)\n    colons = ntuple(i -> Colon(), dims)\n    for k in CartesianIndices(idx)\n        _view(dst, colons, k) .= _view(src, colons, idx[k])\n    end\n    return dst\nend\n\nfunction gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray)\n    isempty(dst) && return dst\n    n_dims = scatter_dims(src, dst, idx)\n    dims = size(src)[1:n_dims]\n    max_dims_idx = prod(dims)\n    ndrange = max_dims_idx * length(idx)\n    _gather!(KernelAbstractions.get_backend(src))(\n        dst, src, idx, CartesianIndices(dims), max_dims_idx; ndrange)\n    return dst\nend\n\n@kernel function _gather!(\n    dst, @Const(src), @Const(idx),\n    dim_ids::CartesianIndices, max_dims_idx::Int,\n)\n    i = @index(Global)\n    j, k = divrem(i - 1, max_dims_idx)\n    @inbounds dst[i] = src[dim_ids[k + 1], Tuple(idx[j + 1])...]\nend\n\n∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx)\n\nfunction rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)\n    y = gather!(dst, src, idx)\n    src_size = size(src)\n    gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent())\n    return y, gather!_pullback\nend\n"
  },
  {
    "path": "src/gemm.jl",
    "content": "## Low level gemm! call with pointers\n## Borrowed from Knet.jl, adapted for compile-time constants\n\nusing LinearAlgebra.BLAS: get_num_threads, set_num_threads\n\nif isdefined(LinearAlgebra.BLAS, :libblastrampoline)\n    const libblas = LinearAlgebra.BLAS.libblastrampoline\nelse\n    const libblas = Base.libblas_name\nend\n\n\"\"\"\n    gemm!()\n\nLow-level gemm!() call with pointers, borrowed from Knet.jl\n\nCalculates `C = alpha*op(A)*op(B) + beta*C`, where:\n  - `transA` and `transB` set `op(X)` to be either `identity()` or `transpose()`\n  - alpha and beta are scalars\n  - op(A) is an (M, K) matrix\n  - op(B) is a (K, N) matrix\n  - C is an (M, N) matrix.\n\"\"\"\ngemm!\n\n# These are the datatypes we have fast GEMM for\ngemm_datatype_mappings = (\n    (:dgemm_, Float64),\n    (:sgemm_, Float32),\n    (:zgemm_, ComplexF64),\n    (:cgemm_, ComplexF32),\n)\nfor (gemm, elt) in gemm_datatype_mappings\n    @eval begin\n        @inline function gemm!(transA::Val, transB::Val,\n                               M::Int, N::Int, K::Int,\n                               alpha::$(elt), A::Ptr{$elt}, B::Ptr{$elt},\n                               beta::$(elt), C::Ptr{$elt})\n            # Convert our compile-time transpose marker to a char for BLAS\n            convtrans(V::Val{false}) = 'N'\n            convtrans(V::Val{true})  = 'C'\n\n            if transA == Val(false)\n                lda = M\n            else\n                lda = K\n            end\n            if transB == Val(false)\n                ldb = K\n            else\n                ldb = N\n            end\n            ldc = M\n            ccall((@blasfunc($(gemm)), libblas), Nothing,\n                  (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},\n                   Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},\n                   Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},\n                   Ref{BlasInt}),\n                  convtrans(transA), convtrans(transB), M, N, K,\n                  alpha, A, lda, B, ldb, beta, C, ldc)\n        end\n    end\nend\n\nfor (gemm, elt) in gemm_datatype_mappings\n    @eval begin\n        @inline function batched_gemm!(transA::AbstractChar,\n                               transB::AbstractChar,\n                               alpha::($elt),\n                               A::AbstractArray{$elt, 3},\n                               B::AbstractArray{$elt, 3},\n                               beta::($elt),\n                               C::AbstractArray{$elt, 3})\n            @assert !Base.has_offset_axes(A, B, C)\n            @assert size(A, 3) == 1 || size(A, 3) == size(C, 3) \"batch size mismatch: A != C\"\n            @assert size(B, 3) == 1 || size(B, 3) == size(C, 3) \"batch size mismatch: B != C\"\n\n            m = size(A, transA == 'N' ? 1 : 2)\n            ka = size(A, transA == 'N' ? 2 : 1)\n            kb = size(B, transB == 'N' ? 1 : 2)\n            n = size(B, transB == 'N' ? 2 : 1)\n            if ka != kb || m != size(C,1) || n != size(C,2)\n                throw(DimensionMismatch(\"A1 has size ($m,$ka), B1 has size ($kb,$n), C1 has size $(size(C)[1:2])\"))\n            end\n            LinearAlgebra.BLAS.chkstride1(A)\n            LinearAlgebra.BLAS.chkstride1(B)\n            LinearAlgebra.BLAS.chkstride1(C)\n\n            ptrA = pointer(A)\n            ptrB = pointer(B)\n            ptrC = pointer(C)\n\n            strA = size(A, 3) == 1 ? 0 : Base.stride(A, 3)\n            strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3)\n            strC = Base.stride(C, 3)\n\n            n_threads = min(\n                Threads.nthreads(:default),\n                1 + max(length(A), length(B)) ÷ 8000)\n            # In some tests, size (20,20,20) is worth splitting between two threads,\n            # as is size (32,32,8).\n\n            if n_threads > 1\n\n                old_threads = get_num_threads()\n                set_num_threads(1)\n\n                parts = Iterators.partition(1:size(C, 3), cld(size(C, 3), n_threads))\n\n                function gemm!_part(ks)\n                    for k in ks\n\n                        ptrAk = ptrA + (k-1) * strA * sizeof($elt)\n                        ptrBk = ptrB + (k-1) * strB * sizeof($elt)\n                        ptrCk = ptrC + (k-1) * strC * sizeof($elt)\n\n                        ccall((@blasfunc($(gemm)), libblas), Nothing,\n                            (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},\n                            Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},\n                            Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},\n                            Ref{BlasInt}),\n                            transA, transB, m, n,\n                            ka, alpha, ptrAk, max(1,Base.stride(A,2)),\n                            ptrBk, max(1,Base.stride(B,2)), beta, ptrCk,\n                            max(1,Base.stride(C,2)))\n                    end\n                end\n                if should_use_spawn() && length(parts) > 1\n                    Threads.@sync for ks in parts\n                        Threads.@spawn gemm!_part(ks)\n                    end\n                else\n                    for ks in parts\n                        gemm!_part(ks)\n                    end\n                end\n                set_num_threads(old_threads)\n\n            else # small problem, no threads\n\n                for k in 1:size(C, 3)\n                    # Identical loop body\n\n                    ptrAk = ptrA + (k-1) * strA * sizeof($elt)\n                    ptrBk = ptrB + (k-1) * strB * sizeof($elt)\n                    ptrCk = ptrC + (k-1) * strC * sizeof($elt)\n\n                    ccall((@blasfunc($(gemm)), libblas), Nothing,\n                          (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},\n                           Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},\n                           Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},\n                           Ref{BlasInt}),\n                          transA, transB, m, n,\n                          ka, alpha, ptrAk, max(1,Base.stride(A,2)),\n                          ptrBk, max(1,Base.stride(B,2)), beta, ptrCk,\n                          max(1,Base.stride(C,2)))\n                end\n\n            end\n\n            return C\n        end\n    end\nend\n"
  },
  {
    "path": "src/impl/conv_direct.jl",
    "content": "## This file contains direct Julia implementations of 2d and 3d convolutions\n\n# Helper functions for restricting x/w overreach\nfunction clamp_lo(x, w)\n    idx = 1\n    while idx <= length(x) && x[idx] <= 0\n        idx += 1\n    end\n    return (x[idx:end], w[idx:end])\nend\nfunction clamp_hi(x, w, L)\n    idx = length(x)\n    while idx >= 1 && x[idx] > L\n        idx -= 1\n    end\n    return (x[1:idx], w[1:idx])\nend\n\n\"\"\"\n    conv_direct!(y, x, w, cdims; alpha=1, beta=0)\n\nDirect convolution implementation; used for debugging, tests, and mixing/matching of\nstrange datatypes within a single convolution.  Uses naive nested for loop implementation\nand does not attempt to optimize performance.  Rather, this implementation is intended to\nbe maximally understandable and debuggable, to aid in testing other, more performant\nimplementations.  We also explicitly support mixing and matching of strange datatypes,\nso that if the user really wants to convolve an image of `UInt8`'s with a `Float16`\nkernel, storing the result in a `Float32` output, there is at least a function call\nfor that madness.\n\nThe keyword arguments `alpha` and `beta` control accumulation behavior; this function\ncalculates `y = alpha * x * w + beta * y`, therefore by setting `beta` to a nonzero\nvalue, the user is able to accumulate values into a preallocated `y` buffer, or by\nsetting `alpha` to a nonunitary value, an arbitrary gain factor can be applied.\n\nBy defaulting `beta` to `false`, we make use of the Bradbury promotion trick to override\n`NaN`'s that may pre-exist within our output buffer, as `false*NaN == 0.0`, whereas\n`0.0*NaN == NaN`.  Only set `beta` if you are certain that none of the elements within\n`y` are `NaN`.\n\nThe basic implementation performs 3-dimensional convolution; 1-dimensional and 2-\ndimensional cases are supported by simply reshaping `y`, `x` and `w`, for which\nwrapper methods are available.\n\"\"\"\nconv_direct!\n\nfunction conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},\n                      w::AbstractArray{wT,5}, cdims::DenseConvDims;\n                      alpha::yT = yT(1), beta = false) where {yT, xT, wT}\n    conv_direct!(\n        y, x, w, cdims,\n        Val(kernel_size(cdims)), Val(channels_out(cdims)),\n        Val(padding(cdims)), Val(dilation(cdims)), Val(stride(cdims)),\n        Val(flipkernel(cdims)); alpha, beta)\n    return y\nend\n\n\nfunction conv_direct!(\n    y::AbstractArray{yT,5}, x::AbstractArray{xT,5},\n    w::AbstractArray{wT,5}, cdims::DenseConvDims,\n    # kernel size, output channels, padding, dilation, stride, flipped kernel\n    ::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S}, fk::Val{F};\n    alpha::yT = yT(1), beta = false,\n) where {yT, xT, wT, K, C, P, D, S, F}\n    check_dims(size(x), size(w), size(y), cdims)\n\n    width, height, depth = input_size(cdims)\n    kernel_w, kernel_h, kernel_d = K\n    pad_w_lo, _, pad_h_lo, _, pad_d_lo, _ = P\n    dil_w, dil_h, dil_d = D\n    stride_w, stride_h, stride_d = S\n\n    # Create a method that determines how we're going to index into `w`.\n    kproj(k, _, ::Val{true}) = k\n    kproj(k, M, ::Val{false}) = M - k + 1\n\n    # A helper function to project from output (w, h) to input (input_w, input_h)\n    project(idx, stride, pad) = (idx - 1)*stride - pad + 1\n\n    # Use `calc_padding_regions` to determine where we do or don't need to worry about padding\n    padded_regions, central_region = calc_padding_regions(cdims)\n\n    # Set outputs to zero to support custom datatypes (https://github.com/FluxML/NNlib.jl/issues/490)\n    if iszero(beta)\n        y = fill!(y, zero(yT))\n    end\n\n    # Start with the central region\n    w_region, h_region, d_region = central_region\n    @inbounds for batch in 1:size(x, 5),\n        c_out in 1:C,\n        d_idx in d_region,\n        h_idx in h_region,\n        w_idx in w_region\n\n        # Since we're in the central region, we don't need to worry about clamping\n        dotprod = yT(0)\n        for c_in in 1:channels_in(cdims),\n            kd in 1:kernel_d,\n            kh in 1:kernel_h,\n            kw in 1:kernel_w\n\n            # Hoist me, you coward.\n            x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d\n            x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h\n            x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w\n\n            x_val = x[x_w, x_h, x_d, c_in, batch]\n            w_val = w[kproj(kw, kernel_w, fk),\n                    kproj(kh, kernel_h, fk),\n                    kproj(kd, kernel_d, fk),\n                    c_in, c_out]\n            dotprod = muladd(x_val, w_val, dotprod)\n        end\n        y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]\n    end\n\n    # Next, do potentially-padded regions:\n    @inbounds for (w_region, h_region, d_region) in padded_regions,\n        batch in 1:size(x, 5),\n        c_out in 1:C,\n        d_idx in d_region,\n        h_idx in h_region,\n        w_idx in w_region\n\n        # Probe for out-of-bounds accesses on `x` and `continue` if we hit one\n        dotprod = yT(0)\n        for c_in in 1:channels_in(cdims),\n            kd in 1:kernel_d\n\n            x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d\n            if x_d <= 0 || x_d > depth\n                continue\n            end\n\n            for kh in 1:kernel_h\n                x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h\n                if x_h <= 0 || x_h > height\n                    continue\n                end\n\n                for kw in 1:kernel_w\n                    x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w\n                    if x_w <= 0 || x_w > width\n                        continue\n                    end\n\n                    x_val = x[x_w, x_h, x_d, c_in, batch]\n                    w_val = w[kproj(kw, kernel_w, fk),\n                            kproj(kh, kernel_h, fk),\n                            kproj(kd, kernel_d, fk),\n                            c_in, c_out]\n                    dotprod = muladd(x_val, w_val, dotprod)\n                end\n            end\n        end\n\n        y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]\n    end\n\n    return y\nend\n\n## Gradient definitions\n\"\"\"\n    ∇conv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0)\n\nCalculate the gradient imposed upon `x` in the convolution `y = x * w`.\n\"\"\"\n∇conv_data_direct!\n\nfunction ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},\n                            w::AbstractArray{wT,5}, cdims::DenseConvDims;\n                            alpha::xT=xT(1), beta=false) where {xT, yT, wT}\n    w = conj(transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :]))\n    dy = predilate(dy, stride(cdims))\n    ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims),\n                                  dilation=dilation(cdims),\n                                  flipkernel=flipkernel(cdims))\n    dx = conv_direct!(dx, dy, w, ctdims; alpha=alpha, beta=beta)\n    return dx\nend\n\n\"\"\"\n    ∇conv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0)\n\nCalculate the gradient imposed upon `w` in the convolution `y = x * w`.\n\"\"\"\n∇conv_filter_direct!\n\nfunction ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},\n                              dy::AbstractArray{yT,5}, cdims::DenseConvDims;\n                              alpha::wT=wT(1), beta=false) where {xT, yT, wT}\n    x = conj(transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :]))\n    dy = transpose_swapbatch(predilate(dy, stride(cdims)))\n    ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims),\n                                    stride=dilation(cdims))\n    dw_ = if flipkernel(cdims)\n        view(dw, reverse(axes(dw, 1)), reverse(axes(dw, 2)), reverse(axes(dw, 3)), :, :)\n    else\n        dw\n    end\n    conv_direct!(dw_, dy, x, ctdims; alpha=alpha, beta=beta)\n    return dw\nend\n"
  },
  {
    "path": "src/impl/conv_im2col.jl",
    "content": "## This file contains im2col-backed implementations of convolution for 2d and 3d\n## convolutions.  Expect to see a lot of indexing.\n\n# Helper function for flipkernel-induced dyslexia\nfunction kernel_index(w, h, d, cdims::ConvDims)\n    flipkernel(cdims) && return (w, h, d)\n    kernel_w, kernel_h, kernel_d = kernel_size(cdims)\n    return (kernel_w - w + 1, kernel_h - h + 1, kernel_d - d + 1)\nend\n\n\"\"\"\n    conv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0)\n\nPerform a convolution using im2col and GEMM, store the result in `y`.  The  kwargs\n`alpha` and `beta` control accumulation behavior; internally this operation is\nimplemented as a matrix multiply that boils down to `y = alpha * x * w + beta * y`, thus\nby setting `beta` to a nonzero value, multiple results can be accumulated into `y`, or\nby setting `alpha` to a nonunitary value, various gain factors can be applied.\n\nNote for the particularly performance-minded, you can provide a pre-allocated `col`,\nwhich should eliminate any need for large allocations within this method.\n\"\"\"\nfunction conv_im2col!(\n                y::AbstractArray{T,5}, x::AbstractArray{T,5},\n                w::AbstractArray{T,5}, cdims::DenseConvDims;\n                col::AbstractArray{T,3}=similar(x, im2col_dims(cdims)),\n                alpha::T=T(1), beta::T=T(0),\n                ntasks::Int=nthreads()) where {T}\n    check_dims(size(x), size(w), size(y), cdims)\n\n    #   COL   *    W    ->    Y\n    # [M x K] * [K x N] -> [M x N]\n    #\n    #  M: output spatial resolution\n    #  N: output channels\n    #  K: size of input \"patch\" (kernel size and input channels combined)\n    #\n    # In english, we're grabbing each input patch and laying them out along\n    # the M dimension in `col`, so that the GEMM call below multiplies each\n    # kernel (which is kernel_h * kernel_w * channels_in elments long) is\n    # dotproducted with that input patch, effectively computing a convolution\n    # in a somewhat memory-wasteful but easily-computed way (since we already\n    # have an extremely highly-optimized GEMM call available in BLAS).\n    M = prod(output_size(cdims))\n    N = channels_out(cdims)\n    K = prod(kernel_size(cdims))*channels_in(cdims)\n\n    parts = Iterators.partition(axes(x, 5), ceil(Int, size(x, 5) / ntasks))\n\n    function conv_part(task_n, part)\n        col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace\n        for batch_idx in part\n            im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)\n            GC.@preserve col_slice w y begin\n                col_ptr = pointer(col_slice)\n                w_ptr = pointer(w)\n                y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)\n                gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)\n            end\n        end\n    end\n\n    if should_use_spawn() && length(parts) > 1\n        @sync for (task_n, part) in enumerate(parts)\n            Threads.@spawn conv_part(task_n, part)\n        end\n    else\n        for (task_n, part) in enumerate(parts)\n            conv_part(task_n, part)\n        end\n    end\n    return y\nend\n\n\"\"\"\n    ∇conv_filter_im2col!(dw, x, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims));\n                         alpha=1, beta=0)\n\nConv backward pass onto the weights using im2col and GEMM; stores the result in `dw`.\nSee [`conv_im2col!`](@ref) for explanation of optional parameters.\n\"\"\"\nfunction ∇conv_filter_im2col!(\n                dw::AbstractArray{T,5}, x::AbstractArray{T,5},\n                dy::AbstractArray{T,5}, cdims::DenseConvDims;\n                col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)),\n                alpha::T=T(1), beta::T=T(0)) where {T}\n    check_dims(size(x), size(dw), size(dy), cdims)\n\n    #   COL'   *   dY   ->    dW\n    # [M x K] * [K x N] -> [M x N]\n    #\n    #  M: size of input \"patch\" (kernel size and input channels combined)\n    #  N: output channels\n    #  K: output spatial resolution\n    #\n    # In english, we're grabbing each input patch and laying them out along\n    # the K dimension in `col`, then multiplying in `dY` to compute a dot\n    # product between all pixels in the input that were multiplied by a single\n    # position in the W kernel, and all output pixels of the same location,\n    # across output channels.  This slice of `col` therefore constitutes every\n    # input pixel that touched a particular element of the kernel.\n    #\n    # This is identical to a convolution between x and a dimension-permuted dY,\n    # where we\n\n    M = prod(kernel_size(cdims))*channels_in(cdims)\n    N = channels_out(cdims)\n    K = prod(output_size(cdims))\n\n    for batch_idx in 1:size(x,5)\n        col_slice = view(col, :, :, 1)\n\n        im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)\n        GC.@preserve col_slice dw dy begin\n            col_ptr = pointer(col_slice)\n            dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1)\n            dw_ptr = pointer(dw)\n            gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)\n        end\n\n        # Because we accumulate over batches in this loop, we must set `beta` equal\n        # to `1.0` from this point on.\n        beta = T(1)\n    end\n    return dw\nend\n\n\"\"\"\n    ∇conv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)\n\nConv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`.\nSee [`conv_im2col!`](@ref) for explanation of optional parameters.\n\"\"\"\nfunction ∇conv_data_im2col!(\n                dx::AbstractArray{T,5}, dy::AbstractArray{T,5},\n                w::AbstractArray{T,5}, cdims::DenseConvDims;\n                col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)),\n                alpha::T=T(1), beta::T=T(0),\n                ntasks::Int=nthreads()) where {T}\n    check_dims(size(dx), size(w), size(dy), cdims)\n\n    #    dY        W'   ->    dX\n    # [M x K] * [K x N] -> [M x N]\n    #\n    #  M: output spatial resolution\n    #  N: size of input \"patch\" (kernel size and input channels combined)\n    #  K: output channels\n    #\n    # In english, we're taking the output image and laying it out by pixel,\n    # with channels lying along the `K` dimension in `col`.  We then multiply\n    # in `W'` to compute a dot product between each pixel location and the\n    # entire kernel.  This dot product therefore constitutes every output pixel\n    # that was a function of a particular input pixel.\n    #\n    # This is identical to a transposed convolution between dY and W\n\n    M = prod(output_size(cdims))\n    N = prod(kernel_size(cdims))*channels_in(cdims)\n    K = channels_out(cdims)\n\n    parts = Iterators.partition(axes(dx, 5), ceil(Int, size(dx, 5) / ntasks))\n\n    function ∇conv_data_part(task_n, part)\n        col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace\n        for batch_idx in part\n            GC.@preserve col_slice w dy begin\n                dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)\n                w_ptr = pointer(w)\n                col_ptr = pointer(col_slice)\n                gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)\n            end\n            col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)\n        end\n    end\n    if should_use_spawn() && length(parts) > 1\n        @sync for (task_n, part) in enumerate(parts)\n            Threads.@spawn ∇conv_data_part(task_n, part)\n        end\n    else\n        for (task_n, part) in enumerate(parts)\n            ∇conv_data_part(task_n, part)\n        end\n    end\n    return dx\nend\n\n\n\n\n\n\"\"\"\n    im2col!(col, x, cdims)\n\nConverts a 3d image `x` into a matrix `col` for usage with GEMM-calculated convolution.\nPatches of `x` of size (kernel_w, kernel_h, kernel_d, C_in) will be extracted and laid\nout along the rows of `col`, one for each output pixel.  This routine is used by all\nim2col-based convolutions, just with extra singleton dimensions added in the case of `2d`\nor `1d` images.\n\"\"\"\nfunction im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, cdims::ConvDims) where {T}\n    if spatial_dims(cdims) != 3\n        throw(DimensionMismatch(\"im2col!() only accepts 3d convoluitional inputs\"))\n    end\n\n    # Extract those nice, compile-time constant type parameters from `cdims`.\n    width, height, depth = input_size(cdims)\n    kernel_w, kernel_h, kernel_d = kernel_size(cdims)\n    C_in = channels_in(cdims)\n    pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)\n    dil_w, dil_h, dil_d = dilation(cdims)\n    stride_w, stride_h, stride_d = stride(cdims)\n    out_width, out_height, out_depth = output_size(cdims)\n\n    # Reshape col for easy access.\n    col_reshaped = reshape(col, (\n        # Output resolution\n        out_width,\n        out_height,\n        out_depth,\n\n        # By input patch size\n        kernel_w,\n        kernel_h,\n        kernel_d,\n        C_in,\n    ))\n\n    padded_regions, central_region = calc_padding_regions(cdims)\n\n    # A helper function to project from output (w, h) to input (input_w, input_h)\n    @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1\n\n\n    # We begin by copying the central region of the image which requires no padding at all.\n    # Eliminating the branches of the fully generalized version below gives us a nice\n    # speedup on the majority of the data.\n    @inbounds for c in 1:C_in\n        # Unpack \"central region\"\n        w_region, h_region, d_region = central_region\n\n        for kd in 1:kernel_d,\n            kh in 1:kernel_h,\n            kw in 1:kernel_w,\n            d in d_region,\n            h in h_region,\n            w in w_region\n\n            input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d\n            input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h\n            input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w\n            kidxs = kernel_index(kw, kh, kd, cdims)\n\n            xval::T = x[input_kw, input_kh, input_kd, c]\n            col_reshaped[w, h, d, kidxs..., c] = xval\n        end\n    end\n\n\n    # For each \"padded region\", we run the fully general version\n    @inbounds for (w_region, h_region, d_region) in padded_regions\n        for c in 1:C_in,\n            d in d_region,\n            h in h_region,\n            w in w_region,\n            kd in 1:kernel_d,\n            kh in 1:kernel_h,\n            kw in 1:kernel_w\n\n            input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d\n            input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h\n            input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w\n\n            kidxs = kernel_index(kw, kh, kd, cdims)\n\n            out_of_bounds = (\n                input_kd <= 0 || input_kd > depth ||\n                input_kh <= 0 || input_kh > height ||\n                input_kw <= 0 || input_kw > width\n            )\n            if out_of_bounds\n                col_reshaped[w, h, d, kidxs..., c] = T(0)\n                continue\n            end\n\n            # Copy the data over\n            xval::T = x[input_kw, input_kh, input_kd, c]\n            col_reshaped[w, h, d, kidxs..., c] = xval\n        end\n    end\nend\n\n\n\"\"\"\n    col2im!(x, col, cdims, beta=0)\n\nDoes the inverse of `im2col!()`, converting `col` back into a 3d image, used for backward\npasses, transposed convolutions, etc...\n\nNote that this method has not been optimized in the same way as `im2col()` has, because\nit is slightly more complicated due to the more chaotic data access patterns, and I'm not\ndesperate enough yet.\n\"\"\"\ncol2im!\n\nfunction col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, cdims::ConvDims, beta::T=T(0)) where T\n    if spatial_dims(cdims) != 3\n        throw(DimensionMismatch(\"col2im!() only accepts 3d convoluitional inputs\"))\n    end\n\n    # Extract those nice, compile-time constant type parameters from `cdims`.\n    width, height, depth = input_size(cdims)\n    kernel_w, kernel_h, kernel_d = kernel_size(cdims)\n    C_in = channels_in(cdims)\n    pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)\n    dil_w, dil_h, dil_d = dilation(cdims)\n    stride_w, stride_h, stride_d = stride(cdims)\n    out_width, out_height, out_depth = output_size(cdims)\n\n    # TODO: Rewrite this method so we don't have this fill!() at the beginning!\n    # Calculate each output pixel once rather than accumulating into it?\n    if beta == T(0)\n        fill!(x, T(0))\n    elseif beta == T(1)\n        # nothing\n    else\n        x .*= beta\n    end\n\n    # Reshape col for easy access.\n    col_reshaped = reshape(col, (\n        # Output resolution\n        out_width,\n        out_height,\n        out_depth,\n\n        # By input patch size\n        kernel_w,\n        kernel_h,\n        kernel_d,\n        C_in,\n    ))\n\n    # A helper function to project from output (w, h) to input (input_w, input_h)\n    @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1\n\n    @inbounds for c in 1:C_in\n        for kd in 1:kernel_d,\n            kh in 1:kernel_h,\n            kw in 1:kernel_w\n\n            for d in 1:out_depth\n                input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d\n\n                # If this d is off the edge, then deal with the entire plane\n                # in one fell swoop, like a ravenous flock of crows.  CAW CAW.\n                if input_kd <= 0 || input_kd > depth\n                    continue\n                end\n\n                for h in 1:out_height\n                    input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h\n\n                    # Same for `h`, but in this case it's only a line, not a plane.\n                    # This results in slightly less caw'ing.\n                    if input_kh <= 0 || input_kh > height\n                        continue\n                    end\n\n                    for w in 1:out_width\n                        input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w\n\n                        # If this `w` is off the edge, only it gets cleared out.\n                        if input_kw <= 0 || input_kw > width\n                            continue\n                        end\n\n                        # Copy the data over\n                        kidxs = kernel_index(kw, kh, kd, cdims)\n                        cval::T = col_reshaped[w, h, d, kidxs..., c]\n                        x[input_kw, input_kh, input_kd, c] += cval\n                    end\n                end\n            end\n        end\n    end\nend\n"
  },
  {
    "path": "src/impl/depthwiseconv_direct.jl",
    "content": "## This file contains direct Julia implementations of depwthwise convolutions\n\n\"\"\"\n    depthwiseconv_direct!(y, x, w, cdims; alpha=1, beta=0)\n\nDirect depthwise convolution implementation; used for debugging, tests, and mixing/\nmatching of strange datatypes within a single convolution.  Uses naive nested for loop\nimplementation and does not attempt to optimize performance.  Rather, this implementation\nis intended to be maximally understandable and debuggable, to aid in testing other, more\nperformant implementations.  We also explicitly support mixing and matching of strange\ndatatypes, so that if the user really wants to convolve an image of `UInt8`'s with a\n`Float16` kernel, storing the result in a `Float32` output, there is at least a function\ncall for that madness.\n\nOne subtlety about depthwise convolutions; the shape of a depthwise convolutional kernel\nis `(spatial_dims..., C_mult, C_in)`, so the axis that must match with the number of\nchannels in `x` is the last, not the second-to-last, as in a normal dense convolution.\n\nSee the docstring for `conv_direct!()` for more on the optional parameters.\n\"\"\"\nfunction depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},\n                      w::AbstractArray{wT,5}, cdims::DepthwiseConvDims;\n                      alpha::yT=yT(1), beta=false) where {yT, xT, wT}\n    check_dims(size(x), size(w), size(y), cdims)\n\n    width, height, depth = input_size(cdims)\n    kernel_w, kernel_h, kernel_d = kernel_size(cdims)\n    pad_w_lo, _, pad_h_lo, _, pad_d_lo, _ = padding(cdims)\n    dil_w, dil_h, dil_d = dilation(cdims)\n    stride_w, stride_h, stride_d = stride(cdims)\n\n    # Create a method that determines how we're going to index into `w`\n    kproj(k, M, cdims::DepthwiseConvDims) = flipkernel(cdims) ? k : (M - k + 1)\n\n    # A helper function to project from output (w, h) to input (input_w, input_h)\n    project(idx, stride, pad) = (idx - 1)*stride - pad + 1\n\n    # Use `calc_padding_regions` to determine where we do or don't need to worry about padding\n    padded_regions, central_region = calc_padding_regions(cdims)\n\n    # Start with the central region\n    w_region, h_region, d_region = central_region\n    @inbounds for batch in 1:size(x)[end],\n        c_mult in 1:channel_multiplier(cdims),\n        c_in in 1:channels_in(cdims),\n        d_idx in d_region,\n        h_idx in h_region,\n        w_idx in w_region\n\n        # Since we're in the central region, we don't need to worry about clamping\n        dotprod = yT(0)\n        c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult\n        for kd in 1:kernel_d,\n            kh in 1:kernel_h,\n            kw in 1:kernel_w\n\n            # Hoist me, you coward.\n            x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d\n            x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h\n            x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w\n\n            x_val = x[x_w, x_h, x_d, c_in, batch]\n            w_val = w[kproj(kw, kernel_w, cdims),\n                      kproj(kh, kernel_h, cdims),\n                      kproj(kd, kernel_d, cdims),\n                      c_mult, c_in]\n            dotprod = muladd(x_val, w_val, dotprod)\n        end\n        y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]\n    end\n\n    # Next, do potentially-padded regions:\n    @inbounds for (w_region, h_region, d_region) in padded_regions,\n        batch in 1:size(x)[end],\n        c_mult in 1:channel_multiplier(cdims),\n        c_in in 1:channels_in(cdims),\n        d_idx in d_region,\n        h_idx in h_region,\n        w_idx in w_region\n\n        dotprod = yT(0)\n        c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult\n        for kd in 1:kernel_d\n            # Probe for out-of-bounds accesses on `x` and `continue` if we hit one\n            x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d\n            if x_d <= 0 || x_d > depth\n                continue\n            end\n\n            for kh in 1:kernel_h\n                x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h\n                if x_h <= 0 || x_h > height\n                    continue\n                end\n\n                for kw in 1:kernel_w\n                    x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w\n                    if x_w <= 0 || x_w > width\n                        continue\n                    end\n\n                    x_val = x[x_w, x_h, x_d, c_in, batch]\n                    w_val = w[kproj(kw, kernel_w, cdims),\n                              kproj(kh, kernel_h, cdims),\n                              kproj(kd, kernel_d, cdims),\n                              c_mult, c_in]\n                    dotprod = muladd(x_val, w_val, dotprod)\n                end\n            end\n        end\n\n        y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]\n    end\n\n    return y\nend\n\n\"\"\"\n    ∇depthwiseconv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0)\n\nCalculate the gradient imposed upon `x` in the depthwise convolution `y = x * w`.\nWe make use of the fact that a depthwise convolution is equivalent to `C_in` separate\nnormal convolutions between that channel of `x` and the `C_mult` different kernels that\nget applied to it.  The output of such a convolution is the gradient imposed upon that\nparticular channel of `x`, and so we simply walk through `x`, calculating the gradient\nfor each batch and channel independently.\n\"\"\"\n∇depthwiseconv_data_direct!\n\nfunction ∇depthwiseconv_data_direct!(\n                dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},\n                w::AbstractArray{wT,5}, cdims::DepthwiseConvDims;\n                alpha::xT=xT(1), beta=false) where {xT, yT, wT}\n    # We do a separate convolution for each channel in x\n    @inbounds for cidx in 1:channels_in(cdims)\n        # For this batch and in-channel, we have a normal transposed convolution\n        # between this slice of `x` and the corresponding slices of `w` and `dy`:\n        dx_slice = view(dx, :, :, :, cidx:cidx, :)\n        C_mult = channel_multiplier(cdims)\n        dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :)\n        w_slice = permutedims(view(w, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4))\n\n        # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out\n        # channels appropriately for this one convolution.\n        cdims_slice = DenseConvDims(cdims;\n            C_in=1,\n            C_out=channel_multiplier(cdims),\n        )\n\n        ∇conv_data_direct!(dx_slice, dy_slice, w_slice, cdims_slice;\n                                               alpha=alpha, beta=beta)\n    end\n    return dx\nend\n\n\"\"\"\n    ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0)\n\nCalculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`.\n\"\"\"\n∇depthwiseconv_filter_direct!\n\nfunction ∇depthwiseconv_filter_direct!(\n                dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},\n                dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims;\n                alpha::wT=wT(1),beta=false) where {xT, yT, wT}\n    # We do a separate convolution for each channel in x\n    @inbounds for cidx in 1:channels_in(cdims)\n        # For this batch and in-channel, we have a normal transposed convolution\n        # between this slice of `x` and the corresponding slices of `w` and `dy`:\n        x_slice = view(x, :, :, :, cidx:cidx, :)\n        C_mult = channel_multiplier(cdims)\n        dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :)\n        dw_slice = permutedims(view(dw, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4))\n\n        # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out\n        # channels appropriately for this one convolution.\n        cdims_slice = DenseConvDims(cdims;\n            C_in=1,\n            C_out=channel_multiplier(cdims),\n        )\n\n        ∇conv_filter_direct!(dw_slice, x_slice, dy_slice, cdims_slice;\n                                                alpha=alpha, beta=beta)\n        dw[:, :, :, :, cidx:cidx] .= permutedims(dw_slice, (1, 2, 3, 5, 4))\n    end\n    return dw\nend\n\n\n"
  },
  {
    "path": "src/impl/depthwiseconv_im2col.jl",
    "content": "## This file contains adapter code for doing depthwise convolutions with im2col.\n\n\n\"\"\"\n    depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0)\n\nPerform a depthwise convolution using im2col and GEMM, store the result in `y`.\nSee [`conv_im2col!`](@ref) for explanation of optional parameters.\n\"\"\"\ndepthwiseconv_im2col!\n\nfunction depthwiseconv_im2col!(\n                y::AbstractArray{T,5}, x::AbstractArray{T,5},\n                w::AbstractArray{T,5}, cdims::DepthwiseConvDims;\n                col::AbstractArray{T,3} = similar(x, im2col_dims(cdims)),\n                alpha::T=T(1), beta::T=T(0),\n                ntasks::Int=nthreads()) where T\n    check_dims(size(x), size(w), size(y), cdims)\n\n    # This functions exactly the same as conv_im2col!(), except that we shard the\n    # incoming data into slices of single channels.  This means that we need to walk\n    # each pointer forward individually, as done below, taking a single input channel\n    # and combining it with each kernel individually, before walking forward and doing\n    # the next input channel.\n    M = prod(output_size(cdims))\n    N = channel_multiplier(cdims)\n    K = prod(kernel_size(cdims))\n\n    parts = Iterators.partition(axes(y)[end], ceil(Int, size(y, 5) / ntasks))\n\n    dcdims = DenseConvDims(cdims)\n\n    function depthwiseconv_part(task_n, part)\n        col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace\n        for batch_idx in part\n            im2col!(col_slice, view(x, :, :, :, :, batch_idx), dcdims)\n\n            # We do a separate convolution for each channel in x, as we must\n            for c_in in 1:channels_in(cdims)\n                # Walk each pointer forward as we process each input channel\n                GC.@preserve col_slice w y begin\n                    col_ptr = pointer(col_slice, (c_in-1)*M*K+1)\n                    w_ptr = pointer(w, (c_in-1)*K*N+1)\n                    y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)\n                    gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)\n                end\n            end\n        end\n    end\n    if should_use_spawn() && length(parts) > 1\n        @sync for (task_n, part) in enumerate(parts)\n            Threads.@spawn depthwiseconv_part(task_n, part)\n        end\n    else\n        for (task_n, part) in enumerate(parts)\n            depthwiseconv_part(task_n, part)\n        end\n    end\n    return y\nend\n\n\"\"\"\n    ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw, ∇filter_im2col_dims(cdims));\n                                  alpha=1, beta=0)\n\nDepthwise conv backward pass onto the weights using im2col and GEMM.\nSee [`conv_im2col!`](@ref) for explanation of optional parameters.\n\"\"\"\n∇depthwiseconv_filter_im2col!\n\nfunction ∇depthwiseconv_filter_im2col!(\n                dw::AbstractArray{T,5}, x::AbstractArray{T,5},\n                dy::AbstractArray{T,5}, cdims::DepthwiseConvDims;\n                col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)),\n                alpha::T=T(1), beta::T=T(0)) where T\n    check_dims(size(x), size(dw), size(dy), cdims)\n\n    M = prod(kernel_size(cdims))\n    N = channel_multiplier(cdims)\n    K = prod(output_size(cdims))\n\n    for batch_idx in 1:size(x, 5)\n        # Because we accumulate over batches in this loop, we must set `beta` equal\n        # to `1.0` after the first sample.\n        beta′ = batch_idx == 1 ? beta : T(1)\n\n        # col_slice is a thread-local workspace\n        col_slice = view(col, :, :, 1)\n        im2col!(col_slice, view(x, :, :, :, :, batch_idx), cdims)\n\n        # We do a separate convolution for each channel in x, as we must\n        for c_in in 1:channels_in(cdims)\n            # Walk each pointer forward as we process each input channel\n            GC.@preserve col_slice dw dy begin\n                col_ptr = pointer(col_slice, (c_in - 1)*M*K + 1)\n                dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1)\n                dw_ptr = pointer(dw, (c_in - 1)*M*N + 1)\n                gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta′, dw_ptr)\n            end\n        end\n    end\n    return dw\nend\n\n\"\"\"\n    ∇depthwiseconv_data_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0)\n\nDepwthwise conv2d backward pass onto the input using im2col and GEMM.\nSee [`conv_im2col!`](@ref) for explanation of optional parameters.\n\"\"\"\n∇depthwiseconv_data_im2col!\n\nfunction ∇depthwiseconv_data_im2col!(\n                dx::AbstractArray{T,5}, dy::AbstractArray{T,5},\n                w::AbstractArray{T,5}, cdims::DepthwiseConvDims;\n                col::AbstractArray{T,3} = similar(dx, im2col_dims(cdims)),\n                alpha::T=T(1), beta::T=T(0),\n                ntasks::Int=nthreads()) where T\n    check_dims(size(dx), size(w), size(dy), cdims)\n\n    M = prod(output_size(cdims))\n    N = prod(kernel_size(cdims))\n    K = channel_multiplier(cdims)\n\n    parts = Iterators.partition(axes(dx)[end], ceil(Int, size(dx, 5) / ntasks))\n\n    function ∇depthwiseconv_data_part(task_n, part)\n        col_slice = col_slice = view(col, :, :, task_n) # col_slice is a task-local workspace\n        for batch_idx in part\n            # We do a separate convolution for each channel in x, as we must\n            for cidx in 1:channels_in(cdims)\n                GC.@preserve col_slice w dy begin\n                    # Walk each pointer forward as we process each input channel\n                    dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)\n                    w_ptr = pointer(w, (cidx - 1)*K*N + 1)\n                    col_ptr = pointer(col_slice, (cidx - 1)*M*N + 1)\n                    gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)\n                end\n            end\n            col2im!(view(dx, :, :, :, :, batch_idx), col_slice, cdims, beta)\n        end\n    end\n    if should_use_spawn() && length(parts) > 1\n        @sync for (task_n, part) in enumerate(parts)\n            Threads.@spawn ∇depthwiseconv_data_part(task_n, part)\n        end\n    else\n        for (task_n, part) in enumerate(parts)\n            ∇depthwiseconv_data_part(task_n, part)\n        end\n    end\n    return dx\nend\n"
  },
  {
    "path": "src/impl/padding_edges.jl",
    "content": "\"\"\"\n    calc_padding_regions(dims)\n\nPadding is a jerk.  A HUGE jerk that tries to sneak a bunch of conditionals and edge\ncases (quite literally) into our beautiful stencil operations such as convolution,\npooling, etc...  The way we deal with this is to, first, deal with everything in 3d,\nand then define a single padding region helper function that returns the seven regions\nthat all 3d operations must deal with, including the central \"unpadded\" region where we\ncan run at full bore, not paying any attention to padding.\n\"\"\"\nfunction calc_padding_regions(dims)\n    width, height, depth = input_size(dims)\n    kernel_w, kernel_h, kernel_d = kernel_size(dims)\n    C_in = channels_in(dims)\n    pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(dims)\n    dil_w, dil_h, dil_d = dilation(dims)\n    stride_w, stride_h, stride_d = stride(dims)\n    out_width, out_height, out_depth = output_size(dims)\n\n    # Let us first calculate the number of rows/cols within which we must zero out some\n    # portion of the image patches we're copying over.  The \"spillage\" here is the number\n    # of indices along a particular dimension for which a kernel will have some portion\n    # of its input domain overlapping the padding.  If padding is zero, these values are\n    # all trivially zero.  The low spillage is trivially the low padding divided by the\n    # stride; literally the number of shifts that overlap some padding.  The high\n    # spillage is slightly more complicated; we first figure out how many elements of\n    # high padding are wasted (e.g. through strides not fitting to the end perfectly)\n    # subtract that from the high padding, then do the same:\n    calc_lo_spill(O, S, P) = max(min(ceil(Int, P/S), O),0)\n    @inline function calc_hi_spill(O, S, Pl, Ph, K, D, I)\n        wasted_Ph = (I + Pl + Ph - (K - 1)*D - 1)%S\n        return max(min(ceil(Int, (Ph - wasted_Ph)/S), O), 0)\n    end\n\n    spill_w_lo = calc_lo_spill(out_width, stride_w, pad_w_lo)\n    spill_w_hi = calc_hi_spill(out_width, stride_w, pad_w_lo, pad_w_hi, kernel_w, dil_w, width)\n    spill_h_lo = calc_lo_spill(out_height, stride_h, pad_h_lo)\n    spill_h_hi = calc_hi_spill(out_height, stride_h, pad_h_lo, pad_h_hi, kernel_h, dil_h, height)\n    spill_d_lo = calc_lo_spill(out_depth, stride_d, pad_d_lo)\n    spill_d_hi = calc_hi_spill(out_depth, stride_d, pad_d_lo, pad_d_hi, kernel_d, dil_d, depth)\n\n    spill_w_hi_abs = out_width  - spill_w_hi + 1\n    spill_h_hi_abs = out_height - spill_h_hi + 1\n    spill_d_hi_abs = out_depth  - spill_d_hi + 1\n\n    # These are the regions we're going to have to run with cognizance of padding.\n    # There are six of them; one for each face of the cube image.  We explicitly\n    # design this so that we run over `width` most tightly, in the expectation that\n    # this will generate better code for when `h` and `d` are singleton dimensions.\n    # We visualize this as a cube, indexed by dimensions (w, h, d).\n    padded_regions = (\n        # First region is the lower-d WH face:\n        (\n            1:out_width,\n            1:out_height,\n            1:spill_d_lo,\n        ),\n\n        # The next largest chunk we choose will be the lower-h WD faces; we always\n        # want to maximize going across full `w`, as its contiguous in memory.\n        (\n            1:out_width,\n            1:spill_h_lo,\n            (spill_d_lo+1):(spill_d_hi_abs-1),\n        ),\n        # Then the upper-h WD face\n        (\n            1:out_width,\n            spill_h_hi_abs:out_height,\n            (spill_d_lo+1):(spill_d_hi_abs-1),\n        ),\n\n        # Next, we fit the HD faces in, but without overlapping the `h` and `d`\n        # regions we've done before:\n        (\n            1:spill_w_lo,\n            (spill_h_lo+1):(spill_h_hi_abs-1),\n            (spill_d_lo+1):(spill_d_hi_abs-1),\n        ),\n        (\n            spill_w_hi_abs:out_width,\n            (spill_h_lo+1):(spill_h_hi_abs-1),\n            (spill_d_lo+1):(spill_d_hi_abs-1)\n        ),\n        \n        # Last region is the higher-d WH face:\n        (\n            1:out_width,\n            1:out_height,\n            spill_d_hi_abs:out_depth,\n        ),\n    )\n\n    # The central region that has no padding.\n    central_region = (\n        (spill_w_lo+1):(spill_w_hi_abs - 1),\n        (spill_h_lo+1):(spill_h_hi_abs - 1),\n        (spill_d_lo+1):(spill_d_hi_abs - 1),\n    )\n    return padded_regions, central_region\nend"
  },
  {
    "path": "src/impl/pooling_direct.jl",
    "content": "# Pooling is so similar, we abstract over meanpooling and maxpooling, simply replacing\n# the inner loop operation and a few initialization parameters.\nfor name in (:max, :mean, :lpnorm)\n    @eval function $((Symbol(\"$(name)pool_direct!\")))(\n                    y::AbstractArray{<:Any, 5}, x::AbstractArray{<:Any, 5},\n                    pdims::PoolDims; alpha=1, beta=0, kwargs...) \n        $((Symbol(\"$(name)pool_direct!\")))(\n            y, x, pdims,\n            Val(kernel_size(pdims)), Val(channels_out(pdims)),\n            Val(padding(pdims)), Val(dilation(pdims)), Val(stride(pdims));\n            alpha, beta, kwargs...)\n        return y\n    end\n\n    @eval function $((Symbol(\"$(name)pool_direct!\")))(\n        y::AbstractArray{T,5}, x::AbstractArray{<:Any,5},\n        pdims::PoolDims,\n        # kernel size, channels out, padding, dilation, stride\n        ::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S};\n        alpha=1, beta=0, kwargs...\n    ) where {T, K, C, P, D, S}\n        @assert iszero(beta) \"beta not supported yet\"\n        check_dims(size(x), size(y), pdims)\n\n        width, height, depth = input_size(pdims)\n        kernel_w, kernel_h, kernel_d = K\n        pad_w_lo, _, pad_h_lo, _, pad_d_lo, _ = P\n        dil_w, dil_h, dil_d = D\n        stride_w, stride_h, stride_d = S\n\n        # We use calc_padding_regions to split outselves up into separate regions that may or\n        # may not need to worry about padding:\n        padded_regions, central_region = calc_padding_regions(pdims)\n\n        # A helper function to project from output (w, h) to input (input_w, input_h)\n        @inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1\n\n        # If we're doing mean pooling, we represent division by kernel size by rolling it\n        # into the `alpha` multiplier. \n        # The type might change here, that's why we prepend the underscore \n        # (does it make a difference, though?)\n        _alpha = if $(name == :mean)\n            T(alpha / prod(K))\n        else\n            T(alpha)\n        end\n        # _beta = T(beta)\n\n        # A quick note on the array element types `T` and `R`:\n        # Ideally, `T == R`, but in some edge-cases, this might not be the case \n        # (e.g. with `ReverseDiff.TrackedArray`, see issue #484).\n        # If the types differ, we will initialize variables (like `_alpha` above) with the \n        # target eltype `T`.\n\n        p = if $(name != :lpnorm) 0 else\n            !haskey(kwargs, :p) && error(\"lpnormpool needs keyword argument `p`\")\n            kwargs[:p]\n        end\n\n        # Each loop, we initialize `m` to something, set that here.\n        m_init = if $(name == :max)\n            T <: AbstractFloat ? nextfloat(typemin(T)) : typemin(T)\n        elseif $(name == :mean) || $(name == :lpnorm)\n            T(0)\n        else\n            error(\"Unimplemented codegen path\")\n        end\n\n        # Start with the central region\n        w_region, h_region, d_region = central_region\n\n        @inbounds for batch_idx in 1:size(x, 5), c in 1:C\n            for d in d_region\n            pd = project(d, stride_d, pad_d_lo)\n            for h in h_region\n            ph = project(h, stride_h, pad_h_lo)\n            for w in w_region\n            pw = project(w, stride_w, pad_w_lo)\n            m = m_init\n\n            for kd in 1:kernel_d,\n                kh in 1:kernel_h,\n                kw in 1:kernel_w\n\n                input_kd = pd + (kd - 1) * dil_d\n                input_kh = ph + (kh - 1) * dil_h\n                input_kw = pw + (kw - 1) * dil_w\n\n                # This conditional will be optimized away at compile time\n                if $(name == :max)\n                    xv = x[input_kw, input_kh, input_kd, c, batch_idx]\n                    if xv > m\n                        m = xv\n                    end\n                elseif $(name == :mean)\n                    m += x[input_kw, input_kh, input_kd, c, batch_idx]\n                elseif $(name == :lpnorm)\n                    # y = (∑ᵢ xᵢ^p)^(1 / p), here to calculate ∑ᵢ xᵢ^p\n                    m += x[input_kw, input_kh, input_kd, c, batch_idx]^p\n                else\n                    error(\"Unimplemented codegen path\")\n                end\n            end\n\n            # for lpnormpool, y = (∑ᵢ xᵢ^p)^(1 / p)\n            m = $(name == :lpnorm) ? m^(T(1) / p) : m\n\n            y[w, h, d, c, batch_idx] = _alpha * m # + _beta * y[w, h, d, c, batch_idx]\n            end\n            end\n            end\n        end\n\n        # Next, the padded regions\n        @inbounds for (w_region, h_region, d_region) in padded_regions\n            for batch_idx in 1:size(x, 5), c in 1:C\n                for d in d_region\n                pd = project(d, stride_d, pad_d_lo)\n                for h in h_region\n                ph = project(h, stride_h, pad_h_lo)\n                for w in w_region\n                pw = project(w, stride_w, pad_w_lo)\n                m = m_init\n\n                for kd in 1:kernel_d\n                    input_kd = pd + (kd - 1) * dil_d\n                    if input_kd <= 0 || input_kd > depth\n                        # add here condition for handling options for paded value handling\n                        continue\n                    end\n\n                    for kh in 1:kernel_h\n                        input_kh = ph + (kh - 1) * dil_h\n                        if input_kh <= 0 || input_kh > height\n                            # add here condition for handling options for paded value handling\n                            continue\n                        end\n\n                        for kw in 1:kernel_w\n                            input_kw = pw + (kw - 1) * dil_w\n                            if input_kw <= 0 || input_kw > width\n                                # add here condition for handling options for paded value handling\n                                continue\n                            end\n\n                            if $(name == :max)\n                                xv = x[input_kw, input_kh, input_kd, c, batch_idx]\n                                if xv > m\n                                    m = xv\n                                end\n                            elseif $(name == :mean)\n                                m += x[input_kw, input_kh, input_kd, c, batch_idx]\n                            elseif $(name == :lpnorm)\n                                m += x[input_kw, input_kh, input_kd, c, batch_idx]^p\n                            else\n                                error(\"Unimplemented codegen path\")\n                            end\n                        end\n                    end\n                end\n                $(name == :lpnorm) && (m = m^(T(1) / p))\n                y[w, h, d, c, batch_idx] = _alpha * m # + _beta * y[w, h, d, c, batch_idx]\n                end\n                end\n                end\n            end\n        end\n\n        return y\n    end\n\n    @eval function $((Symbol(\"∇$(name)pool_direct!\")))(\n                    dx::AbstractArray{<:Any,5}, dy::AbstractArray{<:Any,5},\n                    y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5},\n                    pdims::PoolDims; kwargs...)\n        $((Symbol(\"∇$(name)pool_direct!\")))(\n            dx, dy, y, x, pdims, Val(kernel_size(pdims)); kwargs...)\n        return dx\n    end\n\n    # Same story for gradients, and although this is very similar to the forward pass,\n    # it's unfortunately different enough that I think we need a separate function.  :(\n    @eval function $((Symbol(\"∇$(name)pool_direct!\")))(\n                    dx::AbstractArray{T,5}, dy::AbstractArray{<:Any,5},\n                    y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5},\n                    pdims::PoolDims, ::Val{K}; # == kernel_size(pdims)\n                    alpha=1, beta=0, kwargs...) where {T, K}\n        check_dims(size(x), size(dy), pdims)\n\n        width, height, depth = input_size(pdims)\n        kernel_w, kernel_h, kernel_d = K\n        out_c = channels_out(pdims)\n        pad_w_lo, _, pad_h_lo, _, pad_d_lo, _ = padding(pdims)\n        dil_w, dil_h, dil_d = dilation(pdims)\n        stride_w, stride_h, stride_d = stride(pdims)\n\n        # Concerning array eltypes `DX, DY, X, Y`, we want handle them like above, i.e.,\n        # initialize everything with the left-hand-side type (target type).\n        # Of course, ideally the types are all the same anyways.\n\n        # We use calc_padding_regions to split outselves up into separate regions that\n        # may or may not need to worry about padding:\n        padded_regions, central_region = calc_padding_regions(pdims)\n\n        # A helper function to project from output (w, h) to input (input_w, input_h)\n        @inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1\n\n        # If we're doing mean pooling, we represent division by kernel size by rolling\n        # it into the `_alpha` multiplier.\n        _alpha = if $(name == :mean)\n            T(alpha / prod(K))\n        else\n            T(alpha)\n        end\n\n        p = if $(name != :lpnorm) 0 else\n            !haskey(kwargs, :p) && error(\"lpnormpool must pass p\")\n            kwargs[:p]\n        end\n\n        # Start with the central region\n        w_region, h_region, d_region = central_region\n        @inbounds for batch_idx in 1:size(x, 5), c in 1:out_c\n            for d in d_region\n            pd = project(d, stride_d, pad_d_lo)\n            for h in h_region\n            ph = project(h, stride_h, pad_h_lo)\n            for w in w_region\n            pw = project(w, stride_w, pad_w_lo)\n\n            # Grab the output at this index for future use\n            y_idx = y[w, h, d, c, batch_idx]\n            dy_idx = dy[w, h, d, c, batch_idx]\n            maxpool_already_chose = false\n\n            for kd in 1:kernel_d,\n                kh in 1:kernel_h,\n                kw in 1:kernel_w\n\n                input_kd = pd + (kd - 1) * dil_d\n                input_kh = ph + (kh - 1) * dil_h\n                input_kw = pw + (kw - 1) * dil_w\n\n                # This conditional will be optimized away at compile time,\n                # or my name isn't shengdan jingyu\n                # x_idxs = (input_kw, input_kh, input_kd, c, batch_idx)\n                if $(name == :max)\n                    if maxpool_already_chose\n                        break\n                    end\n                    # If it's equal; this is the one we chose. We only choose one per\n                    # kernel window, all other elements of dx must be zero.\n                    # Uncomment line below if using with non-precise output (e.g. by NNPACK)\n                    # if abs(y_idx - x[x_idxs...]) < 1e-5 && !maxpool_already_chose\n                    if y_idx ≈ x[input_kw, input_kh, input_kd, c, batch_idx]\n                        dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...]\n                        maxpool_already_chose = true\n                    # Maxpooling does not support `beta` right now.  :(\n                    # else\n                    #    dx[x_idxs...] = T(0) + beta*dx[x_idxs...]\n                    end\n                elseif $(name == :mean)\n                    # Either does meanpool :(\n                    dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha\n                elseif $(name == :lpnorm)\n                    # y = (∑ᵢ xᵢ^p)^(1 / p), ∂y/∂xᵢ = xᵢ^(p-1) × y^(1-p)\n                    grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)\n                    dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad\n                else\n                    error(\"Unimplemented codegen path\")\n                end\n            end\n            end\n            end\n            end\n        end\n\n        # Next, the padded regions\n        @inbounds for (w_region, h_region, d_region) in padded_regions\n            for batch_idx in 1:size(x, 5), c in 1:out_c\n                for d in d_region\n                pd = project(d, stride_d, pad_d_lo)\n                for h in h_region\n                ph = project(h, stride_h, pad_h_lo)\n                for w in w_region\n                pw = project(w, stride_w, pad_w_lo)\n\n                # Grab the incoming gradient at this index for future use\n                y_idx = y[w, h, d, c, batch_idx]\n                dy_idx = dy[w, h, d, c, batch_idx]\n                maxpool_already_chose = false\n\n                # In these loops, we have to check that we're not reaching off the edge,\n                # we do so by putting in a bunch of conditionals.  :/\n                for kd in 1:kernel_d\n                    input_kd = pd + (kd - 1) * dil_d\n                    if input_kd <= 0 || input_kd > depth\n                        continue\n                    end\n\n                    for kh in 1:kernel_h\n                        input_kh = ph + (kh - 1) * dil_h\n                        if input_kh <= 0 || input_kh > height\n                            continue\n                        end\n\n                        for kw in 1:kernel_w\n                            input_kw = pw + (kw - 1) * dil_w\n                            if input_kw <= 0 || input_kw > width\n                                continue\n                            end\n\n                            # Same as above\n                            # x_idxs = (input_kw, input_kh, input_kd, c, batch_idx)\n                            if $(name == :max)\n                                if maxpool_already_chose\n                                    break\n                                end\n                                # Uncomment line below if using with non-precise output\n                                # if abs(y_idx - x[x_idxs...]) < 1e-5 && !maxpool_already_chose\n                                if y_idx ≈ x[input_kw, input_kh, input_kd, c, batch_idx]\n                                    dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...]\n                                    maxpool_already_chose = true\n                                # else\n                                #    dx[x_idxs...] = T(0) + beta*dx[x_idxs...]\n                                end\n                            elseif $(name == :mean)\n                                dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...]\n                            elseif $(name == :lpnorm)\n                                grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)\n                                dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad\n                            else\n                                error(\"Unimplemented codegen path\")\n                            end\n                        end\n                    end\n                end\n            end\n            end\n            end\n            end\n        end\n\n        return dx\n    end\nend\n"
  },
  {
    "path": "src/normalization.jl",
    "content": "# TODO: add CPU implementation\nfunction batchnorm end\n\nfunction ∇batchnorm end\n\n\nfunction ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)\n  y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) \n  function batchnorm_pullback(Δ)\n    grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...)\n    (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())\n  end\n  y, batchnorm_pullback\nend\n"
  },
  {
    "path": "src/padding.jl",
    "content": "\"\"\"\n    pad_zeros(x, pad::Tuple; [dims])\n    pad_zeros(x, pad::Int; [dims])\n\nPad the array `x` with zeros.\nEquivalent to [`pad_constant`](@ref) with the constant equal to 0. \n\"\"\"\npad_zeros(x::AbstractArray, pad; dims = :) =\n  pad_constant(x, pad, 0; dims = dims)\n\n\"\"\"\n    pad_constant(x, pad::Tuple, val = 0; [dims = :])\n    pad_constant(x, pad::Int, val = 0; [dims = :])\n\nPad the array `x` with the constant value `val`.\n\n`pad` can be a tuple of integers.\nIf it is of some length `2 * length(dims)` that specifies the left and right padding size\nfor each of the dimensions in `dims` as `(l1, r1, ..., ln, rn)`. \nIf supplied with a tuple of length `length(dims)` instead, it applies symmetric padding.\nIf `dims` is not given, it defaults to all dimensions.\n\nFor integer `pad` input, it is applied on both sides\non every dimension in `dims`.\n\nSee also [`pad_zeros`](@ref), [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_symmetric`](@ref), and [`pad_circular`](@ref).\n\n```jldoctest\njulia> r = reshape(1:4, 2, 2)\n2×2 reshape(::UnitRange{Int64}, 2, 2) with eltype Int64:\n 1  3\n 2  4\n\njulia> pad_constant(r, (1, 2, 3, 4), 8)\n5×9 Matrix{Int64}:\n 8  8  8  8  8  8  8  8  8\n 8  8  8  1  3  8  8  8  8\n 8  8  8  2  4  8  8  8  8\n 8  8  8  8  8  8  8  8  8\n 8  8  8  8  8  8  8  8  8\n\njulia> pad_constant(r, 1, 8)\n4×4 Matrix{Int64}:\n 8  8  8  8\n 8  1  3  8\n 8  2  4  8\n 8  8  8  8\n\njulia> r = reshape(1:27, 3, 3, 3)\n3×3×3 reshape(::UnitRange{Int64}, 3, 3, 3) with eltype Int64:\n[:, :, 1] =\n 1  4  7\n 2  5  8\n 3  6  9\n\n[:, :, 2] =\n 10  13  16\n 11  14  17\n 12  15  18\n\n[:, :, 3] =\n 19  22  25\n 20  23  26\n 21  24  27\n\njulia> pad_constant(r, (2,1), dims = 1) # assymetric padding\n6×3×3 Array{Int64, 3}:\n[:, :, 1] =\n 0  0  0\n 0  0  0\n 1  4  7\n 2  5  8\n 3  6  9\n 0  0  0\n\n[:, :, 2] =\n  0   0   0\n  0   0   0\n 10  13  16\n 11  14  17\n 12  15  18\n  0   0   0\n\n[:, :, 3] =\n  0   0   0\n  0   0   0\n 19  22  25\n 20  23  26\n 21  24  27\n  0   0   0\n\njulia> pad_constant(r, (2,1, 3), dims = (1,2)) # padding must always be either the same length as dims, or double it\nERROR: ArgumentError: Could not parse padding (2, 1, 3) and dims (1, 2)\nStacktrace:\n[...]\n```\n\"\"\"\npad_constant(x::AbstractArray{T,N}, pad::Int, val = 0; dims = :) where {T,N} =\n  pad_constant(x, gen_pad(pad, dims isa Colon ? dims : (dims...,), N), val)\npad_constant(x::AbstractArray{T,N}, pad::Tuple, val = 0; dims = :) where {T,N} =\n  pad_constant(x, gen_pad(pad, dims isa Colon ? dims : (dims...,), N), val)\n\nfunction pad_idx(pad, dims, N)\n  is = zip( (2 .* dims) .- 1, (2 .* dims))\nend\n\n@inline tuplejoin(x) = x\n@inline tuplejoin(x, y) = (x..., y...)\n@inline tuplejoin(x, y, z...) = tuplejoin(tuplejoin(x, y), z...)\n\ngen_pad(pad::Int, dims, N) = gen_pad(ntuple(_ -> pad, length(dims)), dims, N)\ngen_pad(pad::Int, dims::Colon, N) = ntuple(_ -> (pad, pad), N)\ngen_pad(pad, dims::Colon, N) = gen_pad(pad, ntuple(identity, N), N)\ngen_pad(pad, dims::Int, N) = gen_pad(pad, (dims,), N)\ngen_pad(pad::Int, dims::Int, N) = gen_pad((pad,pad), (dims,), N)\nfunction gen_pad(pad::NTuple{L,Int}, dims::NTuple{D,Int}, N) where {L,D}\n  ntuple(N) do d\n   if d in dims\n     if L == D\n       ix = findfirst(==(d), dims)\n       (pad[ix], pad[ix])\n     elseif L == 2D\n       ix = findfirst(==(d), dims)\n       (pad[2ix - 1], pad[2ix])\n     else\n       throw(ArgumentError(\"Could not parse padding $pad and dims $dims\"))\n     end\n   else\n     (0,0)\n   end\n\n  end\nend\n\n\n# Expects length(pad) == 2M\nfunction pad_constant(x::AbstractArray{T,M}, pad::NTuple{N,Tuple{Int,Int}}, val = 0) where {T,M,N}\n  sz, c = size_and_center(x, pad)\n  res = fill!(similar(x, sz...), val)\n  res[c...] = x\n  res\nend\n\nfunction size_and_center(x, pad::NTuple{N,NTuple{2, Int}}) where N\n  sz = ntuple(i -> pad[i][1] + pad[i][2], N) .+ size(x)\n  center = broadcast((x,y) -> x .+ y, axes(x), ntuple(i -> pad[i][1], N))\n  sz, center\nend\n\nfunction rrule(::typeof(pad_constant), x::AbstractArray{T,N},\n               pad, val; dims = :) where {T,N}\n  y = pad_constant(x, pad, val; dims = dims)\n  function pad_constant_pullback(Δ)\n    p = gen_pad(pad, dims, N)\n    outsize, center = size_and_center(x, p)\n    (NoTangent(), @thunk(unthunk(Δ)[center...]), NoTangent(), NoTangent(),)\n  end\n  return y, pad_constant_pullback\nend\n\n\n\"\"\"\n    pad_repeat(x, pad::Tuple; [dims])\n    pad_repeat(x, pad::Int; [dims])\n \nPad the array `x` repeating the values on the border.\n\n`pad` can a tuple of integers `(l1, r1, ..., ln, rn)`\nof some length `2n` that specifies the left and right padding size\nfor each of the dimensions in `dims`. If `dims` is not given, \nit defaults to the first `n` dimensions.\n\nIf `pad` is an integer, it is applied on both sides\non every dimension in `dims`. In this case, `dims` \ndefaults to the first `ndims(x)-2` dimensions \n(i.e. excludes the channel and batch dimension). \n\nSee also [`pad_reflect`](@ref), [`pad_symmetric`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref).\n\n```jldoctest\njulia> r = reshape(1:9, 3, 3)\n3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:\n 1  4  7\n 2  5  8\n 3  6  9\n\njulia> pad_repeat(r, (1,2,3,4))\n6×10 Matrix{Int64}:\n 1  1  1  1  4  7  7  7  7  7\n 1  1  1  1  4  7  7  7  7  7\n 2  2  2  2  5  8  8  8  8  8\n 3  3  3  3  6  9  9  9  9  9\n 3  3  3  3  6  9  9  9  9  9\n 3  3  3  3  6  9  9  9  9  9\n```\n\"\"\"\nfunction pad_repeat(x::AbstractArray, pad::NTuple{M,Int}; \n                    dims = 1:M÷2) where M\n  length(dims) == M ÷ 2 ||\n    throw(ArgumentError(\"The number of dims should be equal to the number of padding dimensions\"))\n  for (i, d) in enumerate(dims)\n    x = pad_repeat(x, (pad[2i-1], pad[2i]); dims=d)\n  end  \n  return x\nend\n\nfunction pad_repeat(x::AbstractArray{F,N}, pad::NTuple{2,Int}; \n                    dims::Int = 1) where {F,N}\n  lpad, rpad = pad\n\n  xlborder = selectdim(x, dims, 1:1)\n  nrepl = ntuple(i -> i == dims ? lpad : 1, N)\n  xl = repeat(xlborder, outer = nrepl)\n\n  n = size(x, dims)\n  xrborder = selectdim(x, dims, n:n)\n  nrepr = ntuple(i -> i == dims ? rpad : 1, N)\n  xr = repeat(xrborder, outer = nrepr)\n\n  return cat(xl, x, xr, dims = dims)\nend\n\n\"\"\"\n    pad_reflect(x, pad::Tuple; [dims])\n    pad_reflect(x, pad::Int; [dims])\n\nPad the array `x` reflecting its values across the border.\n\n`pad` can a tuple of integers `(l1, r1, ..., ln, rn)`\nof some length `2n` that specifies the left and right padding size\nfor each of the dimensions in `dims`. If `dims` is not given, \nit defaults to the first `n` dimensions.\n\nIf `pad` is an integer, it is applied on both sides\non every dimension in `dims`. In this case, `dims` \ndefaults to the first `ndims(x)-2` dimensions \n(i.e. excludes the channel and batch dimension). \n\nSee also [`pad_repeat`](@ref), [`pad_symmetric`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref).\n\n```jldoctest\njulia> r = reshape(1:9, 3, 3)\n3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:\n 1  4  7\n 2  5  8\n 3  6  9\n\njulia> pad_reflect(r, (1,2,1,2))\n6×6 Matrix{Int64}:\n 5  2  5  8  5  2\n 4  1  4  7  4  1\n 5  2  5  8  5  2\n 6  3  6  9  6  3\n 5  2  5  8  5  2\n 4  1  4  7  4  1\n```\n\"\"\"\nfunction pad_reflect(x::AbstractArray, pad::NTuple{M,Int};\n                     dims=1:M÷2) where M\n  length(dims) == M ÷ 2 ||\n    throw(ArgumentError(\"The number of dims should be equal to the number of padding dimensions\"))\n  for (i, d) in enumerate(dims)\n    x = pad_reflect(x, (pad[2i-1], pad[2i]); dims = d)\n  end\n  return x\nend\n\nfunction pad_reflect(\n  x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1,\n) where {F,N}\n  lpad, rpad = pad\n  n = size(x, dims)\n  xl = lpad == 0 ?\n    similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :\n    reverse(selectdim(x, dims, 2:lpad+1); dims)\n  xr = rpad == 0 ?\n    similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :\n    reverse(selectdim(x, dims, n-rpad:n-1); dims)\n  return cat(xl, x, xr; dims)\nend\n\n\"\"\"\n    pad_symmetric(x, pad::Tuple; [dims])\n    pad_symmetric(x, pad::Int; [dims])\n\nPad the array `x` reflecting its values symmetrically across the border, i.e. the border values of `x` are present in the padding values, in contrast to [`pad_reflect`](@ref).\n\n`pad` can a tuple of integers `(l1, r1, ..., ln, rn)`\nof some length `2n` that specifies the left and right padding size\nfor each of the dimensions in `dims`. If `dims` is not given, \nit defaults to the first `n` dimensions.\n\nIf `pad` is an integer, it is applied on both sides\non every dimension in `dims`. In this case, `dims` \ndefaults to the first `ndims(x)-2` dimensions \n(i.e. excludes the channel and batch dimension). \n\nSee also [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_circular`](@ref), and [`pad_constant`](@ref).\n\n```jldoctest\njulia> r = reshape(1:9, 3, 3)\n3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:\n 1  4  7\n 2  5  8\n 3  6  9\n\njulia> pad_symmetric(r, (1,2,1,2))\n6×6 Matrix{Int64}:\n 1  1  4  7  7  4\n 1  1  4  7  7  4\n 2  2  5  8  8  5\n 3  3  6  9  9  6\n 3  3  6  9  9  6\n 2  2  5  8  8  5\n```\n\"\"\"\nfunction pad_symmetric(x::AbstractArray, pad::NTuple{M,Int};\n                     dims=1:M÷2) where M\n  length(dims) == M ÷ 2 ||\n    throw(ArgumentError(\"The number of dims should be equal to the number of padding dimensions\"))\n  for (i, d) in enumerate(dims)\n    x = pad_symmetric(x, (pad[2i-1], pad[2i]); dims = d)\n  end\n  return x\nend\n\nfunction pad_symmetric(\n  x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1,\n) where {F,N}\n  lpad, rpad = pad\n  n = size(x, dims)\n\n  xl = lpad == 0 ?\n    similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :\n    reverse(selectdim(x, dims, 1:lpad); dims)\n  xr = rpad == 0 ?\n    similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :\n    reverse(selectdim(x, dims, n-rpad+1:n); dims)\n  return cat(xl, x, xr; dims)\nend\n\n\"\"\"\n    pad_circular(x, pad::Tuple; [dims])\n    pad_circular(x, pad::Int; [dims])\n\nPad the array `x` \"circularly\" across the border by wrapping around values from the opposite side of `x`. \n\n`pad` can a tuple of integers `(l1, r1, ..., ln, rn)`\nof some length `2n` that specifies the left and right padding size\nfor each of the dimensions in `dims`. If `dims` is not given, \nit defaults to the first `n` dimensions.\n\nIf `pad` is an integer, it is applied on both sides\non every dimension in `dims`. In this case, `dims` \ndefaults to the first `ndims(x)-2` dimensions \n(i.e. excludes the channel and batch dimension). \n\nThe pad length on either side in any dimension must not exceed the\nsize of `x` in that dimension, i.e. `pad_circular` is not able to create abitrary sized tilings of `x`.\n\nSee also [`pad_repeat`](@ref), [`pad_reflect`](@ref), [`pad_symmetric`](@ref), and [`pad_constant`](@ref).\n\n```jldoctest\njulia> r = reshape(1:9, 3, 3)\n3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64:\n 1  4  7\n 2  5  8\n 3  6  9\n\njulia> pad_circular(r, (1,2,1,2))\n6×6 Matrix{Int64}:\n 9  3  6  9  3  6\n 7  1  4  7  1  4\n 8  2  5  8  2  5\n 9  3  6  9  3  6\n 7  1  4  7  1  4\n 8  2  5  8  2  5\n```\n\"\"\"\nfunction pad_circular(x::AbstractArray, pad::NTuple{M,Int}; \n                     dims=1:M÷2) where M\n  length(dims) == M ÷ 2 ||\n    throw(ArgumentError(\"The number of dims should be equal to the number of padding dimensions\"))\n\n  for (i, d) in enumerate(dims)\n    x = pad_circular(x, (pad[2i-1], pad[2i]); dims = d)\n  end  \n  return x\nend\n\nfunction pad_circular(x::AbstractArray{F,N}, pad::NTuple{2,Int}; \n                     dims::Int = 1) where {F,N}\n  lpad, rpad = pad\n  n = size(x, dims)\n\n  xl = selectdim(x, dims, n-lpad+1:n)\n  xr = selectdim(x, dims, 1:rpad)\n  return cat(xl, x, xr, dims = dims)\nend\n\n# convenience methods for symmetric and homogeneous padding\npad_repeat(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} =\n  pad_repeat(x, ntuple(_ -> pad, 2length(dims)); dims = dims)\npad_reflect(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} =\n  pad_reflect(x, ntuple(_ -> pad, 2length(dims)); dims = dims)\npad_symmetric(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} =\n  pad_symmetric(x, ntuple(_ -> pad, 2length(dims)); dims = dims)\npad_circular(x::AbstractArray{F,N}, pad::Int; dims=1:N-2) where {F,N} =\n  pad_circular(x, ntuple(_ -> pad, 2length(dims)); dims = dims)\n\n"
  },
  {
    "path": "src/pooling.jl",
    "content": "## Pooling API\n#\n#  We provide the following generic methods, for 3d, 4d, and 5d tensors, calculating 1d,\n#  2d and 3d pooling, based on the rank of the input tensors, in both mutating and\n#  non-mutating auto-allocating variants:\n#   - Pooling:\n#     - maxpool(x, pdims)\n#     - maxpool!(y, x, pdims)\n#     - meanpool(x, pdims)\n#     - meanpool!(y, x, pdims)\n#     - lpnormpool(x, pdims)\n#     - lpnormpool!(y, x, pdims)\n#   - Pooling input backprop\n#     - ∇maxpool(dy, y, x, pdims)\n#     - ∇maxpool!(dx, dy, y, x, pdims)\n#     - ∇meanpool(dy, y, x, pdims)\n#     - ∇meanpool!(dx, dy, y, x pdims)\n#     - ∇lpnormpool(dy, y, x, pdims)\n#     - ∇lpnormpool!(dx, dy, y, x pdims)\n#\n#   All methods require a `PoolDims` object to define the dimensions and optional\n#   elements of the convolution (stride, dilation, etc...), which is easily constructable\n#   through something like `PoolDims(x, w)`.\n\n\n# First, we will define mappings from the generic API names to our accelerated backend\n# implementations.  At the moment this is only the direct implementation, however this\n# exists here so that other packages (NNPACK, MAGMA, etc...) can override this easily.\nfor (front_name, backend) in (\n        # This maps from public, front-facing name, to internal backend name\n        :maxpool  => :direct,\n        :meanpool => :direct,\n        :lpnormpool => :direct,\n    )\n\n    # We only define 3d pooling primitives, we reshape lower down to get 1d and 2d pooling\n    @eval begin\n        function $(Symbol(\"$(front_name)!\"))(\n                y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5},\n                pdims::PoolDims; kwargs...)\n            $(Symbol(\"$(front_name)_$(backend)!\"))(y, x, pdims; kwargs...)\n        end\n    end\nend\n\n# Do the same for backprops\nfor (front_name, backend) in (\n        :∇maxpool  => :direct,\n        :∇meanpool => :direct,\n        :∇lpnormpool => :direct,\n    )\n    @eval begin\n        function $(Symbol(\"$(front_name)!\"))(\n                        dx::AbstractArray{<:Any,5}, dy::AbstractArray{<:Any,5},\n                        y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5},\n                        pdims::PoolDims; kwargs...)\n            $(Symbol(\"$(front_name)_$(backend)!\"))(dx, dy, y, x, pdims; kwargs...)\n        end\n    end\nend\n\n\n# Our strategy for pooling is to reshape to an array with three spatial dimensions, which\n# makes things MUCH EASIER for us on the backend side, and is in general pretty fast,\n# since we can specialize on sizes.\nfor front_name in (:maxpool, :meanpool, :lpnormpool)\n    for backend in (Symbol(), :_direct)\n        for N in (3, 4)\n            @eval begin\n                function $(Symbol(\"$(front_name)$(backend)!\"))(\n                                y::AbstractArray{<:Any,$N}, x::AbstractArray{<:Any,$N},\n                                pdims::PoolDims; kwargs...)\n                    $(Symbol(\"$(front_name)$(backend)!\"))(\n                        insert_singleton_spatial_dimension(y, $(5 - N)),\n                        insert_singleton_spatial_dimension(x, $(5 - N)),\n                        insert_singleton_spatial_dimension(pdims, $(5 - N));\n                        kwargs...\n                    )\n\n                    # We explicitly return `y` here, because the backend call\n                    # itself may return a reshaped view, which we don't want.\n                    return y\n                end\n\n                # backprops too\n                function $(Symbol(\"∇$(front_name)$(backend)!\"))(\n                                dx::AbstractArray{<:Any,$N}, dy::AbstractArray{<:Any,$N},\n                                y::AbstractArray{<:Any,$N}, x::AbstractArray{<:Any,$N},\n                                pdims::PoolDims; kwargs...)\n                    $(Symbol(\"∇$(front_name)$(backend)!\"))(\n                        insert_singleton_spatial_dimension(dx, $(5 - N)),\n                        insert_singleton_spatial_dimension(dy, $(5 - N)),\n                        insert_singleton_spatial_dimension(y, $(5 - N)),\n                        insert_singleton_spatial_dimension(x, $(5 - N)),\n                        insert_singleton_spatial_dimension(pdims, $(5 - N));\n                        kwargs...\n                    )\n\n                    # We explicitly return `dx` here, because the backend call\n                    # itself may return a reshaped view, which we don't want.\n                    return dx\n                end\n            end\n        end\n    end\nend\n\n\n# Finally, let's generate auto-allocating versions of all our functions, for all backends:\nfor backend in (Symbol(), :_direct)\n    # First make auto-allocating versions of the basic pooling calls:\n    for name in (:maxpool, :meanpool, :lpnormpool)\n        @eval begin\n            function $(Symbol(\"$(name)$(backend)\"))(\n                            x::AbstractArray{<:Any,N},\n                            pdims::PoolDims; kwargs...) where {N}\n                y = similar(x, output_size(pdims)..., channels_out(pdims), size(x, N))\n                fill!(y, 0)\n                return $(Symbol(\"$(name)$(backend)!\"))(y, x, pdims; kwargs...)\n            end\n\n            # Backprops too\n            function $(Symbol(\"∇$(name)$(backend)\"))(\n                            dy::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N},\n                            x::AbstractArray{<:Any,N}, pdims::PoolDims;\n                            kwargs...) where {N}\n                dx = similar(x, input_size(pdims)..., channels_in(pdims), size(dy, N))\n                fill!(dx, 0)\n                return $(Symbol(\"∇$(name)$(backend)!\"))(dx, dy, y, x, pdims; kwargs...)\n            end\n        end\n    end\nend\n\nexpand(N, i::Tuple) = i\nexpand(N, i::Integer) = ntuple(_ -> i, N)\n\n\n\"\"\"\n    maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k)\n\nPerform max pool operation with window size `k` on input tensor `x`.\n\nArguments:\n\n* `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2`\n* `pad`: See [`pad_zeros`](@ref) for details.\n* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`.\n\"\"\"\nfunction maxpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N\n    pad = expand(Val(N), pad)\n    stride = expand(Val(N), stride)\n    pdims = PoolDims(x, k; padding=pad, stride=stride)\n    return maxpool(x, pdims)\nend\n\n\n\"\"\"\n    meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k)\n\nPerform mean pool operation with window size `k` on input tensor `x`.\n\nArguments:\n\n* `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2`\n* `pad`: See [`pad_zeros`](@ref) for details.\n* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`.\n\"\"\"\nfunction meanpool(x, k::NTuple{N, Integer}; pad=0, stride=k) where N\n    pad = expand(Val(N), pad)\n    stride = expand(Val(N), stride)\n    pdims = PoolDims(x, k; padding=pad, stride=stride)\n    return meanpool(x, pdims)\nend\n\n\n\"\"\"\n    lpnormpool(x, p::Real, k::NTuple{N, Integer}; pad=0, stride=k)\n\nPerform Lp pool operation with value of the Lp norm `p` and window size `k` on input tensor `x`, also known as LPPool in pytorch.\nThis pooling operator from [Learned-Norm Pooling for Deep Feedforward and Recurrent Neural Networks](https://arxiv.org/abs/1311.1780).\n\nArguments:\n\n* `x` and `k`: Expects `ndim(x) ∈ 3:5`, and always `length(k) == ndim(x) - 2`\n* `p` is restricted to `0 < p < Inf`.\n* `pad`: See [`pad_zeros`](@ref) for details.\n* `stride`: Either a tuple with the same length as `k`, or one integer for all directions. Default is `k`.\n\nFor all elements `x` in a size `k` window, lpnormpool computes `(∑ᵢ xᵢ^p)^(1 / p)` as an element of the output.\n\nThus `lpnormpool(x, 1, k) ./ prod(k) ≈ meanpool(x, k)` and `lpnormpool(x, 2, k).^2 ./ prod(k) ≈ meanpool(x.^2, k)`.\n\"\"\"\nfunction lpnormpool(x, p::Real, k::NTuple{N, Integer}; pad=0, stride=k) where {N}\n    pow = p isa Integer ? p : convert(float(eltype(x)), p)\n    (isinf(pow) || pow < 0) && error(\"p value of Lp norm pool expects `0 < p < Inf`, but p is $(pow) now.\")\n    pdims = PoolDims(x, k; padding=expand(Val(N), pad), stride=expand(Val(N), stride))\n    return lpnormpool(x, pdims; p=pow)\nend\n\n\nfor pool in [:maxpool, :meanpool, :lpnormpool]\n    ∇pool = Symbol(:∇, pool)\n    pullback = Symbol(pool, :_pullback)\n    @eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...)\n        Ω = $pool(x, pdims; kw...)\n        $pullback(Δ) = (NoTangent(), $∇pool(unthunk(Δ), Ω, x, pdims; kw...), NoTangent())\n        return Ω, $pullback\n    end\nend\n"
  },
  {
    "path": "src/rotation.jl",
    "content": "\"\"\"\n    _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, round_or_floor)\n\nThis rotates the coordinates and either applies round(nearest neighbour)\nor floor for :bilinear interpolation)\n\"\"\"\n@inline function _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, round_or_floor)\n    y = i - rotation_center[1]\n    x = j - rotation_center[2]\n    yrot = cosθ * y - sinθ * x + rotation_center[1]\n    xrot = sinθ * y + cosθ * x + rotation_center[2]\n    yrot_f = round_or_floor(yrot)\n    xrot_f = round_or_floor(xrot)\n    yrot_int = round_or_floor(Int, yrot)\n    xrot_int = round_or_floor(Int, xrot)\n    return yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int\nend\n\n\n\"\"\"\n   _bilinear_helper(yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int) \n\nSome helper variables\n\"\"\"\n@inline function _bilinear_helper(yrot, xrot, yrot_f, xrot_f)\n    xdiff = (xrot - xrot_f)\n    xdiff_1minus = 1 - xdiff\n    ydiff = (yrot - yrot_f)\n    ydiff_1minus = 1 - ydiff\n    \n    return ydiff, ydiff_1minus, xdiff, xdiff_1minus\nend\n\n\n\"\"\"\n    _prepare_imrotate(arr, θ, rotation_center)\n\nPrepate `sin` and `cos`, creates the output array and converts type\nof `rotation_center` if required.\n\"\"\"\nfunction _prepare_imrotate(arr::AbstractArray{T}, θ, rotation_center) where T\n    # needed for rotation matrix\n    θ = mod(real(T)(θ), real(T)(2π))\n    rotation_center = real(T).(rotation_center)\n    sinθ, cosθ = sincos(real(T)(θ)) \n    out = similar(arr)\n    fill!(out, 0)\n    return sinθ, cosθ, rotation_center, out\nend\n\n\n\"\"\"\n    _check_trivial_rotations!(out, arr, θ, rotation_center) \n\nWhen `θ = 0 || π /2 || π || 3/2 || π` and if `rotation_center` \nis in the middle of the array.\nFor an even array of size 4, the rotation_center would need to be 2.5.\nFor an odd array of size 5, the rotation_center would need to be 3.\n\nIn those cases, rotations are trivial just by reversing or swapping some axes.\n\"\"\"\nfunction _check_trivial_rotations!(out, arr, θ, rotation_center; adjoint=false)\n    if iszero(θ)\n        out .= arr\n        return true \n    end\n    # check for special cases where rotations are trivial\n    if (iseven(size(arr, 1)) && iseven(size(arr, 2)) && \n        rotation_center[1] ≈ size(arr, 1) ÷ 2 + 0.5 && rotation_center[2] ≈ size(arr, 2) ÷ 2 + 0.5) ||\n        (isodd(size(arr, 1)) && isodd(size(arr, 2)) && \n        (rotation_center[1] == size(arr, 1) ÷ 2 + 1 && rotation_center[1] == size(arr, 2) ÷ 2 + 1))\n        if θ ≈ π / 2 \n            if adjoint == false\n                out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,))\n            else\n                out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,))\n            end\n            return true\n        elseif θ ≈ π\n            out .= reverse(arr, dims=(1,2))\n            return true\n        elseif θ ≈ 3 / 2 * π\n            if adjoint == false\n                out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(1,))\n            else\n                out .= reverse(PermutedDimsArray(arr, (2, 1, 3, 4)), dims=(2,))\n            end\n            return true\n        end\n    end\n\n    return false\nend\n\n\n\"\"\"\n    imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear, rotation_center=size(arr) .÷ 2 .+ 1)\n\nRotates an array in the first two dimensions around the center pixel `rotation_center`. \nThe default value of `rotation_center` is defined such that there is a integer center pixel for even and odd sized arrays which it is rotated around.\nFor an even sized array of size `(4,4)` this would be `(3,3)`, for an odd array of size `(3,3)` this would be `(2,2)`\nHowever, `rotation_center` can be also non-integer numbers if specified.\n\nThe angle `θ` is interpreted in radians.\n\nThe adjoint is defined with ChainRulesCore.jl. This method also runs with CUDA (and in principle all KernelAbstractions.jl supported backends).\n\n# Keywords\n* `method=:bilinear` for bilinear interpolation or `method=:nearest` for nearest neighbour\n* `rotation_center=size(arr) .÷ 2 .+ 1` means there is a real center pixel around it is rotated.\n\n# Examples\n```julia-repl\njulia> arr = zeros((4,4,1,1)); arr[2,2,1,1] = 1;\n\njulia> arr\n4×4×1×1 Array{Float64, 4}:\n[:, :, 1, 1] =\n 0.0  0.0  0.0  0.0\n 0.0  1.0  0.0  0.0\n 0.0  0.0  0.0  0.0\n 0.0  0.0  0.0  0.0\n\njulia> NNlib.imrotate(arr, deg2rad(90)) # rotation around (3,3)\n4×4×1×1 Array{Float64, 4}:\n[:, :, 1, 1] =\n 0.0  0.0  0.0  0.0\n 0.0  0.0  0.0  1.0\n 0.0  0.0  0.0  0.0\n 0.0  0.0  0.0  0.0\n\njulia> NNlib.imrotate(arr, deg2rad(90), rotation_center=(2,2))\n4×4×1×1 Array{Float64, 4}:\n[:, :, 1, 1] =\n 0.0  0.0  0.0  0.0\n 0.0  1.0  0.0  0.0\n 0.0  0.0  0.0  0.0\n 0.0  0.0  0.0  0.0\n\njulia> arr = zeros((3,3,1,1)); arr[1,2,1,1] = 1\n1\n\njulia> arr\n3×3×1×1 Array{Float64, 4}:\n[:, :, 1, 1] =\n 0.0  1.0  0.0\n 0.0  0.0  0.0\n 0.0  0.0  0.0\n\njulia> NNlib.imrotate(arr, deg2rad(45))\n3×3×1×1 Array{Float64, 4}:\n[:, :, 1, 1] =\n 0.0  0.207107  0.0\n 0.0  0.0       0.207107\n 0.0  0.0       0.0\n\njulia> NNlib.imrotate(arr, deg2rad(45), method=:nearest)\n3×3×1×1 Array{Float64, 4}:\n[:, :, 1, 1] =\n 0.0  0.0  1.0\n 0.0  0.0  0.0\n 0.0  0.0  0.0\n```\n\"\"\"\nfunction imrotate(arr::AbstractArray{T, 4}, θ; method=:bilinear, rotation_center::Tuple=size(arr) .÷ 2 .+ 1) where T\n    if (T <: Integer && method==:nearest || !(T <: Integer)) == false\n        throw(ArgumentError(\"If the array has an Int eltype, only method=:nearest is supported\"))\n    end\n    # prepare out, the sin and cos and type of rotation_center\n    sinθ, cosθ, rotation_center, out = _prepare_imrotate(arr, θ, rotation_center) \n    # such as 0°, 90°, 180°, 270° and only if the rotation_center is suitable\n    _check_trivial_rotations!(out, arr, θ, rotation_center) && return out\n\n    # KernelAbstractions specific\n    backend = KernelAbstractions.get_backend(arr)\n    if method == :bilinear\n        kernel! = imrotate_kernel_bilinear!(backend)\n    elseif method == :nearest\n        kernel! = imrotate_kernel_nearest!(backend)\n    else \n        throw(ArgumentError(\"No interpolation method such as $method\"))\n    end\n    kernel!(out, arr, sinθ, cosθ, rotation_center, size(arr, 1), size(arr, 2),\n            ndrange=size(arr))\n\treturn out\nend\n\n\n\"\"\"\n    ∇imrotate(dy, arr::AbstractArray{T, 4}, θ; method=:bilinear,\n                                               rotation_center=size(arr) .÷ 2 .+ 1)\n\nAdjoint for `imrotate`. Gradient only with respect to `arr` and not `θ`.\n\n# Arguments\n* `dy`: input gradient \n* `arr`: Input from primal computation\n* `θ`: rotation angle in radians\n* `method=:bilinear` or `method=:nearest`\n* `rotation_center=size(arr) .÷ 2 .+ 1` rotates around a real center pixel for even and odd sized arrays\n\"\"\"\nfunction ∇imrotate(dy, arr::AbstractArray{T, 4}, θ; method=:bilinear, \n                                               rotation_center::Tuple=size(arr) .÷ 2 .+ 1) where T\n    \n    sinθ, cosθ, rotation_center, out = _prepare_imrotate(arr, θ, rotation_center) \n    # for the adjoint, the trivial rotations go in the other direction!\n    # pass dy and not arr\n    _check_trivial_rotations!(out, dy, θ, rotation_center, adjoint=true) && return out\n\n    backend = KernelAbstractions.get_backend(arr)\n    if method == :bilinear\n        kernel! = ∇imrotate_kernel_bilinear!(backend)\n    elseif method == :nearest\n        kernel! = ∇imrotate_kernel_nearest!(backend)\n    else \n        throw(ArgumentError(\"No interpolation method such as $method\"))\n    end\n    # don't pass arr but dy! \n    kernel!(out, dy, sinθ, cosθ, rotation_center, size(arr, 1), size(arr, 2),\n            ndrange=size(arr))\n    return out\nend\n\n\n@kernel function imrotate_kernel_nearest!(out, arr, sinθ, cosθ, rotation_center, imax, jmax)\n    i, j, c, b = @index(Global, NTuple)\n\n    r(x...) = round(x..., RoundNearestTiesAway)\n    _, _, _, _, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, r) \n    if 1 ≤ yrot_int ≤ imax && 1 ≤ xrot_int ≤ jmax\n        @inbounds out[i, j, c, b] = arr[yrot_int, xrot_int, c, b]\n    end\nend\n\n\n@kernel function imrotate_kernel_bilinear!(out, arr, sinθ, cosθ, rotation_center, imax, jmax)\n    i, j, c, b = @index(Global, NTuple)\n    \n    yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, floor) \n    if 1 ≤ yrot_int ≤ imax - 1 && 1 ≤ xrot_int ≤ jmax - 1 \n\n        ydiff, ydiff_1minus, xdiff, xdiff_1minus = \n            _bilinear_helper(yrot, xrot, yrot_f, xrot_f)\n        @inbounds out[i, j, c, b] = \n            (   xdiff_1minus    * ydiff_1minus  * arr[yrot_int      , xrot_int      , c, b]\n             +  xdiff_1minus    * ydiff         * arr[yrot_int + 1  , xrot_int      , c, b]\n             +  xdiff           * ydiff_1minus  * arr[yrot_int      , xrot_int + 1  , c, b] \n             +  xdiff           * ydiff         * arr[yrot_int + 1  , xrot_int + 1  , c, b])\n    end\nend\n\n\n@kernel function ∇imrotate_kernel_nearest!(out, arr, sinθ, cosθ, rotation_center, imax, jmax)\n    i, j, c, b = @index(Global, NTuple)\n\n    r(x...) = round(x..., RoundNearestTiesAway)\n    _, _, _, _, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, r) \n    if 1 ≤ yrot_int ≤ imax && 1 ≤ xrot_int ≤ jmax \n        Atomix.@atomic out[yrot_int, xrot_int, c, b] += arr[i, j, c, b]\n    end\nend\n\n\n@kernel function ∇imrotate_kernel_bilinear!(out, arr, sinθ, cosθ, rotation_center, imax, jmax)\n    i, j, c, b = @index(Global, NTuple)\n\n    yrot, xrot, yrot_f, xrot_f, yrot_int, xrot_int = _rotate_coordinates(sinθ, cosθ, i, j, rotation_center, floor) \n    if 1 ≤ yrot_int ≤ imax - 1 && 1 ≤ xrot_int ≤ jmax - 1\n        o = arr[i, j, c, b]\n        ydiff, ydiff_1minus, xdiff, xdiff_1minus = \n            _bilinear_helper(yrot, xrot, yrot_f, xrot_f)\n        Atomix.@atomic out[yrot_int     ,   xrot_int    , c, b]  += xdiff_1minus    * ydiff_1minus * o\n        Atomix.@atomic out[yrot_int + 1 ,   xrot_int    , c, b]  += xdiff_1minus    * ydiff      * o\n        Atomix.@atomic out[yrot_int     ,   xrot_int + 1, c, b]  += xdiff           * ydiff_1minus * o\n        Atomix.@atomic out[yrot_int + 1 ,   xrot_int + 1, c, b]  += xdiff           * ydiff      * o\n    end\nend\n\n\n# is this rrule good? \n# no @thunk and @unthunk\nfunction ChainRulesCore.rrule(::typeof(imrotate), arr::AbstractArray{T}, θ; \n                              method=:bilinear, rotation_center=size(arr) .÷ 2 .+ 1) where T\n    res = imrotate(arr, θ; method, rotation_center)\n    function pb_rotate(dy)\n        ad = ∇imrotate(unthunk(dy), arr, θ; method, rotation_center)\n        return NoTangent(), ad, NoTangent()\n    end    \n\n\treturn res, pb_rotate\nend\n"
  },
  {
    "path": "src/sampling.jl",
    "content": "@inline in_bounds(h, w, H, W) = 1 ≤ h ≤ H && 1 ≤ w ≤ W\n@inline in_bounds(h, w, d, H, W, D) = 1 ≤ h ≤ H && 1 ≤ w ≤ W && 1 ≤ d ≤ D\n# Borders are considered out-of-bounds for gradient.\n@inline clip_coordinate(coordinate, dim_size) = min(dim_size, max(1, coordinate))\n@inline function ∇clip_coordinate(coordinate::C, dim_size) where {C}\n    if coordinate ≤ 1\n        return C(1), C(0)\n    elseif coordinate ≥ dim_size\n        return C(dim_size), C(0)\n    end\n    coordinate, C(1)\nend\n\n@inline unnormalize(coordinate, dim_size) = ((coordinate + 1.0) * 0.5) * (dim_size - 1.0) + 1.0\n@inline ∇unnormalize(coordinate, dim_size) = unnormalize(coordinate, dim_size), (dim_size - 1.0) * 0.5\n\n@inline compute_source_index(coordinate, dim_size, ::Val{:zeros}) = unnormalize(coordinate, dim_size)\n@inline compute_source_index(coordinate, dim_size, ::Val{:border}) = clip_coordinate(unnormalize(coordinate, dim_size), dim_size)\n\n@inline ∇compute_source_index(coordinate, dim_size, ::Val{:zeros}) = ∇unnormalize(coordinate, dim_size)\n@inline function ∇compute_source_index(coordinate, dim_size, ::Val{:border})\n    source_coordinate, grad_in = ∇unnormalize(coordinate, dim_size)\n    source_coordinate, grad_clip = ∇clip_coordinate(source_coordinate, dim_size)\n    source_coordinate, grad_in * grad_clip\nend\n\n\"\"\"\n    grid_sample(input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros)\n    grid_sample(input::AbstractArray{T, 5}, grid::AbstractArray{T, 4}; padding_mode = :zeros)\n\n    Given `input`, compute output by sampling `input` values at pixel\n    locations from `grid`. Uses bilinear interpolation to calculate output values.\n\n    This implementation assumes the extrema (`-1` and `1`) are considered\n    as referring to the center points of the input’s corner pixels\n    (i.e. align corners is `true`).\n\n    # Arguments\n\n    - `input`: Input array in `(W_in, H_in, [D_in,] C, N)` shape.\n    - `grid`: Input grid in `(2, W_out, H_out, [D_out,] N)` shape.\n        Where for each `(W_out, H_out, [D_out,] N)` grid contains `(x, y [,z])`\n        coordinates that specify sampling locations normalized by the `input` shape.\n\n        Therefore, `x`, `y` and [`z`] should have values in `[-1, 1]` range.\n        For example, `(x = -1, y = -1, [z = -1])` is the left-top[-front] pixel of `input`,\n        and `(x = 1, y = 1, [z = 1])` is the right-bottom-back pixel of `input`.\n\n        Out-of-bound values are handled according to the `padding_mode`.\n    - `padding_mode`: Out-of-bound padding.\n        `:zeros` to use `0` for out-of-bound grid locations.\n        `:border` to use border values for out-of-bound grid locations.\n        Default is `:zeros`.\n\n    # Returns\n\n    `(W_out, H_out, [D_out,] C, N)` sampled grid from `input`.\n\n    # Examples\n\n    In the example below, grid contains two out-of-bound sampling locations,\n    which are handled differently, depending on the `padding_mode`.\n\n    ```jldoctest\n    julia> x = reshape(collect(1.0:4.0), (2, 2, 1, 1))\n    2×2×1×1 Array{Float64, 4}:\n    [:, :, 1, 1] =\n    1.0  3.0\n    2.0  4.0\n\n    julia> grid = Array{Float64}(undef, 2, 3, 2, 1);\n\n    julia> grid[:, 1, 1, 1] .= (-3, -1);\n\n    julia> grid[:, 2, 1, 1] .= (0, -1);\n\n    julia> grid[:, 3, 1, 1] .= (1, -1);\n\n    julia> grid[:, 1, 2, 1] .= (-1, 1);\n\n    julia> grid[:, 2, 2, 1] .= (0, 1);\n\n    julia> grid[:, 3, 2, 1] .= (3, 1);\n\n    julia> grid_sample(x, grid; padding_mode=:zeros)\n    3×2×1×1 Array{Float64, 4}:\n    [:, :, 1, 1] =\n    0.0  3.0\n    1.5  3.5\n    2.0  0.0\n\n    julia> grid_sample(x, grid; padding_mode=:border)\n    3×2×1×1 Array{Float64, 4}:\n    [:, :, 1, 1] =\n    1.0  3.0\n    1.5  3.5\n    2.0  4.0\n    ```\n\"\"\"\nfunction grid_sample(input::AbstractArray{T,N}, grid; padding_mode = :zeros) where {T,N}\n    if N ∉ (4,5)\n        error(\"grid_sample is only supported for 4D and 5D arrays.\") \n    end\n    iC, iN = size(input)[end-1:end] \n    output_size = size(grid)[2:end-1] # W_out, H_out, [D_out]\n    output = similar(input, T, (output_size..., iC, iN))\n    grid_sample!(output, input, grid, padding_mode)\nend\n\nfunction grid_sample!(output::AbstractArray{T,4}, input::AbstractArray{T,4}, grid, padding_mode=:zeros) where {T}\n    pad = Val(padding_mode)\n    iW, iH, iC, iN = size(input)\n    _, gW, gH, _ = size(grid)\n    # Loop over each output pixel.\n    Threads.@threads for n in 1:iN\n        for w in 1:gW, h in 1:gH\n            _grid_sample_kernel!(output, input, grid, pad, w, h, n, iW, iH, iC)\n        end\n    end\n    output\nend\n\nfunction grid_sample!(output::AbstractArray{T,5}, input::AbstractArray{T,5}, grid, padding_mode=:zeros) where {T}\n    pad = Val(padding_mode)\n    iW, iH, iD, iC, iN = size(input)\n    _, gW, gH, gD, _ = size(grid)\n    # Loop over each output pixel.\n    Threads.@threads for n in 1:iN\n        for w in 1:gW, h in 1:gH, d in 1:gD\n            _grid_sample_kernel!(output, input, grid, pad, w, h, d, n, iW, iH, iD, iC)\n        end\n    end\n    output\nend\n\n@inline function _grid_sample_kernel!(\n    output::AbstractArray{T,4}, input::AbstractArray{T,4}, grid, padding_mode, w, h, n, iW, iH, iC,\n) where {T}\n    # Get the corresponding (x, y) coordinates from the grid.\n    @inbounds x, y = grid[1, w, h, n], grid[2, w, h, n]\n    ix = compute_source_index(x, iW, padding_mode)\n    iy = compute_source_index(y, iH, padding_mode)\n    # Get corner pixel values from (ix, iy) in north-east-south-west directions.\n    ix_nw, iy_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy))\n    ix_ne, iy_ne = ix_nw + 1, iy_nw\n    ix_sw, iy_sw = ix_nw, iy_nw + 1\n    ix_se, iy_se = ix_ne, iy_sw\n    # Get surfaces to each neighbor (a.k.a. interpolation weights).\n    nw = (ix_se - ix) * (iy_se - iy)\n    ne = (ix - ix_sw) * (iy_sw - iy)\n    sw = (ix_ne - ix) * (iy - iy_ne)\n    se = (ix - ix_nw) * (iy - iy_nw)\n    # ∀ channel: Calculate bilinear weighted pixel value.\n    @inbounds for c in 1:iC\n        r = zero(T)\n        if in_bounds(iy_nw, ix_nw, iH, iW)\n            r += input[ix_nw, iy_nw, c, n] * nw\n        end\n        if in_bounds(iy_ne, ix_ne, iH, iW)\n            r += input[ix_ne, iy_ne, c, n] * ne\n        end\n        if in_bounds(iy_sw, ix_sw, iH, iW)\n            r += input[ix_sw, iy_sw, c, n] * sw\n        end\n        if in_bounds(iy_se, ix_se, iH, iW)\n            r += input[ix_se, iy_se, c, n] * se\n        end\n        output[w, h, c, n] = r\n    end\nend\n\n@inline function _grid_sample_kernel!(\n    output::AbstractArray{T,5}, input::AbstractArray{T,5}, grid, padding_mode, w, h, d, n, iW, iH, iD, iC,\n) where {T}\n    # Get the corresponding (x, y, z) coordinates from the grid.\n    @inbounds x, y, z = grid[1, w, h, d, n], grid[2, w, h, d, n], grid[3, w, h, d, n]\n    ix = compute_source_index(x, iW, padding_mode)\n    iy = compute_source_index(y, iH, padding_mode)\n    iz = compute_source_index(z, iD, padding_mode)\n\n    # Get corner voxel values from (ix, iy, iz) in 8 directions (north-east-south-west-bottom-up).\n    ix_nw, iy_nw, iz_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy)), unsafe_trunc(Int, floor(iz))\n    ix_ne, iy_ne, iz_ne = ix_nw + 1, iy_nw, iz_nw\n    ix_sw, iy_sw, iz_sw = ix_nw, iy_nw + 1, iz_nw\n    ix_se, iy_se, iz_se = ix_ne, iy_sw, iz_nw\n    ix_nw_u, iy_nw_u, iz_nw_u = ix_nw, iy_nw, iz_nw + 1\n    ix_ne_u, iy_ne_u, iz_ne_u = ix_ne, iy_ne, iz_ne + 1\n    ix_sw_u, iy_sw_u, iz_sw_u = ix_sw, iy_sw, iz_sw + 1\n    ix_se_u, iy_se_u, iz_se_u = ix_se, iy_se, iz_se + 1\n\n    # Get volumes to each neighbor (a.k.a. interpolation weights).\n    nw = (ix_se - ix) * (iy_se - iy) * (iz_se_u - iz)\n    ne = (ix - ix_sw) * (iy_sw - iy) * (iz_sw_u - iz)\n    sw = (ix_ne - ix) * (iy - iy_ne) * (iz_ne_u - iz)\n    se = (ix - ix_nw) * (iy - iy_nw) * (iz_nw_u - iz)\n    nw_u = (ix_se - ix) * (iy_se - iy) * (iz - iz_nw)\n    ne_u = (ix - ix_sw) * (iy_sw - iy) * (iz - iz_sw)\n    sw_u = (ix_ne - ix) * (iy - iy_ne) * (iz - iz_ne)\n    se_u = (ix - ix_nw) * (iy - iy_nw) * (iz - iz_nw)\n\n    # ∀ channel: Calculate trilinear weighted voxel value.\n    @inbounds for c in 1:iC\n        r = zero(T)\n        if in_bounds(iy_nw, ix_nw, iz_nw, iH, iW, iD)\n            r += input[ix_nw, iy_nw, iz_nw, c, n] * nw\n        end\n        if in_bounds(iy_ne, ix_ne, iz_ne, iH, iW, iD)\n            r += input[ix_ne, iy_ne, iz_ne, c, n] * ne\n        end\n        if in_bounds(iy_sw, ix_sw, iz_sw, iH, iW, iD)\n            r += input[ix_sw, iy_sw, iz_sw, c, n] * sw\n        end\n        if in_bounds(iy_se, ix_se, iz_se, iH, iW, iD)\n            r += input[ix_se, iy_se, iz_se, c, n] * se\n        end\n        if in_bounds(iy_nw_u, ix_nw_u, iz_nw_u, iH, iW, iD)\n            r += input[ix_nw_u, iy_nw_u, iz_nw_u, c, n] * nw_u\n        end\n        if in_bounds(iy_ne_u, ix_ne_u, iz_ne_u, iH, iW, iD)\n            r += input[ix_ne_u, iy_ne_u, iz_ne_u, c, n] * ne_u\n        end\n        if in_bounds(iy_sw_u, ix_sw_u, iz_sw_u, iH, iW, iD)\n            r += input[ix_sw_u, iy_sw_u, iz_sw_u, c, n] * sw_u\n        end\n        if in_bounds(iy_se_u, ix_se_u, iz_se_u, iH, iW, iD)\n            r += input[ix_se_u, iy_se_u, iz_se_u, c, n] * se_u\n        end\n        output[w, h, d, c, n] = r\n    end\nend\n\n\n\"\"\"\n    ∇grid_sample(Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{T, 4}; padding_mode = :zeros) where T\n\n# Arguments\n\n- `Δ`: Input gradient in `(W_out, H_out, C, N)` shape\n    (same as output of the primal computation).\n- `input`: Input from primal computation in `(W_in, H_in, C, N)` shape.\n- `grid`: Grid from primal computation in `(2, W_out, H_out, N)` shape.\n- `padding_mode`: Out-of-bound padding.\n    `:zeros` to use `0` for out-of-bound grid locations.\n    `:border` to use border values for out-of-bound grid locations.\n    Should be the same as in primal computation.\n    Default is `:zeros`.\n\n# Returns\n\n`dinput` (same shape as `input`) and `dgrid` (same shape as `grid`) gradients.\n\"\"\"\nfunction ∇grid_sample(Δ::AbstractArray{T,N}, input::AbstractArray{T,N}, grid; padding_mode=:zeros) where {T, N}\n    if N ∉ (4,5)\n        error(\"∇grid_sample is only supported for 4D and 5D arrays.\") \n    end\n    dx = zeros(T, size(input))\n    dgrid = similar(grid)\n    ∇grid_sample!(dx, dgrid, Δ, input, grid, padding_mode)\nend\n\nfunction ∇grid_sample!(dx::AbstractArray{T,4}, dgrid::AbstractArray{T,4}, Δ::AbstractArray{T,4}, input::AbstractArray{T,4}, grid::AbstractArray{T,4}, padding_mode) where {T}\n    pad = Val(padding_mode)\n    iW, iH, iC, iN = size(input)\n    gW, gH = size(grid, 2), size(grid, 3)\n    # Loop over each output pixel.\n    Threads.@threads for n in 1:iN\n        for w in 1:gW, h in 1:gH\n            _∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, pad, w, h, n, iW, iH, iC)\n        end\n    end\n    dx, dgrid\nend\n\nfunction ∇grid_sample!(dx::AbstractArray{T,5}, dgrid::AbstractArray{T,5}, Δ::AbstractArray{T,5}, input::AbstractArray{T,5}, grid::AbstractArray{T,5}, padding_mode) where {T}\n    pad = Val(padding_mode)\n    iW, iH, iD, iC, iN = size(input)\n    gW, gH, gD = size(grid, 2), size(grid, 3), size(grid, 4)\n    # Loop over each output voxel.\n    Threads.@threads for n in 1:iN\n        for w in 1:gW, h in 1:gH, d in 1:gD\n            _∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, pad, w, h, d, n, iW, iH, iD, iC)\n        end\n    end\n    dx, dgrid\nend\n\n@inline function _∇grid_sample_kernel!(\n    dx::AbstractArray{T,4}, dgrid::AbstractArray{V,4}, Δ::AbstractArray{T,4}, input::AbstractArray{T,4}, grid::AbstractArray{V,4}, padding_mode, w, h, n, iW, iH, iC,\n) where {T,V}\n    # Get corresponding (x, y) from grid.\n    @inbounds x, y = grid[1, w, h, n], grid[2, w, h, n]\n    # Compute multipliers for gradients on ix, iy.\n    ix, gix_mult = ∇compute_source_index(x, iW, padding_mode)\n    iy, giy_mult = ∇compute_source_index(y, iH, padding_mode)\n    # Get corner pixel values from (ix, iy) in north-east-south-west directions.\n    ix_nw, iy_nw = unsafe_trunc(Int, floor(ix)), unsafe_trunc(Int, floor(iy))\n    ix_ne, iy_ne = ix_nw + 1, iy_nw\n    ix_sw, iy_sw = ix_nw, iy_nw + 1\n    ix_se, iy_se = ix_ne, iy_sw\n    # Get surfaces to each neighbor (a.k.a. interpolation weights).\n    nw = (ix_se - ix) * (iy_se - iy)\n    ne = (ix - ix_sw) * (iy_sw - iy)\n    sw = (ix_ne - ix) * (iy - iy_ne)\n    se = (ix - ix_nw) * (iy - iy_nw)\n    # ∀ channel: Calculate billinear weighted pixel value.\n    gix, giy = zero(V), zero(V)\n    @inbounds for c in 1:iC\n        g_out = Δ[w, h, c, n]\n        # Calculate dx and dgrid partials.\n        if in_bounds(iy_nw, ix_nw, iH, iW)\n            _safe_add!(dx, g_out * nw, ix_nw, iy_nw, c, n)\n            nw_val = input[ix_nw, iy_nw, c, n]\n            gix -= nw_val * (iy_se - iy) * g_out\n            giy -= nw_val * (ix_se - ix) * g_out\n        end\n        if in_bounds(iy_ne, ix_ne, iH, iW)\n            _safe_add!(dx, g_out * ne, ix_ne, iy_ne, c, n)\n            ne_val = input[ix_ne, iy_ne, c, n]\n            gix += ne_val * (iy_sw - iy) * g_out\n            giy -= ne_val * (ix - ix_sw) * g_out\n        end\n        if in_bounds(iy_sw, ix_sw, iH, iW)\n            _safe_add!(dx, g_out * sw, ix_sw, iy_sw, c, n)\n            sw_val = input[ix_sw, iy_sw, c, n]\n            gix -= sw_val * (iy - iy_ne) * g_out\n            giy += sw_val * (ix_ne - ix) * g_out\n        end\n        if in_bounds(iy_se, ix_se, iH, iW)\n            _safe_add!(dx, g_out * se, ix_se, iy_se, c, n)\n            se_val = input[ix_se, iy_se, c, n]\n            gix += se_val * (iy - iy_nw) * g_out\n            giy += se_val * (ix - ix_nw) * g_out\n        end\n    end\n    @inbounds dgrid[1, w, h, n] = gix_mult * gix\n    @inbounds dgrid[2, w, h, n] = giy_mult * giy\nend\n\n@inline function _∇grid_sample_kernel!(\n    dx::AbstractArray{T,5}, dgrid::AbstractArray{V,5}, Δ::AbstractArray{T,5}, input::AbstractArray{T,5}, grid::AbstractArray{V,5}, padding_mode, w, h, d, n, iW, iH, iD, iC,\n) where {T,V}\n    # Get corresponding (x, y, z) from grid.\n    @inbounds x, y, z = grid[1, w, h, d, n], grid[2, w, h, d, n], grid[3, w, h, d, n]\n    # Compute multipliers for gradients on ix, iy, iz.\n    ix, gix_mult = ∇compute_source_index(x, iW, padding_mode)\n    iy, giy_mult = ∇compute_source_index(y, iH, padding_mode)\n    iz, giz_mult = ∇compute_source_index(z, iD, padding_mode)\n     \n    # Get corner pixel values from (ix, iy, iz)\n    ix_0 = unsafe_trunc(Int, floor(ix))\n    iy_0 = unsafe_trunc(Int, floor(iy))\n    iz_0 = unsafe_trunc(Int, floor(iz))\n    ix_1 = ix_0 + 1\n    iy_1 = iy_0 + 1\n    iz_1 = iz_0 + 1\n    \n    # Get difference of coordinate\n    wx_0 = ix - ix_0\n    wy_0 = iy - iy_0\n    wz_0 = iz - iz_0\n    wx_1 = ix_1 - ix\n    wy_1 = iy_1 - iy\n    wz_1 = iz_1 - iz\n    \n    # Calculate weights (volume of diagnal vertex cube) \n    # w_{abc} = wx_{¬a}*wy_{¬b}*wz_{¬c}\n    weight_000 = wx_1 * wy_1 * wz_1\n    weight_001 = wx_1 * wy_1 * wz_0\n    weight_010 = wx_1 * wy_0 * wz_1\n    weight_011 = wx_1 * wy_0 * wz_0\n    weight_100 = wx_0 * wy_1 * wz_1\n    weight_101 = wx_0 * wy_1 * wz_0\n    weight_110 = wx_0 * wy_0 * wz_1\n    weight_111 = wx_0 * wy_0 * wz_0\n\n    # ∂w_{abc}/∂x=(-1)^{¬a} wy_{¬b}*wz_{¬c}, ∂w/∂y = (-1)^{¬b} wx_{¬a}*wz_{¬c}, ∂w/∂z=(-1)^{¬c} wx_{¬a}*wy_{¬b}\n    # abc are the index of the vertex of the cube (001,010...)\n\n    # Initialize gradient accumulators\n    gix, giy, giz = zero(V), zero(V), zero(V)\n    \n    @inbounds for c in 1:iC\n        g_out = Δ[w, h, d, c, n]\n        \n        # Calculate dx and dgrid partials for all 8 corners\n        if in_bounds(iy_0, ix_0, iz_0, iH, iW, iD)\n            _safe_add!(dx, g_out * weight_000, ix_0, iy_0, iz_0, c, n)\n            val = input[ix_0, iy_0, iz_0, c, n]\n            gix -= val * wy_1 * wz_1 * g_out\n            giy -= val * wx_1 * wz_1 * g_out\n            giz -= val * wx_1 * wy_1 * g_out\n        end\n\n        if in_bounds(iy_0, ix_0, iz_1, iH, iW, iD)\n            _safe_add!(dx, g_out * weight_001, ix_0, iy_0, iz_1, c, n)\n            val = input[ix_0, iy_0, iz_1, c, n]\n            gix -= val * wy_1 * wz_0 * g_out\n            giy -= val * wx_1 * wz_0 * g_out\n            giz += val * wx_1 * wy_1 * g_out\n        end\n        \n        if in_bounds(iy_1, ix_0, iz_0, iH, iW, iD)\n            _safe_add!(dx, g_out * weight_010, ix_0, iy_1, iz_0, c, n)\n            val = input[ix_0, iy_1, iz_0, c, n]\n            gix -= val * wy_0 * wz_1 * g_out\n            giy += val * wx_1 * wz_1 * g_out\n            giz -= val * wx_1 * wy_0 * g_out\n        end\n        \n        if in_bounds(iy_1, ix_0, iz_1, iH, iW, iD)\n            _safe_add!(dx, g_out * weight_011, ix_0, iy_1, iz_1, c, n)\n            val = input[ix_0, iy_1, iz_1, c, n]\n            gix -= val * wy_0 * wz_0 * g_out\n            giy += val * wx_1 * wz_0 * g_out\n            giz += val * wx_1 * wy_0 * g_out\n        end\n\n        if in_bounds(iy_0, ix_1, iz_0, iH, iW, iD)\n            _safe_add!(dx, g_out * weight_100, ix_1, iy_0, iz_0, c, n)\n            val = input[ix_1, iy_0, iz_0, c, n]\n            gix += val * wy_1 * wz_1 * g_out\n            giy -= val * wx_0 * wz_1 * g_out\n            giz -= val * wx_0 * wy_1 * g_out\n        end\n        \n        if in_bounds(iy_0, ix_1, iz_1, iH, iW, iD)\n            _safe_add!(dx, g_out * weight_101, ix_1, iy_0, iz_1, c, n)\n            val = input[ix_1, iy_0, iz_1, c, n]\n            gix += val * wy_1 * wz_0 * g_out\n            giy -= val * wx_0 * wz_0 * g_out\n            giz += val * wx_0 * wy_1 * g_out\n        end\n\n        if in_bounds(iy_1, ix_1, iz_0, iH, iW, iD)\n            _safe_add!(dx, g_out * weight_110, ix_1, iy_1, iz_0, c, n)\n            val = input[ix_1, iy_1, iz_0, c, n]\n            gix += val * wy_0 * wz_1 * g_out\n            giy += val * wx_0 * wz_1 * g_out\n            giz -= val * wx_0 * wy_0 * g_out\n        end\n        \n        if in_bounds(iy_1, ix_1, iz_1, iH, iW, iD)\n            _safe_add!(dx, g_out * weight_111, ix_1, iy_1, iz_1, c, n)\n            val = input[ix_1, iy_1, iz_1, c, n]\n            gix += val * wy_0 * wz_0 * g_out\n            giy += val * wx_0 * wz_0 * g_out\n            giz += val * wx_0 * wy_0 * g_out\n        end\n    end\n    \n    @inbounds dgrid[1, w, h, d, n] = gix_mult * gix\n    @inbounds dgrid[2, w, h, d, n] = giy_mult * giy\n    @inbounds dgrid[3, w, h, d, n] = giz_mult * giz\nend\n\n@inline function _safe_add!(dx, value, ix, iy, c, n)\n    @inbounds dx[ix, iy, c, n] += value\nend\n\n@inline function _safe_add!(dx, value, ix, iy, iz, c, n)\n    @inbounds dx[ix, iy, iz, c, n] += value\nend\n\nfunction rrule(::typeof(grid_sample), x, grid; padding_mode)\n    y = grid_sample(x, grid; padding_mode=padding_mode)\n    function grid_sample_pullback(Δ)\n        ∇x, ∇grid = ∇grid_sample(unthunk(Δ), x, grid; padding_mode=padding_mode)\n        NoTangent(), ∇x, ∇grid\n    end\n    return y, grid_sample_pullback\nend\n"
  },
  {
    "path": "src/scatter.jl",
    "content": "## Scatter API\n#   - Scatter:\n#     - scatter(op, src, idx)\n#     - scatter!(op, dst, src, idx)\n#   - Scatter destination backpropagation\n#     - ∇scatter!_dst\n#   - Scatter source backpropagation\n#     - ∇scatter_src\n#     - ∇scatter!_src\n#\n\ntypelength(::Type{<:Number}) = 1\ntypelength(::Type{<:NTuple{M}}) where M = M\ntypelength(::Type{CartesianIndex{M}}) where M = M\n\n\"\"\"\nPerforms dimensional consistency checks and return the\ndimensionality of the scattered objects.\n\"\"\"\nfunction scatter_dims(\n    X::AbstractArray{Tx,Nx}, Y::AbstractArray{Ty,Ny},\n    idx::AbstractArray{Tidx,Nidx},\n) where {Tx,Ty,Tidx,Nx,Ny,Nidx}\n    dims = scatter_dims(Nx, Ny, typelength(Tidx), Nidx)\n    size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError(\"Incompatible input shapes.\"))\n    size(Y)[dims+1:end] == size(idx) || throw(ArgumentError(\"Incompatible input shapes.\"))\n    return dims\nend\n\nfunction scatter_dims(Nx, Ny, M, Nidx)\n    @assert Nx - M == Ny - Nidx \"Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx).\"\n    dims = Nx - M\n    dims < 0 && throw(ArgumentError(\"dims must be non-negative but got dims=$dims.\"))\n    return dims\nend\n\n_view(X, colons, k) = view(X, colons..., k...)\n_view(X, colons, k::Union{Integer, CartesianIndex}) = view(X, colons..., k)\n\n\"\"\"\n    NNlib.scatter!(op, dst, src, idx)\n\nScatter operation, which writes data in `src` into `dst` at locations `idx`.\nA binary reduction operator `op` is applied during the scatter.\nFor each index `k` in `idx`, accumulates values in `dst` according to\n\n    dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...])\n\nSee also [`scatter`](@ref), [`gather`](@ref).\n\n# Arguments\n\n- `op`: Operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max`, `min` and `mean`.\n- `dst`: The destination for `src` to aggregate to. This argument will be mutated.\n- `src`: The source data for aggregating.\n- `idx`: The mapping for aggregation from source (index) to destination (value).\n         The `idx` array can contain either integers or tuples.\n\n# Examples\n```jldoctest\njulia> NNlib.scatter!(+, ones(3), [10,100], [1,3])\n3-element Vector{Float64}:\n  11.0\n   1.0\n 101.0\n\njulia> NNlib.scatter!(*, fill(0.5, 2, 4), [1 10; 100 1000], [3,2])\n2×4 Matrix{Float64}:\n 0.5    5.0   0.5  0.5\n 0.5  500.0  50.0  0.5\n```\n\"\"\"\nfunction scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) where OP\n    dims = scatter_dims(dst, src, idx)\n    colons = Base.ntuple(_->Colon(), dims)\n    for k in CartesianIndices(idx)\n        dst_v = _view(dst, colons, idx[k])\n        src_v = _view(src, colons, k)\n        dst_v .= (op).(dst_v, src_v)\n    end\n    dst\nend\n\nfor AT in (AbstractArray, AnyGPUArray)\n    @eval function scatter!(op::typeof(mean), dst::$AT, src::$AT, idx::$AT)\n        Ns = scatter!(+, zero(dst), one.(src), idx)\n        dst_ = scatter!(+, zero(dst), src, idx)\n        dst .+= safe_div.(dst_, Ns)\n        return dst\n    end\nend\n\nfunction scatter!(op::OP, dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) where OP\n    n_dims = scatter_dims(dst, src, idx)\n    args = if n_dims == 0\n        ndrange = length(idx)\n        ()\n    else\n        dims = size(dst)[1:n_dims]\n        max_dims_idx = prod(dims)\n        ndrange = max_dims_idx * length(idx)\n        (CartesianIndices(dims), max_dims_idx)\n    end\n    _scatter!(KernelAbstractions.get_backend(dst))(\n        op, dst, src, idx, args...; ndrange)\n    dst\nend\n\n@kernel function _scatter!(op::OP, dst, src, idxs) where OP\n    i = @index(Global)\n    @inbounds idx = Tuple(_convert_i64(idxs[i]))\n    @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i])\n    # FIXME `@atomic` macro silently fails to perform atomic op below\n    # @atomic dst[idx...] = op(dst[idx...], src[i])\nend\n\n@kernel function _scatter!(\n    op::OP, dst, src, idxs, dim_ids::CartesianIndices, max_dims_idx::Int,\n) where OP\n    i = @index(Global)\n    j, k = divrem(i - 1, max_dims_idx)\n    @inbounds idx = (Tuple(dim_ids[k + 1])..., Tuple(_convert_i64(idxs[j + 1]))...)\n    @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i])\n    # FIXME `@atomic` macro silently fails to perform atomic op below\n    # dim_i = Tuple(dim_ids[k + 1])\n    # idx = idxs[j + 1]\n    # @atomic dst[dim_i..., idx...] = op(dst[dim_i..., idx...], src[i])\nend\n\n# Allow non-Int64 indices by converting them to Int64 when index eltype <: Integer.\n# All other index types (tuples, cartesian indices) must be in Int64 already.\n@inline _convert_i64(x::Int) = x\n@inline _convert_i64(x::Integer) = Int(x)\n@inline _convert_i64(x) = x\n\n\"\"\"\n    NNlib.scatter(op, src, idx; [init, dstsize])\n\nScatter operation allocating a destination array `dst` and\ncalling `scatter!(op, dst, src, idx)` on it.\n\n* If keyword `init` is provided, it is used to initialize the content of `dst`.\n  Otherwise, the init values is inferred from the reduction operator `op`\n  for some common operators (e.g. `init = 0` for `op = +`).\n\n* If `dstsize` is provided, it will be used to define the size of\n  destination array, otherwise it will be inferred by `src` and `idx`.\n\nSee [`scatter!`](@ref) for full details on how `idx` works.\n\n# Examples\n```jldoctest\njulia> NNlib.scatter(+, [10,100,1000], [3,1,2])\n3-element Vector{Int64}:\n  100\n 1000\n   10\n\njulia> NNlib.scatter(+, [1 2 3 4; 5 6 7 8], [2,1,1,5])\n2×5 Matrix{Int64}:\n  5  1  0  0  4\n 13  5  0  0  8\n\njulia> NNlib.scatter(*, [10,200,3000], [1,4,2]; init = 10, dstsize = 6)\n6-element Vector{Int64}:\n   100\n 30000\n    10\n  2000\n    10\n    10\n```\n\"\"\"\nfunction scatter(\n    op::OP, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx};\n    init = nothing, dstsize = nothing,\n) where {Tsrc,Tidx,Nsrc,Nidx,OP}\n    dims = Nsrc - Nidx\n    dstsz = isnothing(dstsize) ? (size(src)[1:dims]..., maximum_dims(idx)...) : dstsize\n    dst = similar(src, Tsrc, dstsz)\n    xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init\n    fill!(dst, xinit)\n    scatter!(op, dst, src, idx)\nend\n\nscatter_empty(op, T) = Base.reduce_empty(op, T)\nscatter_empty(op::typeof(-), T) = zero(T)\nscatter_empty(op::typeof(/), T) = one(T)\nscatter_empty(op::typeof(min), T) = typemax(T)\nscatter_empty(op::typeof(max), T) = typemin(T)\nscatter_empty(op::typeof(mean), T) = zero(T)\n\n## Gradients\n\n∇scatter!_src(op, Δ, dst, src, idx) = ∇scatter_src(op, Δ, dst, src, idx)\n∇scatter!_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src, idx) =\n    gather(dst, idx) .* ∇scatter_src(op, Δ, dst, src, idx)\n∇scatter!_dst(op, Δ, dst, y) = Δ\n∇scatter!_dst(op::Union{typeof(max),typeof(min)}, Δ, dst_old, dst) =\n    (dst_old .== op.(dst_old, dst)) .* Δ\n\nmodify_src(::typeof(+), X) = X\nmodify_src(::typeof(-), X) = -X\nmodify_src(::typeof(*), X, Y) = X\nmodify_src(::typeof(/), X, Y) = .-X ./ Y.^2\n\n∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) =\n    modify_src(op, gather(Δ, idx))\n∇scatter_src(::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) =\n    (src .== gather(dst, idx)) .* gather(Δ, idx)\n\nfunction ∇scatter_src(\n    op::Union{typeof(*),typeof(/)}, Δ, dst,\n    src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx},\n) where {Tsrc,Tidx,Nsrc,Nidx}\n    dims = Nsrc - Nidx\n    Δsrc = modify_src(op, gather(Δ, idx), src)\n    rev_idx = reverse_indices(idx)\n    ax = CartesianIndices(axes(src)[1:dims])\n    for k in CartesianIndices(idx)\n        inds = filter(x -> x != k, rev_idx[idx[k]])\n        for i in ax\n            Δsrc[i, k] = op(Δsrc[i, k], prod(j -> src[i, j], inds))\n        end\n    end\n    Δsrc\nend\n\nfunction ∇scatter_src(\n    op::Union{typeof(*), typeof(/)}, Δ, dst,\n    src::AnyGPUArray{Tsrc, Nsrc}, idx::AnyGPUArray{Tidx, Nidx},\n) where {Tsrc, Nsrc, Tidx, Nidx}\n    n_dims = Nsrc - Nidx\n    Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)\n    rev_idx = NNlib.reverse_indices(idx)\n\n    args = if n_dims == 0\n        ndrange = length(idx)\n        ()\n    else\n        dims = size(dst)[1:n_dims]\n        max_dims_idx = prod(dims)\n        ndrange = max_dims_idx * length(idx)\n        (CartesianIndices(dims), max_dims_idx)\n    end\n    _∇scatter_src(KernelAbstractions.get_backend(src))(\n        op, Δsrc, src, idx, rev_idx, args...; ndrange)\n    KernelAbstractions.unsafe_free!(rev_idx)\n    return Δsrc\nend\n\n@kernel function _∇scatter_src(op, Δsrc, src::AbstractArray{T}, idx, rev_idx) where T\n    i = @index(Global)\n    cart_j = CartesianIndices(idx)[i]\n    @inbounds begin\n        inds = rev_idx[Tuple(idx[cart_j])...]\n        x = one(T)\n        for k in inds\n            x *= src[k]\n        end\n        x /= src[cart_j]\n        Δsrc[cart_j] = op(Δsrc[cart_j], x)\n    end\nend\n\n@kernel function _∇scatter_src(\n    op, Δsrc, src::AbstractArray{T}, idx, rev_idx,\n    dim_ids::CartesianIndices, max_dims_idx::Int,\n) where T\n    i = @index(Global)\n    j, k = fldmod1(i, max_dims_idx)\n    @inbounds begin\n        cart_j = CartesianIndices(idx)[j]\n        cart_k = dim_ids[k]\n        inds = rev_idx[Tuple(cart_j)...]\n        x = one(T)\n        for s in inds\n            x *= src[Tuple(cart_k)..., Tuple(s)...]\n        end\n        x /= src[i]\n        Δsrc[i] = op(Δsrc[i], x)\n    end\nend\n\nfunction ∇scatter_src(\n    ::typeof(mean), Δ, dst,\n    src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx},\n) where {Tsrc,Tidx,Nsrc,Nidx}\n    M = typelength(Tidx)\n    num = gather(Δ, idx)\n    counts = fill!(similar(Δ, Int, size(Δ)[end-M+1:end]), 0)\n    scatter!(+, counts, fill!(similar(idx, Int), 1), idx)\n    den = gather(counts, idx)\n    # make num and den broadcast compatible\n    for i in 1:ndims(num)-ndims(den)\n        den = unsqueeze(den)\n    end\n    return safe_div.(num, den)\nend\n\n∇scatter_src(op, Δ, dst, src, idx) = ∇scatter_src(op, unthunk(Δ), dst, src, idx)\n\nfunction rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)\n    dst_old = copy(dst)\n    scatter!(op, dst, src, idx)\n    scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, unthunk(Δ), dst_old, dst), ∇scatter!_src(op, unthunk(Δ), dst, src, idx), NoTangent())\n    dst, scatter!_pullback\nend\n\nfunction rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray; kws...)\n    y = scatter(op, src, idx; kws...)\n    scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, unthunk(Δ), y, src, idx), NoTangent())\n    y, scatter_pullback\nend\n"
  },
  {
    "path": "src/softmax.jl",
    "content": "\n\"\"\"\n    softmax(x; dims = 1)\n\n[Softmax](https://en.wikipedia.org/wiki/Softmax_function) turns input array `x`\ninto probability distributions that sum to 1 along the dimensions specified by `dims`.\nIt is semantically equivalent to the following:\n\n    softmax(x; dims = 1) = exp.(x) ./ sum(exp.(x), dims = dims)\n\nwith additional manipulations enhancing numerical stability.\n\nFor a matrix input `x` it will by default (`dims = 1`) treat it as a batch of vectors,\nwith each column independent. Keyword `dims = 2` will instead treat rows independently, and so on.\n\nSee also [`logsoftmax`](@ref).\n\n# Examples\n\n```jldoctest; filter = r\"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?\"\njulia> softmax([1, 2, 3])\n3-element Vector{Float64}:\n 0.09003057317038046\n 0.24472847105479764\n 0.6652409557748218\n\njulia> softmax([1 2 3; 2 2 2])  # dims=1\n2×3 Matrix{Float64}:\n 0.268941  0.5  0.731059\n 0.731059  0.5  0.268941\n\njulia> softmax([1 2 3; 2 2 2]; dims=2)\n2×3 Matrix{Float64}:\n 0.0900306  0.244728  0.665241\n 0.333333   0.333333  0.333333\n```\n\nNote that, when used with Flux.jl, `softmax` must not be passed to layers like `Dense`\nwhich accept an activation function. The activation is broadcasted over the result,\nthus applies to individual numbers. But `softmax` always needs to see the whole column.\n\n```julia-repl\njulia> using Flux\n\njulia> x = randn(Float32, 4, 4, 3, 13);\n\njulia> model = Chain(Conv((4, 4), 3 => 8, tanh), Flux.flatten, Dense(8 => 7), softmax);\n\njulia> model(x) |> size\n(7, 13)\n\njulia> Dense(4 => 7, softmax)(x)\nERROR: `softmax(x)` called with a number, but it expects an array. \n```\n\"\"\"\nsoftmax(x::AbstractArray{T}; dims = 1) where {T} = softmax!(similar(x, float(T)), x; dims)\n\nsoftmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims)\n\nfunction softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}\n    max_ = fast_maximum(x; dims)\n    if all(isfinite, max_)\n        @fastmath out .= exp.(x .- max_)\n    else\n        _zero, _one, _inf = T(0), T(1), T(Inf)\n        @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))\n    end\n    tmp = dims isa Colon ? sum(out) : sum!(max_, out)\n    out ./= tmp\nend\n\nfunction ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}\n    dx = if within_gradient(y)\n        tmp = dy .* y\n        tmp .- y .* sum(tmp; dims)\n    else\n        # This path is faster, only safe for 1st derivatives though.\n        # Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads,\n        # but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30\n        out = similar(y, promote_type(T,S))  # sure to be mutable\n        out .= dy .* y\n        out .= out .- y .* sum(out; dims)\n    end\nend\n\nfunction rrule(::typeof(softmax), x; dims = 1)\n    y = softmax(x; dims)\n    softmax_pullback(dy) = (NoTangent(), ∇softmax_data(unthunk(dy), y; dims))\n    return y, softmax_pullback\nend\n\nfast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))\n\n\"\"\"\n    logsoftmax(x; dims = 1)\n\nComputes the log of softmax in a more numerically stable\nway than directly taking `log.(softmax(xs))`. Commonly used in\ncomputing cross entropy loss.\n\nIt is semantically equivalent to the following:\n\n    logsoftmax(x; dims = 1) = x .- log.(sum(exp.(x), dims = dims))\n\nSee also [`softmax`](@ref).\n\"\"\"\nlogsoftmax(x::AbstractArray{T}; dims = 1) where {T} = logsoftmax!(similar(x, float(T)), x; dims)\n\nlogsoftmax!(x::AbstractArray; dims = 1) = logsoftmax!(x, x; dims)\n\nfunction logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}\n    max_ = fast_maximum(x; dims)\n    if all(isfinite, max_)\n        out .= x .- max_\n    else\n        _zero, _minf, _inf = T(0), T(-Inf), T(Inf)\n        @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _zero, _minf), x - max_)\n    end\n    @fastmath log_ = log.(sum(exp, out; dims))\n    out .-= log_\nend\n\nfunction ∇logsoftmax_data(dy::AbstractArray, y::AbstractArray; dims = 1)\n    # This was previously `∇logsoftmax!(dx, dy, x, y; dims)` to allow CUDA overloads, but that was slow.\n    dx = dy .- sum(dy; dims) .* exp.(y)\nend\n    \nfunction rrule(::typeof(logsoftmax), x; dims = 1)\n    y = logsoftmax(x; dims)\n    logsoftmax_pullback(dy) = (NoTangent(), ∇logsoftmax_data(unthunk(dy), y; dims))\n    return y, logsoftmax_pullback\nend\n\n\"\"\"\n    logsumexp(x; dims = :)\n\nComputes `log.(sum(exp.(x); dims))` in a numerically stable way.\nWithout `dims` keyword this returns a scalar.\n\nSee also [`logsoftmax`](@ref).\n\"\"\"\nfunction logsumexp(x::AbstractArray; dims = :)\n    max_ = fast_maximum(x; dims)\n    @fastmath max_ .+ log.(sum(exp.(x .- max_); dims))\nend\n\nfunction rrule(::typeof(logsumexp), x; dims = :)\n    # The gradient is `softmax`, but both compute `tmp` so it's worth saving.\n    max_ = fast_maximum(x; dims)\n    @fastmath tmp = exp.(x .- max_)\n    @fastmath y = max_ .+ log.(sum(tmp; dims))\n    logsumexp_pullback(dy) = (NoTangent(), unthunk(dy) .* tmp ./ sum(tmp; dims))\n    return y, logsumexp_pullback\nend\n\n# Informative error message if any of the softmax variants is called with a number\nfor f in (:softmax, :logsoftmax, :softmax!, :logsoftmax!, :logsumexp)\n    @eval $(f)(x::Number, args...) = \n      error(\"`\", $(string(f)), \"(x)` called with a number, but it expects an array. Usually this is because a layer like `Dense(3,4,softmax)` is broadcasting it like an activation function; `softmax` needs to be outside the layer.\")\nend\n"
  },
  {
    "path": "src/upsample.jl",
    "content": "\"\"\"\n    pixel_shuffle(x, r::Integer)\n\nPixel shuffling operation, upscaling by a factor `r`.\n\nFor 4-arrays representing `N` images, the operation converts input `size(x) == (W, H, r^2*C, N)`\nto output of size `(r*W, r*H, C, N)`. For `D`-dimensional data, it expects `ndims(x) == D+2`\nwith channel and batch dimensions, and divides the number of channels by `r^D`.\n\nUsed in super-resolution networks to upsample towards high resolution features.\nReference: Shi et. al., \"Real-Time Single Image and Video Super-Resolution ...\", CVPR 2016, https://arxiv.org/abs/1609.05158\n\n# Examples\n\n```jldoctest\njulia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]\n2×3×4×1 Array{Float64, 4}:\n[:, :, 1, 1] =\n 11.1  12.1  13.1\n 21.1  22.1  23.1\n\n[:, :, 2, 1] =\n 11.2  12.2  13.2\n 21.2  22.2  23.2\n\n[:, :, 3, 1] =\n 11.3  12.3  13.3\n 21.3  22.3  23.3\n\n[:, :, 4, 1] =\n 11.4  12.4  13.4\n 21.4  22.4  23.4\n\njulia> pixel_shuffle(x, 2)  # 4 channels used up as 2x upscaling of image dimensions\n4×6×1×1 Array{Float64, 4}:\n[:, :, 1, 1] =\n 11.1  11.3  12.1  12.3  13.1  13.3\n 11.2  11.4  12.2  12.4  13.2  13.4\n 21.1  21.3  22.1  22.3  23.1  23.3\n 21.2  21.4  22.2  22.4  23.2  23.4\n\njulia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1]\n3×6×1 Array{Float64, 3}:\n[:, :, 1] =\n 1.1  1.2  1.3  1.4  1.5  1.6\n 2.1  2.2  2.3  2.4  2.5  2.6\n 3.1  3.2  3.3  3.4  3.5  3.6\n\njulia> pixel_shuffle(y, 2)  # 1D image, with 6 channels reduced to 3\n6×3×1 Array{Float64, 3}:\n[:, :, 1] =\n 1.1  1.3  1.5\n 1.2  1.4  1.6\n 2.1  2.3  2.5\n 2.2  2.4  2.6\n 3.1  3.3  3.5\n 3.2  3.4  3.6\n```\n\"\"\"\nfunction pixel_shuffle(x::AbstractArray, r::Integer)\n    ndims(x) > 2 || throw(ArgumentError(\"expected x with at least 3 dimensions\"))\n    d = ndims(x) - 2\n    sizein = size(x)[1:d]\n    cin, n = size(x, d+1), size(x, d+2)\n    cin % r^d == 0 || throw(ArgumentError(\"expected channel dimension to be divisible by r^d = $(\n        r^d), where d=$d is the number of spatial dimensions. Given r=$r, input size(x) = $(size(x))\"))\n    cout = cin ÷ r^d\n    x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n)\n    perm = hcat(d+1:2d, 1:d) |> transpose |> vec  # = [d+1, 1, d+2, 2, ..., 2d, d]\n    x = permutedims(x, (perm..., 2d+1, 2d+2))\n    return reshape(x, map(s -> s*r, sizein)..., cout, n)\nend\n\n#\n# Upsampling\n#\n# GPU based bilinear upsampling including its gradient\n#\n# Based on the Caffe2 implementation at:\n# The code is a translation from the following files:\n# - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/operators/upsample_op.cu\n# - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/core/common_gpu.h\n#\n# Copyright (c) 2016-2021 Facebook Inc.\n# Copyright (c) 2015 Google Inc.\n# Copyright (c) 2015 Yangqing Jia\n# Copyright 2019-2020 Kakao Brain\n#\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without modification, are\n# permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright notice, this list of\n#    conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright notice, this list of\n#    conditions and the following disclaimer in the documentation and/or other materials\n#    provided with the distribution.\n#\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America and\n#    IDIAP Research Institute nor the names of its contributors may be used to endorse or\n#    promote products derived from this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY\n# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF\n# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE\n# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,\n# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)\n# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR\n# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n# Forward and backward pass have been tested to produce the same output\n# as pytorch with align_corners=True - it works modulo bit noise.\n# pytorch's default is align_corners=False, because otherwise the gradients depend on the\n# image size, which should be avoided -> this should be considered here as well\n\n\"\"\"\n    upsample_nearest(x, scale::NTuple{S,Int})\n    upsample_nearest(x; size::NTuple{S,Int})\n\nUpsamples the array `x` by integer multiples along the first `S` dimensions.\nSubsequent dimensions of `x` are not altered.\n\nEither the `scale` factors or the final output `size` can be specified.\n\nSee also [`upsample_bilinear`](@ref), for two dimensions of an `N=4` array.\n\n# Example\n```jldoctest\njulia> upsample_nearest([1 2 3; 4 5 6], (2, 3))\n4×9 Matrix{$Int}:\n 1  1  1  2  2  2  3  3  3\n 1  1  1  2  2  2  3  3  3\n 4  4  4  5  5  5  6  6  6\n 4  4  4  5  5  5  6  6  6\n\njulia> ans == upsample_nearest([1 2 3; 4 5 6]; size=(4, 9))  # equivalent\ntrue\n\njulia> upsample_nearest([1 2 3; 4 5 6], (2,))\n4×3 Matrix{$Int}:\n 1  2  3\n 1  2  3\n 4  5  6\n 4  5  6\n\njulia> ans == upsample_nearest([1 2 3; 4 5 6], size=(4,))\ntrue\n```\n\"\"\"\nfunction upsample_nearest(x::AbstractArray; size::NTuple{S,Int}) where S\n    xsize = Base.size(x)[1:S]\n    all(size .% xsize .== 0) || throw(ArgumentError(\"expected output size divisible by input size\"))\n    return upsample_nearest(x, size .÷ xsize)\nend\n\nfunction upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S}\n    S in 1:N || throw(ArgumentError(\"can't upsample ndims(x)=$N with scale=$scales\"))\n    outsize = ntuple(d -> d<=S ? scales[d] * size(x,d) : size(x,d), N)\n    out = similar(x, T, outsize)\n    writesize = ntuple(N+S) do d\n        d > 2S && return size(x, d-S)\n        isodd(d) ? scales[cld(d,2)] : size(x, cld(d,2))\n    end\n    readsize = ntuple(N+S) do d\n        d > 2S && return size(x, d-S)\n        isodd(d) ? 1 : size(x, cld(d,2))\n    end\n    reshape(out, writesize) .= reshape(x, readsize)\n    out\nend\n\n\"\"\"\n    ∇upsample_nearest(Δ::AbstractArray{T,3}, scales::NTuple{S, <:Integer}) where T\n\n# Arguments\n- `Δ`: Incoming gradient array, backpropagated from downstream layers\n- `scales`: scales by which the image was upsampled in the first place\n\n# Outputs\n- `dx`: Downsampled version of `Δ`\n\"\"\"\nfunction ∇upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S}\n    outsize = ntuple(N) do d\n        d > S && return size(x,d)\n        rem(size(x,d), scales[d]) == 0 || throw(ArgumentError(\"expected input array evenly divisible by scale=$scales, got size(x)=$(size(x))\"))\n        div(size(x,d), scales[d])\n    end\n    tempsize = ntuple(N+S) do d\n        d > 2S && return size(x, d-S)\n        s = scales[cld(d,2)]\n        isodd(d) ? s : div(size(x, cld(d,2)),s)\n    end\n    mid = sum(reshape(x, tempsize), dims=ntuple(d -> 2d-1, S))\n    reshape(mid, outsize)\nend\n\nfunction rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple)\n    Ω = upsample_nearest(x, s)\n    upsample_nearest_pullback(Δ) = (NoTangent(), ∇upsample_nearest(unthunk(Δ), s), NoTangent())\n    return Ω, upsample_nearest_pullback\nend\n\n\"\"\"\n    upsample_linear(x::AbstractArray{T,3}, scale::Real; align_corners::Bool = true)\n    upsample_linear(x::AbstractArray{T,3}; size::Integer, align_corners::Bool = true)\n\nUpsamples the first dimension of the array `x` by the upsample provided `scale`,\nusing linear interpolation. As an alternative to using `scale`, the resulting array `size`\ncan be directly specified with a keyword argument.\n\nThe size of the output is equal to\n`(scale*S1, S2, S3)`, where `S1, S2, S3 = size(x)`.\n\"\"\"  # the user facing function\nfunction upsample_linear(x::AbstractArray{<:Any,N}, scale::NTuple{M,Real}; align_corners::Bool = true) where {N,M}\n    M == N-2 || error(\"The scale argument should be an NTuple with length $(N-2), but it has length $M.\")\n    outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), N-2)\n    return upsample_linear(x; size=outsize, align_corners)\nend\n\n# convenience for single-number scale\nupsample_linear(x::AbstractArray{<:Any,N}, scale::Real; align_corners::Bool = true) where N =\n    upsample_linear(x, ntuple(_ -> scale, N-2); align_corners)\n\n# this actually calls the upsamling kernel\nfunction upsample_linear(x::AbstractArray{T,N}; size::Union{Integer, NTuple{<:Any,Integer}}, align_corners::Bool = true) where {T,N}\n    length(size) == N-2 || error(\"The scale argument should be an NTuple with length $(N-2), but it has length $(length(size)).\")\n\n    if Base.size(x)[1:N-2] == size\n        return x\n    end\n\n    y = similar(x, T, size..., Base.size(x)[end-1:end]...)\n    return upsample_linear_kernel!(y, x; align_corners)\nend\n\n# Convenience definition for integers. The algo internally works with floats and then rounds.\nfunction upsample_linear(x::AbstractArray{T,<:Any}; size, align_corners::Bool = true) where T<:Integer\n    y = float.(x)\n    res = upsample_linear(y; size=size, align_corners)\n    return round.(T, res)\nend\n\n\"\"\"\n    ∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer, align_corners::Bool = true) where T\n\n# Arguments\n- `Δ`: Incoming gradient array, backpropagated from downstream layers\n- `size`: Size of the image upsampled in the first place\n\n# Outputs\n- `dx`: Downsampled version of `Δ`\n\"\"\"\nfunction ∇upsample_linear(Δ::AbstractArray{T,N}; size::NTuple{<:Any,Integer}, align_corners::Bool = true) where {T,N}\n    if Base.size(Δ)[1:N-2] == size\n        return Δ\n    end\n    dx = fill!(similar(Δ, T, size..., Base.size(Δ)[end-1:end]...), zero(T))\n    return ∇upsample_linear_kernel!(dx, Δ; align_corners)\nend\n\n\nfunction rrule(::typeof(upsample_linear), x::AbstractArray{<:Any,N}; size, align_corners::Bool = true) where N\n    Ω = upsample_linear(x; size, align_corners)\n    function upsample_linear_pullback(Δ)\n        (NoTangent(), ∇upsample_linear(unthunk(Δ); size=Base.size(x)[1:N-2], align_corners))\n    end\n    return Ω, upsample_linear_pullback\nend\n\n\"\"\"\n    upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}; align_corners::Bool = true)\n    upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}, align_corners::Bool = true)\n\nUpsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`,\nusing bilinear interpolation. As an alternative to using `scale`, the resulting image `size`\ncan be directly specified with a keyword argument.\n\nThe size of the output is equal to\n`(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.\n\n# Examples\n\n```jldoctest\njulia> x = reshape(Float32[1 2 3; 4 5 6], (2,3,1,1))\n2×3×1×1 Array{Float32, 4}:\n[:, :, 1, 1] =\n 1.0  2.0  3.0\n 4.0  5.0  6.0\n\njulia> upsample_bilinear(x, (2, 3))\n4×9×1×1 Array{Float32, 4}:\n[:, :, 1, 1] =\n 1.0  1.25  1.5  1.75  2.0  2.25  2.5  2.75  3.0\n 2.0  2.25  2.5  2.75  3.0  3.25  3.5  3.75  4.0\n 3.0  3.25  3.5  3.75  4.0  4.25  4.5  4.75  5.0\n 4.0  4.25  4.5  4.75  5.0  5.25  5.5  5.75  6.0\n\njulia> ans == upsample_bilinear(x; size=(4, 9))  # specify ouput size instead\ntrue\n\njulia> upsample_bilinear(x, (2.5, 3.5))  # non-integer scaling factors are allowed\n5×10×1×1 Array{Float32, 4}:\n[:, :, 1, 1] =\n 1.0   1.22222  1.44444  1.66667  1.88889  …  2.33333  2.55556  2.77778  3.0\n 1.75  1.97222  2.19444  2.41667  2.63889     3.08333  3.30556  3.52778  3.75\n 2.5   2.72222  2.94444  3.16667  3.38889     3.83333  4.05556  4.27778  4.5\n 3.25  3.47222  3.69444  3.91667  4.13889     4.58333  4.80556  5.02778  5.25\n 4.0   4.22222  4.44444  4.66667  4.88889     5.33333  5.55556  5.77778  6.0\n```\n\"\"\"\nupsample_bilinear(x, scale; align_corners::Bool = true) = upsample_linear(x, scale; align_corners)\nupsample_bilinear(x; size, align_corners::Bool = true)  = upsample_linear(x; size, align_corners)\n\n\n\"\"\"\n    ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}, align_corners::Bool = true) where T\n\n# Arguments\n- `Δ`: Incoming gradient array, backpropagated from downstream layers\n- `size`: Lateral (W,H) size of the image upsampled in the first place\n\n# Outputs\n- `dx`: Downsampled version of `Δ`\n\"\"\"\n∇upsample_bilinear(Δ; size, align_corners::Bool = true) = ∇upsample_linear(Δ; size, align_corners)\n\n\"\"\"\n    upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real}; align_corners::Bool = true)\n    upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}, align_corners::Bool = true)\n\nUpsamples the first 3 dimensions of the array `x` by the upsample factors stored in `scale`,\nusing trilinear interpolation. As an alternative to using `scale`, the resulting image `size`\ncan be directly specified with a keyword argument.\n\nThe size of the output is equal to\n`(scale[1]*S1, scale[2]*S2, scale[3]*S3, S4, S5)`, where `S1, S2, S3, S4, S5 = size(x)`.\n\n# Examples\n\n```julia\nupsample_trilinear(x, (2, 3, 4))\nupsample_trilinear(x; size=(4, 9, 11))  # specify ouput size instead\nupsample_trilinear(x, (2.5, 3.5, pi))  # non-integer scaling factors are allowed\n```\n\"\"\"\nupsample_trilinear(x, scale; align_corners::Bool = true) = upsample_linear(x, scale; align_corners)\nupsample_trilinear(x; size, align_corners::Bool = true)  = upsample_linear(x; size, align_corners)\n\n\"\"\"\n    ∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}, align_corners::Bool = true) where T\n\n# Arguments\n- `Δ`: Incoming gradient array, backpropagated from downstream layers\n- `size`: Lateral size & depth (W,H,D) of the image upsampled in the first place\n\n# Outputs\n- `dx`: Downsampled version of `Δ`\n\"\"\"\n∇upsample_trilinear(Δ; size, align_corners::Bool = true) = ∇upsample_linear(Δ; size, align_corners)\n\nfunction upsample_linear_kernel!(\n    y::AbstractArray{T, N}, x::AbstractArray{T, N}; align_corners::Bool = true,\n) where {T, N}\n    backend = KernelAbstractions.get_backend(x)\n    ndrange = backend isa CPU ?\n        size(y)[N - 1:end] : # Parallelization along channel x batch.\n        size(y)[1:N - 2] # Parallelization along WHD.\n    ratios = align_corners ?\n        ntuple(i -> real(T)((size(x, i) - 1) / (size(y, i) - 1)), N - 2) :\n        ntuple(i -> real(T)(size(x, i) / size(y, i)), N - 2)\n    _upsample_linear_kernel!(backend)(backend, y, x, ratios..., Val(align_corners); ndrange)\n    return y\nend\n\nfunction ∇upsample_linear_kernel!(\n    dx::AbstractArray{T, N}, Δ::AbstractArray{T, N}; align_corners::Bool = true,\n) where {T, N}\n    backend = KernelAbstractions.get_backend(dx)\n    ndrange = backend isa CPU ?\n        size(Δ)[N - 1:end] : # Parallelization along channel x batch.\n        size(Δ)[1:N - 2] # Parallelization along WHD.\n    ratios = align_corners ?\n        ntuple(i -> real(T)((size(dx, i) - 1) / (size(Δ, i) - 1)), N - 2) :\n        ntuple(i -> real(T)(size(dx, i) / size(Δ, i)), N - 2)\n    _∇upsample_linear_kernel!(backend)(backend, dx, Δ, ratios..., Val(align_corners); ndrange)\n    return dx\nend\n\n# Linear (CPU): parallelization along channel x batch dimensions.\n\n@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, align::Val{A}) where {\n    T <: AbstractArray{<:Any, 3}, A,\n}\n    @uniform in_width, channels, batch = size(x)\n    @uniform out_width = size(y, 1)\n    c, n = @index(Global, NTuple)\n    yv, xv = @view(y[:, c, n]), @view(x[:, c, n])\n    @inbounds for i in 1:out_width\n        iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width)\n        yv[i] = w0λ * xv[iw0] + w1λ * xv[iw1]\n    end\nend\n\n@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, align::Val{A}) where {\n    T1 <: AbstractArray{<:Any, 3}, T2 <: AbstractArray{<:Any, 3}, A,\n}\n    @uniform in_width, channels, batch = size(Δ)\n    @uniform out_width = size(dx, 1)\n    c, n = @index(Global, NTuple)\n    Δv, dxv = @view(Δ[:, c, n]), @view(dx[:, c, n])\n    @inbounds for i in 1:in_width\n        ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width)\n        val = Δv[i]\n        dxv[ow0] += w0λ * val\n        dxv[ow1] += w1λ * val\n    end\nend\n\n# Linear (GPU): parallelization along width dimension.\n\n@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, align::Val{A}) where {\n    B <: GPU, T <: AbstractArray{<:Any, 3}, A,\n}\n    @uniform in_width, channels, batch = size(x)\n    i = @index(Global)\n    iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width)\n    @inbounds for n in 1:batch, c in 1:channels\n        y[i, c, n] = w0λ * x[iw0, c, n] + w1λ * x[iw1, c, n]\n    end\nend\n\n@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, align::Val{A}) where {\n    B <: GPU, T <: AbstractArray{<:Any, 3}, A,\n}\n    @uniform in_width, channels, batch = size(Δ)\n    @uniform out_width = size(dx, 1)\n    i = @index(Global)\n    ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width)\n    @inbounds for n in 1:batch, c in 1:channels\n        val = Δ[i, c, n]\n        @atomic dx[ow0, c, n] += w0λ * val\n        @atomic dx[ow1, c, n] += w1λ * val\n    end\nend\n\n# Bilinear (CPU): parallelization along channel x batch dimensions.\n\n@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, align::Val{A}) where {\n    T <: AbstractArray{<:Any, 4}, A,\n}\n    @uniform in_width, in_height, channels, batch = size(x)\n    @uniform out_width, out_height = size(y)[1:2]\n    c, n = @index(Global, NTuple)\n    yv, xv = @view(y[:, :, c, n]), @view(x[:, :, c, n])\n    for j in 1:out_height\n        ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height)\n        for i in 1:out_width\n            iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width)\n            @inbounds yv[i, j] =\n                h0λ * (w0λ * xv[iw0, ih0] + w1λ * xv[iw1, ih0]) +\n                h1λ * (w0λ * xv[iw0, ih1] + w1λ * xv[iw1, ih1])\n        end\n    end\nend\n\n@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where {\n    T1 <: AbstractArray{<:Any, 4}, T2 <: AbstractArray{<:Any, 4}, A,\n}\n    @uniform in_width, in_height, channels, batch = size(Δ)\n    @uniform out_width, out_height = size(dx)[1:2]\n    c, n = @index(Global, NTuple)\n    Δv, dxv = @view(Δ[:, :, c, n]), @view(dx[:, :, c, n])\n    for j in 1:in_height\n        oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height)\n        @inbounds for i in 1:in_width\n            ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width)\n            val = Δv[i, j]\n            dxv[ow0, oh0] += w0λ * h0λ * val\n            dxv[ow1, oh0] += w1λ * h0λ * val\n            dxv[ow0, oh1] += w0λ * h1λ * val\n            dxv[ow1, oh1] += w1λ * h1λ * val\n        end\n    end\nend\n\n# Bilinear (GPU): parallelization along width, height dimensions.\n\n@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, align::Val{A}) where {\n    B <: GPU, T <: AbstractArray{<:Any, 4}, A,\n}\n    @uniform in_width, in_height, channels, batch = size(x)\n    i, j = @index(Global, NTuple)\n    iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width)\n    ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height)\n    @inbounds for n in 1:batch, c in 1:channels\n        y[i, j, c, n] =\n            h0λ * (w0λ * x[iw0, ih0, c, n] + w1λ * x[iw1, ih0, c, n]) +\n            h1λ * (w0λ * x[iw0, ih1, c, n] + w1λ * x[iw1, ih1, c, n])\n    end\nend\n\n@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, align::Val{A}) where {\n    B <: GPU, T <: AbstractArray{<:Any, 4}, A,\n}\n    @uniform in_width, in_height, channels, batch = size(Δ)\n    @uniform out_width, out_height = size(dx)[1:2]\n    i, j = @index(Global, NTuple)\n    ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width)\n    oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height)\n    @inbounds for n in 1:batch, c in 1:channels\n        val = Δ[i, j, c, n]\n        @atomic dx[ow0, oh0, c, n] += w0λ * h0λ * val\n        @atomic dx[ow1, oh0, c, n] += w1λ * h0λ * val\n        @atomic dx[ow0, oh1, c, n] += w0λ * h1λ * val\n        @atomic dx[ow1, oh1, c, n] += w1λ * h1λ * val\n    end\nend\n\n# Trilinear (CPU): parallelization along channel x batch dimensions.\n\n@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where {\n    T <: AbstractArray{<:Any, 5}, A,\n}\n    @uniform in_width, in_height, in_depth = size(x)[1:3]\n    @uniform channels, batch = size(x, 4), size(x, 5)\n    @uniform out_width, out_height, out_depth = size(y)[1:3]\n    c, n = @index(Global, NTuple)\n    yv, xv = @view(y[:, :, :, c, n]), @view(x[:, :, :, c, n])\n    for k in 1:out_depth\n        id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, in_depth)\n        for j in 1:out_height\n            ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height)\n            for i in 1:out_width\n                iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width)\n                @inbounds yv[i, j, k] =\n                    d0λ * (\n                        h0λ * (w0λ * xv[iw0, ih0, id0] + w1λ * xv[iw1, ih0, id0]) +\n                        h1λ * (w0λ * xv[iw0, ih1, id0] + w1λ * xv[iw1, ih1, id0])) +\n                    d1λ * (\n                        h0λ * (w0λ * xv[iw0, ih0, id1] + w1λ * xv[iw1, ih0, id1]) +\n                        h1λ * (w0λ * xv[iw0, ih1, id1] + w1λ * xv[iw1, ih1, id1]))\n            end\n        end\n    end\nend\n\n@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, rdepth, align::Val{A}) where {\n    T1 <: AbstractArray{<:Any, 5}, T2 <: AbstractArray{<:Any, 5}, A,\n}\n    @uniform in_width, in_height, in_depth = size(Δ)[1:3]\n    @uniform channels, batch = size(Δ, 4), size(Δ, 5)\n    @uniform out_width, out_height, out_depth = size(dx)[1:3]\n    c, n = @index(Global, NTuple)\n    Δv, dxv = @view(Δ[:, :, :, c, n]), @view(dx[:, :, :, c, n])\n    for k in 1:in_depth\n        od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, out_depth)\n        for j in 1:in_height\n            oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height)\n            @inbounds for i in 1:in_width\n                ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width)\n                val = Δv[i, j, k]\n                dxv[ow0, oh0, od0] += w0λ * h0λ * d0λ * val\n                dxv[ow1, oh0, od0] += w1λ * h0λ * d0λ * val\n                dxv[ow0, oh1, od0] += w0λ * h1λ * d0λ * val\n                dxv[ow1, oh1, od0] += w1λ * h1λ * d0λ * val\n\n                dxv[ow0, oh0, od1] += w0λ * h0λ * d1λ * val\n                dxv[ow1, oh0, od1] += w1λ * h0λ * d1λ * val\n                dxv[ow0, oh1, od1] += w0λ * h1λ * d1λ * val\n                dxv[ow1, oh1, od1] += w1λ * h1λ * d1λ * val\n            end\n        end\n    end\nend\n\n# Trilinear (GPU): parallelization along width x height x depth dimensions.\n\n@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where {\n    B <: GPU, T <: AbstractArray{<:Any, 5}, A,\n}\n    @uniform in_width, in_height, in_depth = size(x)[1:3]\n    @uniform channels, batch = size(x, 4), size(x, 5)\n    i, j, k = @index(Global, NTuple)\n    iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, in_width)\n    ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, in_height)\n    id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, in_depth)\n    @inbounds for n in 1:batch, c in 1:channels\n        y[i, j, k, c, n] =\n            d0λ * (\n                h0λ * (w0λ * x[iw0, ih0, id0, c, n] + w1λ * x[iw1, ih0, id0, c, n]) +\n                h1λ * (w0λ * x[iw0, ih1, id0, c, n] + w1λ * x[iw1, ih1, id0, c, n])) +\n            d1λ * (\n                h0λ * (w0λ * x[iw0, ih0, id1, c, n] + w1λ * x[iw1, ih0, id1, c, n]) +\n                h1λ * (w0λ * x[iw0, ih1, id1, c, n] + w1λ * x[iw1, ih1, id1, c, n]))\n    end\nend\n\n@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, rdepth, align::Val{A}) where {\n    B <: GPU, T <: AbstractArray{<:Any, 5}, A,\n}\n    @uniform in_width, in_height, in_depth = size(Δ)[1:3]\n    @uniform channels, batch = size(Δ, 4), size(Δ, 5)\n    @uniform out_width, out_height, out_depth = size(dx)[1:3]\n    i, j, k = @index(Global, NTuple)\n    ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - 1, align, out_width)\n    oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - 1, align, out_height)\n    od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - 1, align, out_depth)\n    @inbounds for n in 1:batch, c in 1:channels\n        val = Δ[i, j, k, c, n]\n        @atomic dx[ow0, oh0, od0, c, n] += w0λ * h0λ * d0λ * val\n        @atomic dx[ow1, oh0, od0, c, n] += w1λ * h0λ * d0λ * val\n        @atomic dx[ow0, oh1, od0, c, n] += w0λ * h1λ * d0λ * val\n        @atomic dx[ow1, oh1, od0, c, n] += w1λ * h1λ * d0λ * val\n\n        @atomic dx[ow0, oh0, od1, c, n] += w0λ * h0λ * d1λ * val\n        @atomic dx[ow1, oh0, od1, c, n] += w1λ * h0λ * d1λ * val\n        @atomic dx[ow0, oh1, od1, c, n] += w0λ * h1λ * d1λ * val\n        @atomic dx[ow1, oh1, od1, c, n] += w1λ * h1λ * d1λ * val\n    end\nend\n\n@inline function source_idx_and_λ(\n    ratio::T, out_idx::Int, ::Val{align}, in_width::Int,\n) where {T, align}\n    real_index = align ?\n        ratio * out_idx :\n        max(zero(T), ratio * (out_idx + T(0.5)) - T(0.5))\n\n    iw0 = if T <: Rational\n        floor(Int, real_index) # Not GPU-friendly, but allows for Rational support.\n    else\n        unsafe_trunc(Int, floor(real_index))\n    end\n    offset = ifelse(iw0 < in_width - 1, 1, 0)\n    iw1 = iw0 + offset + 1\n\n    w1lambda = real_index - iw0\n    w0lambda = one(T) - w1lambda\n    return iw0 + 1, iw1, w0lambda, w1lambda\nend\n"
  },
  {
    "path": "src/utils.jl",
    "content": "\"\"\"\n    within_gradient(x) --> Bool\n\nReturns `false` except when used inside a `gradient` call, when it returns `true`.\nUseful for Flux regularisation layers which behave differently during training and inference.\n\nThis should work with any ChainRules-based differentiation package, in which case `x` is ignored.\nBut Tracker.jl overloads `with_gradient(x::TrackedArray)`, thus for widest use you should\npass it an array whose gradient is of interest.\nThere is also an overload for ForwardDiff.jl's `Dual` types (and arrays of them).\n\n# Examples\n```julia-repl\njulia> using ForwardDiff, Zygote, NNlib\n\njulia> f_good(x) = if NNlib.within_gradient(x)\n                     @show 10x\n                   else\n                     x\n                   end;\n\njulia> Zygote.withgradient(f_good, 1.0)\n10x = 10.0\n(val = 10.0, grad = (10.0,))\n\njulia> ForwardDiff.derivative(f_good, 1.0)\n10x = Dual{ForwardDiff.Tag{typeof(f_good), Float64}}(10.0,10.0)\n10.0\n\njulia> f_bad(x, y) = if any(NNlib.within_gradient, (x, y))\n                       @show x * y\n                     else\n                       x / y\n                     end;\n\njulia> Zygote.withgradient(f_bad, 2.0, 3.0)\n(val = 0.6666666666666666, grad = (0.3333333333333333, -0.2222222222222222))\n\njulia> ForwardDiff.derivative(x -> f_bad(x, 3.0), 2.0)\nx * y = Dual{ForwardDiff.Tag{var\"#9#10\", Float64}}(6.0,3.0)\n3.0\n```\n\nWhat goes wrong in `f_bad` is that Zygote knows `any` to be non-differentiable,\nand thus completely ignores its contents. This is not a perfect mechanism,\nand the only style recommended is precisely that of `f_good` above.\n\"\"\"\nwithin_gradient(x) = false\n\nChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), NoTangent())\n\n\n\"\"\"\n    safe_div(x, y)\n\nReturns `x/y` unless `y==0`, in which case it just returns `x`.\n(Used internally by `scatter`.)\n\"\"\"\nsafe_div(x, y) = ifelse(iszero(y), x, x/y)\n\n\"\"\"\n    maximum_dims(dims)\n\nGiven an array of `CartesianIndex{N}` or `NTuple{N,Int}`,\nreturns a tuple containing the maximum of all the 1st entries,\nall the 2nd entries, and so on up to `N`.\n\nGiven an array of integers, returns `(maximum(dims),)`.\n\n(These arguments are what [`scatter`](@ref NNlib.scatter) understands.)\n\"\"\"\nmaximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )\nmaximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N)\nmaximum_dims(dims::AbstractArray{CartesianIndex{N}}) where {N} = ntuple(i -> maximum(x->x[i], dims), N)\n\nfunction reverse_indices!(rev::AbstractArray, idx::AbstractArray{<:Tuple})\n    for (ind, val) in pairs(Array(idx))\n        push!(rev[val...], ind)\n    end\n    # if CUDA supports `unique`, a more efficient version:\n    # cidx in CartesianIndices(idx)\n    # for i = unique(idx)\n    #     rev[i] = cidx[idx .== i]\n    # end\n    rev\nend\n\nfunction reverse_indices!(rev::AbstractArray, idx::AbstractArray)\n    for (ind, val) in pairs(Array(idx))\n        push!(rev[val], ind)\n    end\n    rev\nend\n\n\"\"\"\n    reverse_indices(idx)\n\nReturn the reverse indices of `idx`. The indices of `idx` will be values, and values of `idx` will be index.\n\n# Arguments\n\n- `idx`: The indices to be reversed. Accepts array or cuarray of integer, tuple or `CartesianIndex`.\n\"\"\"\nfunction reverse_indices(idx::AbstractArray{<:Any,N}) where N\n    max_dims = maximum_dims(idx)\n    T = CartesianIndex{N}\n    rev = Array{Vector{T}}(undef, max_dims...)\n    for i in eachindex(rev)\n        rev[i] = T[]\n    end\n    return reverse_indices!(rev, idx)\nend\n\nunsqueeze(x) = reshape(x, 1, size(x)...)\n\n\n\"\"\"\n    _fast_broadcast!(f, x, y, z...)\n\nThis does `x .= f.(x, y, z...)`, but works around\nan issue with broadcasting that prevents SIMD in such cases.\nCan perhaps be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.\n\nHas an `rrule` to avoid mutation within derivatives.\n\n!!! warning\n    Not intended for general use.\n    Uses `@inbounds` but does not check sizes!\n    Assumes that `f` has no derivative!\n\"\"\"\nfunction _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}\n    bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))\n    @simd ivdep for I in eachindex(bc)\n        @inbounds x[I] = bc[I]\n    end\n    return x\nend\nfunction _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}\n    # CUDA does not suffer from this bug\n    broadcast!(f, x, x, yz...)\nend\n\nfunction rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f::F, x::AbstractArray, ys...)  where {F<:Function}\n    rrule_via_ad(cfg, broadcast, f, x, ys...)\nend\n\n\n"
  },
  {
    "path": "test/Project.toml",
    "content": "[deps]\nAdapt = \"79e6a3ab-5dfb-504d-930d-738a2a938a0e\"\nChainRulesCore = \"d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4\"\nChainRulesTestUtils = \"cdddcdb0-9152-4a09-a978-84456f9df70a\"\nDocumenter = \"e30172f5-a6a5-5a46-863b-614d45cd2de4\"\nEnzyme = \"7da242da-08ed-463a-9acd-ee780be4f1d9\"\nEnzymeCore = \"f151be2c-9106-41f4-ab19-57ee4f262869\"\nEnzymeTestUtils = \"12d8515a-0907-448a-8884-5fe00fdf1c5a\"\nFFTW = \"7a1cc6ca-52ef-59f5-83cd-3a7055c09341\"\nFiniteDifferences = \"26cc04aa-876d-5657-8c51-4c34ba976000\"\nForwardDiff = \"f6369f11-7733-5829-9624-2563aa707210\"\nImageTransformations = \"02fcd773-0e25-5acc-982a-7f6622650795\"\nInterpolations = \"a98d9a8b-a2ab-59e6-89dd-64a1c18fca59\"\nKernelAbstractions = \"63c18a36-062a-441e-b654-da1e3ab1ce7c\"\nLinearAlgebra = \"37e2e46d-f89d-539d-b4ee-838fcccc9c8e\"\nLogging = \"56ddb016-857b-54e1-b83d-db4d58db5568\"\nMLDataDevices = \"7e8f7934-dd98-4c1a-8fe8-92b47a384d40\"\nNNlib = \"872c559c-99b0-510c-b3b7-b6c96a88d5cd\"\nPkg = \"44cfe95a-1eb2-52ea-b672-e2afdf69b78f\"\nRandom = \"9a3f8284-a2c9-5f02-9a11-845980a1fd5c\"\nReverseDiff = \"37e2e3b7-166d-5795-8a7a-e32c996b4267\"\nSpecialFunctions = \"276daf66-3868-5448-9aa4-cd146d93841b\"\nStableRNGs = \"860ef19b-820b-49d6-a774-d7a799459cd3\"\nStatistics = \"10745b16-79ce-11e8-11f9-7d13ad32a3b2\"\nTest = \"8dfed614-e22c-5e08-85e1-65c5234f0b40\"\nUnicodePlots = \"b8865327-cd53-5732-bb35-84acbb429228\"\nZygote = \"e88e6eb3-aa80-5325-afca-941959d7151f\"\n"
  },
  {
    "path": "test/activations.jl",
    "content": "\nACTIVATION_FUNCTIONS = [@eval($a) for a in NNlib.ACTIVATIONS]\n\nBINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATION_FUNCTIONS)\n\n@test sigmoid(0.0) == 0.5\n@test hardsigmoid(0.0) == 0.5\n@test hardtanh(0.0) == 0.0\n@test relu(0.0) == 0.0\n@test leakyrelu(0.0) == 0.0\n@test relu6(0.0) == 0.0\n@test rrelu(0.0) == 0.0\n@test elu(0.0) == 0.0\n@test gelu(0.0) == 0.0\n@test gelu_tanh(0.0) == 0.0\n@test gelu_sigmoid(0.0) == 0.0\n@test gelu_erf(0.0) == 0.0\n@test swish(0.0) == 0.0\n@test hardswish(0.0) == 0.0\n@test lisht(0.0) == 0.0\n@test softplus(0.0) ≈ log(2.0)\n@test softplus(1e8) ≈ 1e8\n@test softplus(-1e8) ≈ 0.0\n@test softsign(0.0) == 0.0\n@test selu(0.0) == 0.0\n@test celu(0.0) == 0.0\n@test trelu(0.0) == 0.0\n@test logcosh(0.0) == log(cosh(0.0))\n@test mish(0.0) == 0.0\n@test tanhshrink(0.0) == 0.0\n@test softshrink(0.0) == 0.0\n\n@test sigmoid(1.0) == 1.0 / (1.0 + exp(-1.0))\n@test hardsigmoid(1.0) == max(0,min(1, (1 + 3)/6))\n@test hardtanh(1.0) == 1.0\n@test relu(1.0) == 1.0\n@test leakyrelu(1.0) == 1.0\n@test relu6(1.0) == 1.0\n@test rrelu(1.0) == 1.0\n@test elu(1.0) == 1.0\n@test gelu(1.0) ≈ 0.8411919906082768\n@test gelu_tanh(1.0) ≈ 0.8411919906082768\n@test gelu_sigmoid(1.0) ≈ 0.8411919906082768\n@test gelu_erf(1.0) == 0.8413447460685429\n@test swish(1.0) == sigmoid(1.0)\n@test hardswish(1.0) == hardsigmoid(1.0)\n@test lisht(1.0) ≈ 1.0 * tanh(1.0)\n@test softplus(1.0) ≈ log(exp(1.0) + 1.0)\n@test softsign(1.0) == 0.5\n@test selu(1.0) == 1.0507009873554804934193349852946\n@test celu(1.0) == 1.0\n@test trelu(1.0) == 0.0\n@test logcosh(1.0) ≈ log(cosh(1.0))\n@test mish(1.0) ≈ tanh(log(1.0 + exp(1.0)))\n@test tanhshrink(1.0) ≈ 0.23840584404423515\n@test softshrink(1.0) == 0.5\n\n@test sigmoid(-1.0) == exp(-1.0) / (1.0 + exp(-1.0))\n@test hardsigmoid(-1.0) == max(0,min(1,(-1+3)/6 ))\n@test hardtanh(-1.0) == -1.0\n@test relu(-1.0) == 0.0\n@test leakyrelu(-1.0) == -0.01\n@test relu6(-1.0) == 0.0\n@test -1/3.0 <= rrelu(-1.0) <= -1/8.0\n@test elu(-1.0) == exp(-1.0) - 1.0\n@test gelu(-1.0) ≈ -0.15880800939172324\n@test gelu_tanh(-1.0) ≈ -0.15880800939172324\n@test gelu_sigmoid(-1.0) ≈ -0.15880800939172324\n@test gelu_erf(-1.0) == -0.15865525393145707\n@test swish(-1.0) == -sigmoid(-1.0)\n@test hardswish(-1.0) == -hardsigmoid(-1.0)\n@test lisht(-1.0) ≈ -1.0 * tanh(-1.0)\n@test softplus(-1.0) ≈ log(exp(-1.0) + 1.0)\n@test softsign(-1.0) == -0.5\n@test selu(-1.0) ≈ 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)\n@test celu(-1.0) == exp(-1.0) - 1\n@test trelu(-1.0) == 0.0\n@test log(cosh(-1.0)) ≈ log(cosh(-1.0))\n@test mish(-1.0) ≈ -tanh(log(1.0 + exp(-1.0)))\n@test tanhshrink(-1.0) ≈ -0.23840584404423515\n@test softshrink(-1.0) == -0.5\n\n@testset \"Float inference\" begin\n    @testset \"$(a): \" for a in ACTIVATION_FUNCTIONS\n        for T in [Float16, Float32, Float64]\n            for val in [-10, -1, 0, 1, 10]\n                out = @inferred a(T(val))\n                @test typeof(out) == T\n            end\n        end\n    end\n    @testset \"binary $a: \" for a in BINARY_ACTIVATIONS\n        for T in [Float16, Float32, Float64]\n            for val in [-10, -1, 0, 1, 10], beta in Any[0.1, 0.5f0, 1]\n                out = @inferred a(T(val), beta)\n                @test typeof(out) == T\n            end\n        end\n    end\nend\n\n@testset \"Array input -> error\" begin\n    x = rand(5)\n    for a in ACTIVATION_FUNCTIONS\n        @test size(a(x)) == size(x)\n        grad = Zygote.gradient(p -> sum(a(p)), x)\n        @test size(grad[1]) == size(x)\n    end\n    for a in BINARY_ACTIVATIONS\n        @test size(a(x, 0.1)) == size(x)\n        grad = Zygote.gradient(p -> sum(a(p, 0.1)), x)\n        @test size(grad[1]) == size(x)\n    end\nend\n\n@testset \"NaN propagation\" begin\n    @testset \"$a\" for a in ACTIVATION_FUNCTIONS\n        # With NaN input, all should produce NaN output:\n        @test isnan(a(NaN32))\n\n        # Ideally +-Inf would not lead to NaN, but perhaps\n        # these aren't worth the complication of fixing:\n        a == softsign && continue\n        @test !isnan(a(Inf32))\n\n        a in [gelu, gelu_tanh, gelu_sigmoid, gelu_erf, swish, hardswish, logcosh, mish] && continue\n        @test !isnan(a(-Inf32))\n    end\nend\n\n@testset \"Integer inputs\" begin\n    # These should work without error, for e.g. readme examples,\n    # but no serious use will involve integers, no need for performance.\n    @testset \"$a\" for a in ACTIVATION_FUNCTIONS\n        @test typeof(a(Int64(1))) <: Real\n        @test typeof(a(Int32(1))) <: Real\n    end\n\n    # The following ones can pass integers through. But it's not very important.\n    @testset \"relu: Int -> Int\" begin\n        @test typeof(relu(Int64(1))) == Int64\n        @test typeof(relu(Int32(1))) == Int32\n    end\n    @testset \"relu6: Int -> Int\" begin\n        @test typeof(relu6(Int64(1))) == Int64\n        @test typeof(relu6(Int32(1))) == Int32\n    end\n    @testset \"hardtanh: Int -> Int\" begin\n        @test typeof(hardtanh(Int64(1))) == Int64\n        @test typeof(hardtanh(Int32(1))) == Int32\n    end\n    @testset \"trelu: Int -> Int\" begin\n        @test typeof(trelu(Int64(1))) == Int64\n        @test typeof(trelu(Int32(1))) == Int32\n    end\nend\n\n@testset \"elu\" begin\n    @test elu(42) == 42\n    @test elu(42.) == 42.\n\n    @test elu(-4) ≈ (exp(-4) - 1)\nend\n\n@testset \"mish\" begin\n    @test mish(-5) ≈ -0.033576237730161704\n    @test mish(9) == 9*tanh(log(1 + exp(9)))\n    xs = Float32[1 2 3; 1000 2000 3000]\n    @test typeof(mish.(xs)) == typeof(xs)\nend\n\n@test leakyrelu( 0.4,0.3) ≈  0.4\n@test leakyrelu(-0.4,0.3) ≈ -0.12\n\n@test relu6(10.0) == 6.0\n\n@test -0.2 <= rrelu(-0.4,0.25,0.5) <= -0.1\n\n@testset \"celu\" begin\n    @test celu(42) == 42\n    @test celu(42.) == 42.\n\n    @test celu(-4, 0.5) ≈ 0.5*(exp(-4.0/0.5) - 1)\nend\n\n@testset \"softshrink\" begin\n    @test softshrink(15., 5.) == 10.\n    @test softshrink(4., 5.) == 0.\n    @test softshrink(-15., 5.) == -10.\nend\n\n@testset \"logsigmoid\" begin\n    xs = randn(10,10)\n    @test logsigmoid.(xs) ≈ log.(sigmoid.(xs))\n    for T in [:Float32, :Float64]\n        @eval @test logsigmoid.($T[-100_000, 100_000.]) ≈ $T[-100_000, 0.]\n    end\nend\n\n@test logcosh(1_000.0) + log(2) == 1_000.0\n\n@testset \"hardsigmoid\" begin\n    @test hardsigmoid(0.3) == max(0,min(1,(0.3+3)/6))\n    @test hardsigmoid(-0.3) == max(0,min(1,(-0.3+3)/6))\n    for T in [:Float32, :Float64]\n        @eval @test hardsigmoid.($T[-100_000, 100_000.]) ≈ $T[0., 1.]\n    end\nend\n\n@test hardtanh(10.0) == 1.0\n\n@test lisht(2.5) == 2.5*tanh(2.5)\n\n@testset \"trelu\" begin\n    @test trelu(0.5) == 0.0\n    @test trelu(1.0) == 0.0\n    @test trelu(1.1) == 1.1\n    @test trelu(0.9,0.5) == 0.9\nend\n\n## Faster variants\n\nusing NNlib: tanh_fast, sigmoid_fast\n\nfunction countepsfrom(x::T, xtrue) where {T<:AbstractFloat}\n    target = T(xtrue)\n    for n in Iterators.flatten(zip(0:100, -1:-1:-100))\n        nextfloat(x, n) === target && return n\n    end\n    return round(Int, (target - x) / eps(x))\nend\n\nmean_eps(f, g, xs) = mean(x -> abs(countepsfrom(f(x), g(big(x)))), xs)\nworst_eps(f, g, xs) = maximum(x -> abs(countepsfrom(f(x), g(big(x)))), xs)\nfunction find_worst(f, g, xs)\n    c, i = findmax(x -> abs(countepsfrom(f(x), g(big(x)))), xs)\n    c, xs[i]\nend\n\n@testset \"tanh_fast & sigmoid_fast: Float64\" begin\n    \n    x64 = 1e-6:1e-4:5\n    xbig = vcat(6:3:200.0, 1000, 10^6, typemax(Float64))\n    \n    @testset \"tanh\" begin\n        mean_eps(tanh, tanh, x64)  # 0.06582\n        worst_eps(tanh, tanh, x64) # 2\n\n        @test mean_eps(tanh_fast, tanh, x64) < 0.2  # 0.13164\n        @test worst_eps(tanh_fast, tanh, x64) <= 5  # 5\n\n        @test mean_eps(tanh_fast, tanh, -x64) < 0.6 # 0.5248\n        @test worst_eps(tanh_fast, tanh, -x64) <= 5 # 5\n\n        @test tanh_fast.(xbig) ≈ tanh.(xbig)\n        @test tanh_fast.(-xbig) ≈ tanh.(-xbig)\n    end\n    @testset \"sigmoid\" begin\n        mean_eps(sigmoid, sigmoid, x64)  # 0.39246\n        worst_eps(sigmoid, sigmoid, x64) # 1\n\n        @test mean_eps(sigmoid_fast, sigmoid, x64) < 0.5  # 0.40432\n        @test worst_eps(sigmoid_fast, sigmoid, x64) <= 5  # 2\n\n        mean_eps(sigmoid, sigmoid, -x64)  # 0.37672\n        worst_eps(sigmoid, sigmoid, -x64) # 2\n\n        @test mean_eps(sigmoid_fast, sigmoid, -x64) < 0.6  # 0.56478\n        @test worst_eps(sigmoid_fast, sigmoid, -x64) <= 5  # 4\n\n        @test sigmoid_fast.(xbig) ≈ sigmoid.(xbig)\n        @test sigmoid_fast.(-xbig) ≈ sigmoid.(-xbig)\n    end\nend\n\n@testset \"tanh_fast & sigmoid_fast: Float32\" begin\n    \n    x32 = 1f-6:1f-4:5\n    xbig32 = vcat(6:3:200f0, 1000, typemax(Float32))\n\n    @testset \"tanh\" begin\n        mean_eps(tanh, tanh, x32)  # 0.065\n        worst_eps(tanh, tanh, x32) # 1\n\n        @test mean_eps(tanh_fast, tanh, x32) < 0.8  # 0.65414\n        @test worst_eps(tanh_fast, tanh, x32) <= 5  # 5\n\n        @test mean_eps(tanh_fast, tanh, -x32) < 0.8 # 0.65414\n        @test worst_eps(tanh_fast, tanh, -x32) <= 5 # 5\n\n        @test tanh_fast.(xbig32) ≈ tanh.(xbig32)\n        @test tanh_fast.(-xbig32) ≈ tanh.(-xbig32)\n    end\n    @testset \"sigmoid\" begin\n        mean_eps(sigmoid, sigmoid, x32)  # 0.38896\n        worst_eps(sigmoid, sigmoid, x32) # 1\n\n        @test mean_eps(sigmoid_fast, sigmoid, x32) < 0.5  # 0.38896\n        @test worst_eps(sigmoid_fast, sigmoid, x32) <= 2  # 2\n\n        mean_eps(sigmoid, sigmoid, -x32)  # 0.38088\n        worst_eps(sigmoid, sigmoid, -x32) # 2\n\n        @test mean_eps(sigmoid_fast, sigmoid, -x32) < 0.5  # 0.38088\n        @test worst_eps(sigmoid_fast, sigmoid, -x32) <= 2  # 2\n\n        @test sigmoid_fast.(xbig32) ≈ sigmoid.(xbig32)\n        @test sigmoid_fast.(-xbig32) ≈ sigmoid.(-xbig32)\n    end\nend\n\n## Autodiff tests\n\nWITH_UNARY_RULE = [@eval($a) for (a, _) in NNlib.UNARY_ACTS]\n\nWITH_BINARY_RULE = [@eval($a) for (a, _, _) in NNlib.BINARY_ACTS]\n\nhas_rule(a) = rrule(a, 1f0) === nothing ? \"(no rule)\" : \"\"\n\n@testset \"Gradient inference\" begin\n    @testset \"$(a): $(has_rule(a))\" for a in ACTIVATION_FUNCTIONS\n        @testset \"$T\" for T in [Float16, Float32, Float64]\n            for val in [-10, -1, 0, 1, 10]\n                grad = @inferred gradient(a, T(val))\n                @test typeof(grad[1]) == T\n            end\n        end\n    end\nend\n\nusing Base.Broadcast: broadcasted\n\n@testset \"lazy broadcasting\" begin\n    # ChainRules returns a Broadcasted, check these rules accept it\n    @test rrule(broadcasted, relu, rrule(broadcasted, +, [1,2], 3)[1]) != nothing\n    @test rrule(broadcasted, leakyrelu, rrule(broadcasted, +, [1,2], 3)[1], 0.2) != nothing\nend\n\n@testset \"Gradient correctness\" begin\n    \n    local rng = StableRNG(17)\n\n    @testset \"$(f): $(has_rule(f))\" for f in ACTIVATION_FUNCTIONS\n        f == rrelu && continue # stocastich output\n        \n        ## Avoid singular points of some activations\n        ## problematic for finite diff methods\n        gradtest(f, +2 + rand(rng))\n        gradtest(f, -2 - rand(rng))\n        gradtest(f, +2 .+ rand(rng, 2, 2), check_broadcast=true)\n        gradtest(f, -2 .- rand(rng, 2, 2), check_broadcast=true)\n\n        if f in BINARY_ACTIVATIONS\n            gradtest(x -> f(x, 0.2), 1 + rand(rng))\n            gradtest(x -> f(x, 0.7), 1 + rand(rng))\n\n            gradtest(x -> f(x, 0.2), -2 + rand(rng))\n            gradtest(x -> f(x, 0.7), -2 + rand(rng))\n        end\n\n        ## Check that rules, including broadcast rules, are defined:\n        if f in WITH_UNARY_RULE\n            @test rrule(f, rand()) !== nothing\n            @test rrule(broadcasted, f, rand(2)) !== nothing\n        end\n        if f in WITH_BINARY_RULE\n            @test rrule(f, rand(), rand()) !== nothing\n            @test rrule(broadcasted, f, rand(2), rand()) !== nothing\n        end\n    end \n    \n    @testset \"Flux-like usage\" begin\n        ## This checks some broadcast rules for correctness:\n        gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)\n        gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)\n        gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2)\n        gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2)\n        gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2)\n        gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2, atol=1e-4)\n        gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2)\n        gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2, atol=1e-4)\n\n        gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)\n        gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)\n\n        ## Binary functions have their own broadcast rules:\n        gradtest((x, W, b) -> leakyrelu.(W*x .+ b, 0.2), 5, (2,5), 2)\n        gradtest((x, W, b) -> leakyrelu.(W*x .+ b, 0.7), (5,3), (2,5), 2)\n    end\n\n    @testset \"Zygote issue 758\" begin\n        ## Tests for https://github.com/FluxML/Zygote.jl/issues/758\n        @test gradient(xs -> sum(selu.(xs)), [1_000, 10_000])[1] ≈ [1.0507009873554805, 1.0507009873554805] rtol=1e-8\n        @test gradient(x -> selu(x), 1_000) == (1.0507009873554805,)\n        @test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],)\n        @test gradient(x -> elu(x, 2), 1_000) == (1.,)\n        @test gradient(x -> elu(x, 2), -1) == (2*exp(-1),)\n        gradtest(x-> selu.(x),[100., 1_000.])\n        gradtest(x -> elu.(x, 3.5),[100., 1_000.])\n        gradtest(x -> elu.(x, 3.5),[1_000., 10_000.])\n        gradtest(x -> selu.(x), [1_000., 10_000.])\n        gradtest(x -> selu.(x), 10, atol=1e-4)\n    end\n\nend\n\n@testset \"Second derivatives\" begin\n    ## Not extensive, but a start!\n    ## More careful tests could look for `nothing` gradients of piecewise functions\n    @testset \"$(f): $(has_rule(f))\" for f in ACTIVATION_FUNCTIONS\n        f == rrelu && continue\n\n        ## Scalar\n        h = Zygote.hessian_dual(x -> sin(f(x)), 0.23)\n        @test h ≈ Zygote.hessian_reverse(x -> sin(f(x)), 0.23)\n\n        ## Broadcasting\n        x = [-0.9, -0.2, 0.1, 0.3, 1.2]\n        H = Zygote.hessian_dual(x -> sum(abs2, f.(x .+ 0.1)), x)\n        @test H ≈ Zygote.hessian_reverse(x -> sum(abs2, f.(x .+ 0.1)), x)\n    end\n    @testset \"$(f): $(has_rule(f))\" for f in BINARY_ACTIVATIONS\n        f == rrelu && continue\n\n        ## Scalar\n        h = Zygote.hessian_dual(x -> sin(f(x, 0.3)), 0.45)\n        @test h ≈ Zygote.hessian_reverse(x -> sin(f(x, 0.3)), 0.45)\n\n        ## Broadcasting\n        x = [-0.9, -0.2, 0.1, 0.3, 1.2]\n        H = Zygote.hessian_dual(x -> sum(abs2, f.(x .+ 0.1, 0.3)), x)\n        @test H ≈ Zygote.hessian_reverse(x -> sum(abs2, f.(x .+ 0.1, 0.3)), x)\n    end\nend\n"
  },
  {
    "path": "test/attention.jl",
    "content": "@testset \"different batchsizes\" begin\n    n = 15\n    lenq = 3\n    lenkv = 4\n    for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5]\n        q = rand(Float32, n, lenq, batch_size...)\n        k = rand(Float32, n, lenkv, batch_size...)\n        v = rand(Float32, n, lenkv, batch_size...)\n        y, α = dot_product_attention(q, k, v; nheads)\n        @test y isa Array{Float32}\n        @test size(y) == (n, lenq, batch_size...)\n        @test size(α) == (lenkv, lenq, nheads, batch_size...)\n        @test sum(α, dims=1) ≈ ones(1, lenq, nheads, batch_size...)\n    end\nend\n\n@testset \"dot_product_attention_scores\" begin\n    q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24\n    α = dot_product_attention_scores(q, k)\n    q2, k2 = reshape.((q, k), 8, 3, 1)\n    y, α2 = dot_product_attention(q2, k2, k2; nheads=2)\n    @test α ≈ α2\nend\n\n@testset \"specific results\" begin\n    q = k = v = reshape([1:12;], 4, 3, 1) ./ 12\n    y, α = dot_product_attention(q, k, v; nheads=2)\n    ytrue = [0.429754, 0.513087, 0.613791, 0.697125, 0.46431, 0.547644, 0.647876, 0.73121, 0.49773, 0.581064, 0.680455, 0.763788]\n    ytrue = reshape(ytrue, 4, 3, 1)\n    αtrue = [0.313896, 0.332948, 0.353157, 0.264431, 0.328206, 0.407362, 0.219215, 0.31838, 0.462405, 0.288691, 0.331243, 0.380066, 0.241239, 0.323893, 0.434868, 0.198438, 0.311761, 0.489801]\n    αtrue = reshape(αtrue, 3, 3, 2, 1)\n    @test y ≈ ytrue atol=1e-5\n    @test α ≈ αtrue atol=1e-5\nend\n\n@testset \"mask\" begin\n    q = rand(4, 2, 3, 1)\n    k = rand(4, 2, 5, 1)\n\n    mask = rand(Bool, (5, 3))\n    α = dot_product_attention_scores(q, k; mask)\n    @test all((α[:,:,1,1].> 0) .== mask)\n    @test all((α[:,:,2,1].> 0) .== mask)\n\n    @testset \"causal\" begin\n        x = rand(4, 2, 3, 1)\n        mask = make_causal_mask(x, dims=3)\n        α = dot_product_attention_scores(x, x; mask)\n        @test all((α[:,:,1,1].> 0) .== mask)\n        @test all((α[:,:,2,1].> 0) .== mask)\n    end\nend\n\n@testset \"dropout\" begin\n    q = k = v = rand(10, 10, 10)\n    fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)\n    y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5))\n    @test 0.6 > mean(>(0), α) > 0.4\nend\n\n@testset \"bias\" begin\n    q = rand(4, 5, 1)\n    k = v = rand(4, 3, 1)\n    bias = randn(3, 5)\n    y, α = dot_product_attention(q, k, v, bias; nheads=2)\n    @test size(α) == (3, 5, 2, 1)\n    @test size(y) == (4, 5, 1)\nend\n\n@testset \"gradient\" begin\n    q = rand(4, 5, 1)\n    k = v = rand(4, 3, 1)\n    bias = randn(3, 5)\n    y, α = dot_product_attention(q, k, v, bias; nheads=2)\n    gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias)\nend\n"
  },
  {
    "path": "test/batchedmul.jl",
    "content": "using NNlib, Test, LinearAlgebra, Logging\nusing NNlib: storage_type, storage_typejoin, is_strided,\n    batched_mul_generic!, _unbatch, _copy_if_faster,\n    BatchedAdjoint, BatchedTranspose\n\nfunction bmm_test(a,b; transA = false, transB = false)\n    bs = size(a,3)\n    transA && (a = permutedims(a, [2,1,3]))\n    transB && (b = permutedims(b, [2,1,3]))\n    c = []\n    for i = 1:bs\n        push!(c, a[:,:,i]*b[:,:,i])\n    end\n\n    cat(c...; dims = 3)\nend\n\nfunction bmm_adjtest(a,b; adjA = false, adjB = false)\n    bs = size(a,3)\n    c = []\n    for i = 1:bs\n        ai = adjA ? adjoint(a[:,:,i]) : a[:,:,i]\n        bi = adjB ? adjoint(b[:,:,i]) : b[:,:,i]\n        push!(c, ai*bi)\n    end\n\n    cat(c...; dims = 3)\nend\n\nfunction half_batched_mul(x,y)\n    @assert size(y,3) == 1\n    d = size(x,2)\n    x_mat = reshape(permutedims(x, (1,3,2)),:,d)\n    y_mat = reshape(y,d,:)\n    z_mat = x_mat * y_mat\n    permutedims(reshape(z_mat, size(x,1), size(x,3), :), (1,3,2))\nend\n\n@testset \"batched_mul: Float64 * $TB\" for TB in [Float64, Float32]\n\n    # Real\n    A = randn(7,5,3)\n    B = randn(TB, 5,7,3)\n    C = randn(7,6,3)\n\n    @test batched_mul(A, B) ≈ bmm_test(A, B)\n    @test batched_mul(batched_transpose(A), batched_transpose(B)) ≈ bmm_test(A, B; transA = true, transB = true)\n    @test batched_mul(batched_transpose(A), C) ≈ bmm_test(A, C; transA = true)\n    @test batched_mul(A, batched_transpose(A)) ≈ bmm_test(A, A; transB = true)\n\n    # Complex\n    cA = randn(Complex{Float64}, 7,5,3)\n    cB = randn(Complex{TB}, 5,7,3)\n    cC = randn(Complex{Float64}, 7,6,3)\n\n    @test batched_mul(cA, cB) ≈ bmm_adjtest(cA, cB)\n    @test batched_mul(batched_adjoint(cA), batched_adjoint(cB)) ≈ bmm_adjtest(cA, cB; adjA = true, adjB = true)\n    @test batched_mul(batched_adjoint(cA), cC) ≈ bmm_adjtest(cA, cC; adjA = true)\n    @test batched_mul(cA, batched_adjoint(cA)) ≈ bmm_adjtest(cA, cA; adjB = true)\n\n    # Wrappers which cancel\n    @test batched_transpose(batched_transpose(A)) === A\n    @test batched_transpose(PermutedDimsArray(A, (2,1,3))) === A\n    @test batched_adjoint(batched_adjoint(cA)) === cA\n    @test batched_transpose(batched_adjoint(cA)) isa NNlib.BatchedAdjoint\n\n    # Integers\n    TBi = TB==Float64 ? Int64 : Int32\n    iA = rand(1:99, 7,5,3)\n    iB = TB.(rand(1:99, 5,7,3))\n    iC = zeros(Int, 7,6,3)\n    @test batched_mul(iA, iB) == bmm_adjtest(iA, iB)\n    @test batched_mul(cA, iB) ≈ bmm_adjtest(cA, iB)\n\n    # Errors\n    @test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 2,2,10))\n    @test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 10,2,2))\n    @test_throws Exception batched_mul!(zeros(2,2,10), rand(2,2,2), rand(TB, 2,2,2))\n\n    # PermutedDimsArrays\n    for perm in [(1,3,2), (2,1,3), (3,2,1)], fun in [identity, batched_adjoint], ty in [identity, complex]\n        A = randn(ty(Float64), 4,4,4)\n        B = randn(ty(TB), 4,4,4)\n        @test batched_mul(fun(A), PermutedDimsArray(B, perm)) ≈ batched_mul(fun(A), permutedims(B, perm))\n        @test batched_mul(fun(PermutedDimsArray(A, perm)), B) ≈ batched_mul(fun(permutedims(A, perm)), B)\n        # when TB=Float64, only the case  perm=(2,1,3); fun=batched_adjoint; ty=complex;  goes to fallback\n        # but all the perm=(3,2,1); cases copy their inputs.\n    end\n\n    # PermutedDimsArray output\n    A′ = randn(4,3,2)\n    B′ = batched_adjoint(randn(TB, 5,3,2))\n    C1 = batched_mul(A′, B′) # size 4,5,2\n    C2 = PermutedDimsArray(zeros(5,2,4), (3,1,2)) # size 4,5,2\n    @test C1 ≈ batched_mul!(C2, A′, B′) # Float64: \"Debug: transposing C = A * B into Cᵀ = Bᵀ * Aᵀ\"\n    @test C1 ≈ C2\n\n    # 5-arg mul!\n    @test 10 .* C1 ≈ batched_mul!(C2, A′, B′, 10) rtol=1e-7\n    C2 .= 10\n    @test C1 .+ 100 ≈ batched_mul!(C2, A′, B′, 1, 10)\n\n    # Trivial batches for B\n    D′ = randn(TB, 3,5,1)\n    @test size(batched_mul(A′,D′)) == (4,5,2)\n    @test batched_mul(A′,D′) ≈ half_batched_mul(A′, D′)\n\n    # Large output, multi-threaded path\n    if TB == Float64\n        N = 50\n        A = rand(N,N,N)\n        B = rand(N,N,N)\n        C = reshape(reduce(hcat, [vec(A[:,:,k] * B[:,:,k]) for k in 1:N]), N,N,N)\n        @test C ≈ A ⊠ B\n\n        D = rand(N,N,1)\n        E = reshape(reduce(hcat, [vec(A[:,:,k] * D[:,:,1]) for k in 1:N]), N,N,N)\n        @test E ≈ A ⊠ D\n    end\nend\n\nperm_12(A) = PermutedDimsArray(A, (2,1,3))\nperm_23(A) = PermutedDimsArray(A, (1,3,2))\n\n@testset \"batched_mul: trivial dimensions & unit strides, $T\" for T in [Float64, ComplexF64]\n    @testset \"$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))\" for\n    tA in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], sA in [(1,1), (1,3), (3,1), (3,3)],\n    tB in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], sB in [(1,1), (1,3), (3,1), (3,3)]\n\n        A = tA(rand(T, sA..., 3))\n        B = tB(rand(T, sB..., 3))\n        size(A,2) == size(B,1) && size(A,3) == size(B,3) == 3 || continue\n\n        C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2], A[:,:,3] * B[:,:,3]; dims=3)\n        @test A ⊠ B ≈ C\n        @test_logs min_level=Logging.Debug A ⊠ B\n\n        # In-place batched_mul!\n        α, β = rand(T), rand(T)\n        D = rand(T, size(C))\n        @test batched_mul!(copy(D), A, B, α, β) ≈ α .* C .+ β .* D\n        @test batched_mul_generic!(copy(D), A, B, α, β) ≈ α .* C .+ β .* D\n\n        # ... and with weird LHS -- all to batched_mul_generic! right now\n        C2 = batched_transpose(permutedims(C, (2,1,3)))\n        C3 = batched_adjoint(permutedims(conj(C), (2,1,3)))\n        @test C2 == C3 == C\n        C2 .= D\n        C3 .= D\n        @test batched_mul!(C2, A, B, α, β) ≈ α .* C .+ β .* D\n        @test C2 ≈ α .* C .+ β .* D\n        @test batched_mul!(C3, A, B, α, β) ≈ α .* C .+ β .* D\n        @test C3 ≈ α .* C .+ β .* D\n    end\nend\n\n@testset \"BatchedAdjOrTrans interface * $TB\" for TB in [Float64, Float32]\n    A = randn(7,5,3)\n    B = randn(TB, 5,7,3)\n    C = randn(7,6,3)\n\n    function interface_tests(X, _X)\n        @test length(_X) == length(X)\n        @test size(_X) == (size(X, 2), size(X, 1), size(X, 3))\n        @test axes(_X) == (axes(X, 2), axes(X, 1), axes(X, 3))\n        #\n        @test getindex(_X, 2, 3, 3) == getindex(X, 3, 2, 3)\n        @test getindex(_X, 5, 4, 1) == getindex(X, 4, 5, 1)\n        #\n        setindex!(_X, 2.0, 2, 4, 1)\n        @test getindex(_X, 2, 4, 1) == 2.0\n        setindex!(_X, 3.0, 1, 2, 2)\n        @test getindex(_X, 1, 2, 2) == 3.0\n\n        _sim = similar(_X, TB, (2, 3))\n        @test size(_sim) == (2, 3)\n        @test typeof(_sim) == Array{TB, 2}\n\n        _sim = similar(_X, TB)\n        @test length(_sim) == length(_X)\n        @test typeof(_sim) == Array{TB, 3}\n\n        _sim = similar(_X, (2, 3))\n        @test size(_sim) == (2, 3)\n        @test typeof(_sim) == Array{Float64, 2}\n\n        _sim = similar(_X)\n        @test length(_sim) == length(_X)\n        @test typeof(_sim) == Array{Float64, 3}\n\n        @test parent(_X) == _X.parent\n    end\n\n    for (X, _X) in zip([A, B, C], map(batched_adjoint, [A, B, C]))\n        interface_tests(X, _X)\n\n        @test -_X == NNlib.BatchedAdjoint(-_X.parent)\n\n        _copyX = copy(_X)\n        @test _X == _copyX\n\n        setindex!(_copyX, 2.0, 1, 2, 1)\n        @test _X != _copyX\n    end\n\n    for (X, _X) in zip([A, B, C], map(batched_transpose, [A, B, C]))\n        interface_tests(X, _X)\n\n        @test -_X == NNlib.BatchedTranspose(-_X.parent)\n\n        _copyX = copy(_X)\n        @test _X == _copyX\n\n        setindex!(_copyX, 2.0, 1, 2, 1)\n        @test _X != _copyX\n    end\nend\n\n@testset \"batched_mul(ndims < 3), $TM\" for TM in [ComplexF64, Int8]\n    A = randn(ComplexF64, 3,3,3)\n    M = rand(TM, 3,3) .+ im\n    V = rand(TM, 3)\n\n    # These are all reshaped and sent to batched_mul(3-array, 3-array)\n    @test batched_mul(A, M) ≈ cat([A[:,:,k] * M for k in 1:3]...; dims=3)\n    @test batched_mul(A, M') ≈ cat([A[:,:,k] * M' for k in 1:3]...; dims=3)\n    @test A ⊠ transpose(M) ≈ cat([A[:,:,k] * transpose(M) for k in 1:3]...; dims=3)\n\n    @test batched_mul(M, A) ≈ cat([M * A[:,:,k] for k in 1:3]...; dims=3)\n    @test batched_mul(M', A) ≈ cat([M' * A[:,:,k] for k in 1:3]...; dims=3)\n    @test transpose(M) ⊠ A ≈ cat([transpose(M) * A[:,:,k] for k in 1:3]...; dims=3)\n\n    # batched_vec\n    @test batched_vec(A, M) ≈ hcat([A[:,:,k] * M[:,k] for k in 1:3]...)\n    @test batched_vec(A, M') ≈ hcat([A[:,:,k] * (M')[:,k] for k in 1:3]...)\n    @test batched_vec(A, V) ≈ hcat([A[:,:,k] * V for k in 1:3]...)\nend\n\n@testset \"storage_type\" begin\n\n    @test storage_type(transpose(reshape(view(rand(10), 2:9),4,:))) == Vector{Float64}\n    @test storage_type(transpose(reshape(view(1:10,     2:9),4,:))) == UnitRange{Int}\n\n    @test storage_typejoin(rand(2), rand(Float32, 2)) == Vector{<:Any}\n    @test storage_typejoin(rand(2), rand(2,3)', rand(2,3,4)) == Array{Float64}\n    @test storage_typejoin([1,2,3], 4:5) == AbstractVector{Int}\n\nend\n\n@testset \"is_strided\" begin\n\n    M = ones(10,10)\n\n    @test is_strided(M)\n    @test is_strided(view(M, 1:2:5,:))\n    @test is_strided(PermutedDimsArray(M, (2,1)))\n\n    @test !is_strided(reshape(view(M, 1:2:10,:), 10,:))\n    @test !is_strided((M.+im)')\n    @test !is_strided(Diagonal(ones(3)))\n\n    A = ones(2,2,2)\n\n    @test is_strided(batched_adjoint(A))\n    @test is_strided(batched_transpose(A))\n    @test !is_strided(batched_adjoint(A .+ im))\n    @test is_strided(batched_transpose(A .+ im))\n\nend\n\nFiniteDifferences.to_vec(x::BatchedAdjoint) = FiniteDifferences.to_vec(collect(x))\nFiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect(x))\n\n@testset \"AutoDiff\" begin\n    M, P, Q = 13, 7, 11\n    B = 3\n    # Two 3-arrays\n    gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B))\n    gradtest(batched_mul, batched_adjoint(randn(rng, P, M, B)), randn(rng, P, Q, B))\n    gradtest(batched_mul, randn(rng, M, P, B), batched_transpose(randn(rng, Q, P, B)))\n\n    # One a matrix...\n    gradtest(batched_mul, randn(rng, M, P), randn(rng, P, Q, B))\n    gradtest(batched_mul, adjoint(randn(rng, P, M)), randn(rng, P, Q, B))\n    gradtest(batched_mul, randn(rng, M, P), batched_adjoint(randn(rng, Q, P, B)))\n\n    gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q))\n    gradtest(batched_mul, batched_transpose(randn(rng, P, M, B)), randn(rng, P, Q))\n    gradtest(batched_mul, randn(rng, M, P, B), transpose(randn(rng, Q, P)))\n\n    # ... or equivalent to a matrix\n    gradtest(batched_mul, randn(rng, M, P, 1), randn(rng, P, Q, B))\n    gradtest(batched_mul, batched_transpose(randn(rng, P, M, 1)), randn(rng, P, Q, B))\n    gradtest(batched_mul, randn(rng, M, P, 1), batched_transpose(randn(rng, Q, P, B)))\n\n    gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, 1))\n    gradtest(batched_mul, batched_adjoint(randn(rng, P, M, B)), randn(rng, P, Q, 1))\n    gradtest(batched_mul, randn(rng, M, P, B), batched_adjoint(randn(rng, Q, P, 1)))\n\n    # batched_vec\n    gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P, B))\n    gradtest(batched_vec, randn(rng, M, P, B), transpose(randn(rng, B, P)))\n\n    gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P))\nend\n\n@testset \"batched_vec: N-D batches\" begin\n    # Test 4D case: A is 4D, B is 3D\n    A4d = randn(4, 5, 3, 2)  # (matrix_rows, matrix_cols, batch_dim1, batch_dim2)\n    B3d = randn(5, 3, 2)     # (vector_length, batch_dim1, batch_dim2)\n    \n    C = batched_vec(A4d, B3d)\n    @test size(C) == (4, 3, 2)\n    \n    # Manual verification\n    for i in 1:3, j in 1:2\n        @test C[:, i, j] ≈ A4d[:, :, i, j] * B3d[:, i, j]\n    end\n    \n    # Test 5D case: A is 5D, B is 4D\n    A5d = randn(3, 4, 2, 3, 2)  # (matrix_rows, matrix_cols, batch1, batch2, batch3)\n    B4d = randn(4, 2, 3, 2)     # (vector_length, batch1, batch2, batch3)\n    \n    C5 = batched_vec(A5d, B4d)\n    @test size(C5) == (3, 2, 3, 2)\n    \n    # Manual verification for a few cases\n    @test C5[:, 1, 1, 1] ≈ A5d[:, :, 1, 1, 1] * B4d[:, 1, 1, 1]\n    @test C5[:, 2, 3, 2] ≈ A5d[:, :, 2, 3, 2] * B4d[:, 2, 3, 2]\n    \n    # Test dimension mismatch errors\n    @test_throws DimensionMismatch batched_vec(randn(3, 4, 2), randn(4, 3))  # ndims mismatch\n    @test_throws DimensionMismatch batched_vec(randn(3, 4, 2, 3), randn(4, 2, 2))  # batch size mismatch\n    \nend\n"
  },
  {
    "path": "test/bias_act.jl",
    "content": "using NNlib, Zygote, ChainRulesCore, Test\nusing Zygote: ForwardDiff\n\nACTIVATION_FUNCTIONS =\n    [@eval($a) for a in NNlib.ACTIVATIONS]\n\n@testset \"bias_act!\" begin\n    x = randn(3,4)\n    b = randn(3)\n    @test @inferred(bias_act!(identity, x, false)) === x  # pass-through\n    @test @inferred(bias_act!(identity, copy(x), b)) ≈ (x .+ b)\n    @test @inferred(bias_act!(relu, copy(x), b)) ≈ relu.(x .+ b)\n    @test @inferred(bias_act!(tanh, copy(x), b)) ≈ tanh.(x .+ b)\n    @test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x)\n\n    # Check that it does overwrite:\n    x32 = rand(Float32, 3, 4); x32copy = copy(x32)\n    @test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b)\n    @test x32 ≈ cbrt.(x32copy .+ b)\n\n    x32 = rand(Float32, 3, 4); x32copy = copy(x32)  # without bias\n    @test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy)\n    @test x32 ≈ tanh.(x32copy)\n\n    x32 = rand(Float32, 3, 4); x32copy = copy(x32)  # now check gradient rule\n    y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b)\n    @test y ≈ x32 ≈ relu.(x32copy .+ b)\n\n    x32 = rand(Float32, 3, 4); x32copy = copy(x32)  # without bias\n    y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false)\n    @test y ≈ x32 ≈ relu.(x32copy)\n\n    # Check that it doesn't try to overwrite non-float arrays:\n    xint = rand(-3:3, 3, 4)\n    bint = rand(-2:2, 3)\n    @test bias_act!(identity, copy(xint), bint) ≈ xint .+ bint\n    @test bias_act!(tanh, copy(xint), bint) ≈ tanh.(xint .+ bint)\n    @test bias_act!(tanh, copy(xint), false) ≈ tanh.(xint)\n\n    # Reject bias===true so that Bool means one thing:\n    @test_throws Exception bias_act!(identity, rand(3), true)\n    @test_throws Exception bias_act!(cbrt, rand(3), true)\n    @test_throws Exception bias_act!(cbrt, rand(1:3, 3), true)\n\n    @testset \"gradient with $fun\" for fun in vcat([identity, tanh, cbrt],\n                                                    ACTIVATION_FUNCTIONS,\n                                                    [x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)])\n        # Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about.\n        fun == rrelu && continue # this one is randomised!\n        fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below\n\n        @test bias_act!(fun, copy(x), b) ≈ fun.(x .+ b)\n        @test bias_act!(fun, copy(x), false) ≈ fun.(x)\n\n        gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)\n        gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps())\n        gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps())\n        if !(gx ≈ gxplus ≈ gxminus)\n            @warn \"skipping gradient tests due to discontinuity\" fun x b\n            continue\n        end\n        @test gx ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1]\n\n        gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)\n        gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())\n        gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())\n        if !(gx2 ≈ gx2plus ≈ gx2minus)\n            @warn \"skipping gradient tests due to discontinuity\" fun x\n            continue\n        end\n        @test gx2 ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1]\n\n        gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)\n        @test gb ≈ Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1]\n\n        @test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,)\n        @test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,)\n    end\n\n    @testset \"gradient for fast_broadcast!\" begin\n        # Gradient definition is just to disable mutation inside 2nd order AD\n        gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)\n        @test gx ≈ Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)[1]\n\n        # relu should take the fast path\n        g2 = ForwardDiff.gradient(x) do x\n            sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])\n        end\n        @test_skip gx ≈ Zygote.gradient(x) do x  # Here global variable b causes an error\n            sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])\n        end\n        # Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).\n        # [5] (::typeof(∂(accum_global)))(Δ::Nothing)\n        @test g2 ≈ Zygote.gradient(x, b) do x, b\n            sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1])\n        end[1]\n\n       g3 = ForwardDiff.gradient(x) do x\n            sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])\n        end\n        @test g3 ≈ Zygote.gradient(x, b) do x, b\n            sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])\n        end[1]\n\n        # Anon function sure to take the generic path\n        g4 = ForwardDiff.gradient(x) do x\n            sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])\n        end\n        @test g4 ≈ Zygote.gradient(x, b) do x, b\n            sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])\n        end[1]\n    end\nend\n\n"
  },
  {
    "path": "test/conv.jl",
    "content": "using NNlib, Test\nusing NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier,\n             stride, padding, dilation, flipkernel, output_size,\n             groupcount\nusing Random: AbstractRNG, SamplerType\n\n@testset \"ConvDims\" begin\n    for T in (DenseConvDims, DepthwiseConvDims)\n        @testset \"$(T)\" begin\n            x = randn(5,4,3,2)\n\n            if T == DenseConvDims\n                w = randn(1,2,3,4)\n            elseif T == DepthwiseConvDims\n                w = randn(1,2,4,3)\n            end\n\n            # First, getters:\n            cdims = T(x, w)\n            @test input_size(cdims) == size(x)[1:2]\n            @test kernel_size(cdims) == size(w)[1:2]\n            @test channels_in(cdims) == size(x, 3)\n            @test stride(cdims) == (1,1)\n            @test dilation(cdims) == (1,1)\n            @test padding(cdims) == (0,0,0,0)\n            @test flipkernel(cdims) == false\n            @test output_size(cdims) == (5,3)\n\n            # Special-case channel output tests\n            if T == DenseConvDims\n                @test channels_out(cdims) == size(w, 4)\n            elseif T == DepthwiseConvDims\n                @test channel_multiplier(cdims) == size(w, 3)\n                @test channels_out(cdims) == size(w,3)*size(w,4)\n            end\n\n            # Next, scalar settings:\n            cdims = T(x, w; stride=2, dilation=2, padding=3, flipkernel=true)\n            @test stride(cdims) == (2,2)\n            @test dilation(cdims) == (2,2)\n            @test padding(cdims) == (3,3,3,3)\n            @test flipkernel(cdims) == true\n            @test output_size(cdims) == (6,4)\n\n            # Next, tuple settings\n            cdims = T(x, w; stride=(1, 2), dilation=(1, 2), padding=(0,1))\n            @test stride(cdims) == (1,2)\n            @test dilation(cdims) == (1,2)\n            @test padding(cdims) == (0,0,1,1)\n            @test output_size(cdims) == (5,2)\n\n            # Special case for 4-d padding spec:\n            cdims = T(x, w; padding=(1,2,3,4))\n            @test padding(cdims) == (1,2,3,4)\n            @test output_size(cdims) == (8,10)\n\n            # Make sure we throw on invalid settings:\n            # Invalid dimensionality of settings:\n            @test_throws DimensionMismatch T(x, w; stride=(1,))\n            @test_throws DimensionMismatch T(x, w; stride=(1, 1, 1))\n            @test_throws DimensionMismatch T(x, w; padding=(1, 1, 1))\n            @test_throws DimensionMismatch T(x, w; padding=(1, 1, 1, 1, 1))\n            @test_throws DimensionMismatch T(x, w; dilation=(1,))\n            @test_throws DimensionMismatch T(x, w; dilation=(1, 1, 1))\n            # Dilation will cause us to reach beyond the end of input + padding here:\n            @test_throws DimensionMismatch T(x, w; dilation=(1, 5))\n            # Channel mismatch:\n            if T == DenseConvDims\n                @test_throws DimensionMismatch T(x, w[:,:,1:1,:])\n            elseif T == DepthwiseConvDims\n                @test_throws DimensionMismatch T(x, w[:,:,:,1:1])\n            end\n        end\n    end\nend\n\nconv_answer_dict = Dict(\n    # Known-good answers for 1d convolution operations\n    1 => Dict(\n        \"y_pad\"  => [1, 4,  7, 10, 13, 10.],\n        \"y_dil\"  => [5, 8, 11.],\n        \"y_flip\" => [5, 8, 11, 14.],\n\n        \"dx\"        => [ 8, 18, 27, 36, 13.],\n        \"dx_stride\" => [ 8,  4, 20, 10,  0.],\n        \"dx_pad\"    => [ 9, 18, 27, 36, 33.],\n        \"dx_dil\"    => [10, 16, 27,  8, 11.],\n        \"dx_flip\"   => [ 5, 18, 27, 36, 28.],\n\n        \"dw\"        => [134, 100.],\n        \"dw_stride\" => [ 48,  34.],\n        \"dw_pad\"    => [135, 150.],\n        \"dw_dil\"    => [102,  54.],\n        \"dw_flip\"   => [110, 148.],\n    ),\n\n    # Known-good answers for 2d convolution operations\n    2 => Dict(\n        \"y_pad\" => [\n            1  9  29  49  48;\n            4 29  79 129 115;\n            7 39  89 139 122;\n            10 49  99 149 129;\n            13 59 109 159 136;\n            10 40  70 100  80.\n        ],\n        \"y_dil\" => [\n            48   98;\n            58  108;\n            68  118.\n        ],\n        \"y_flip\" => [\n            51  101  151;\n            61  111  161;\n            71  121  171;\n            81  131  181.\n        ],\n\n        \"dx\" => [\n            116  374   674  258;\n            243  700  1200  407;\n            313  800  1300  437;\n            383  900  1400  467;\n            177  386   586  159.\n        ],\n        \"dx_stride\" => [\n            116  58  516  258;\n            87  29  387  129;\n            196  98  596  298;\n            147  49  447  149;\n            0   0    0    0.\n        ],\n        \"dx_pad\" => [\n            152  470   850   911;\n            261  700  1200  1240;\n            340  800  1300  1319;\n            419  900  1400  1398;\n            370  746  1126  1087.\n        ],\n        \"dx_dil\" => [\n            192  392   96  196;\n            232  432  116  216;\n            416  766  184  334;\n            174  324   58  108;\n            204  354   68  118.\n        ],\n        \"dx_flip\" => [\n            51  254   454   453;\n            163  700  1200  1087;\n            193  800  1300  1157;\n            223  900  1400  1227;\n            162  586   886   724.\n        ],\n\n        \"dw\" => [\n            17378  11738;\n            16250  10610.\n        ],\n        \"dw_stride\" => [\n            5668  3888;\n            5312  3532.\n        ],\n        \"dw_pad\" => [\n            18670  22550;\n            19850  23430.\n        ],\n        \"dw_dil\" => [\n            8632  3652;\n            7636  2656.\n        ],\n        \"dw_flip\" => [\n            12590  19550;\n            13982  20942.\n        ],\n    ),\n\n    # Known-good answers for 3d convolution operations (these are getting rather large)\n    3 => Dict(\n        \"y_pad\"  => reshape([\n            1, 4, 7, 10, 13, 10, 9, 29, 39, 49, 59, 40, 29, 79, 89, 99, 109, 70, 49, 129,\n            139, 149, 159, 100, 48, 115, 122, 129, 136, 80, 26, 80, 94, 108, 122, 80, 126,\n            322, 358, 394, 430, 260, 206, 502, 538, 574, 610, 360, 286, 682, 718, 754, 790,\n            460, 220, 502, 524, 546, 568, 320, 146, 360, 374, 388, 402, 240, 446, 1042, 1078,\n            1114, 1150, 660, 526, 1222, 1258, 1294, 1330, 760, 606, 1402, 1438, 1474, 1510,\n            860, 420, 942, 964, 986, 1008, 560, 205, 456, 467, 478, 489, 270, 517, 1133, 1159,\n            1185, 1211, 660, 577, 1263, 1289, 1315, 1341, 730, 637, 1393, 1419, 1445, 1471,\n            800, 392, 847, 862, 877, 892, 480.\n        ], (6,5,4)),\n        \"y_dil\"  => reshape([608, 644, 680, 788, 824, 860.], (3,2,1)),\n        \"y_flip\" => reshape([\n            686, 722, 758, 794, 866, 902, 938, 974, 1046, 1082, 1118, 1154, 1406, 1442,\n            1478, 1514, 1586, 1622, 1658, 1694, 1766, 1802, 1838, 1874.\n        ], (4,3,2)),\n\n        \"dx\"        => reshape([\n            2576, 5118, 5658, 6198, 3010, 5948, 11576, 12512, 13448, 6420, 8468, 16256,\n            17192, 18128, 8580, 4092, 7718, 8114, 8510, 3950, 9624, 18316, 19108, 19900,\n            9340, 18680, 34992, 36288, 37584, 17320, 22280, 41472, 42768, 44064, 20200,\n            9776, 17756, 18260, 18764, 8340, 4168, 7438, 7690, 7942, 3450, 6972, 11896,\n            12256, 12616, 5140, 8052, 13696, 14056, 14416, 5860, 2804, 4278, 4386, 4494,\n            1510.\n        ], (5,4,3)),\n        \"dx_stride\" => reshape([\n            2576, 2254, 3152, 2758, 0, 1932, 1610, 2364, 1970, 0, 5456, 4774, 6032,\n            5278, 0, 4092, 3410, 4524, 3770, 0, 1288, 966, 1576, 1182, 0, 644, 322,\n            788, 394, 0, 2728, 2046, 3016, 2262, 0, 1364, 682, 1508, 754, 0, 0, 0, 0,\n            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.\n        ], (5,4,3)),\n        \"dx_pad\"    => reshape([\n            4220, 6343, 7116, 7889, 6550, 8490, 12276, 13312, 14348, 11606, 12350,\n            17456, 18492, 19528, 15546, 11989, 16664, 17469, 18274, 14333, 16200,\n            22628, 23616, 24604, 19392, 25336, 34992, 36288, 37584, 29320, 30216,\n            41472, 42768, 44064, 34200, 26236, 35664, 36652, 37640, 28940, 22816,\n            30831, 31636, 32441, 24794, 32522, 43668, 44704, 45740, 34742, 36462,\n            48848, 49884, 50920, 38602, 29501, 39264, 40037, 40810, 30733.\n        ], (5,4,3)),\n        \"dx_dil\"    => reshape([\n            4864, 5152, 9696, 4508, 4760, 6304, 6592, 12396, 5768, 6020, 3648,\n            3864, 7120, 3220, 3400, 4728, 4944, 9100, 4120, 4300, 0, 0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2432, 2576, 4544, 1932, 2040,\n            3152, 3296, 5804, 2472, 2580, 1216, 1288, 1968, 644, 680, 1576, 1648,\n            2508, 824, 860.\n        ], (5,4,3)),\n        \"dx_flip\"   => reshape([\n            686, 2094, 2202, 2310, 1588, 2924, 7544, 7904, 8264, 5124, 3644, 9344,\n            9704, 10064, 6204, 3138, 7430, 7682, 7934, 4616, 4836, 11980, 12484,\n            12988, 7792, 14936, 34992, 36288, 37584, 21640, 17816, 41472, 42768,\n            44064, 25240, 12620, 28412, 29204, 29996, 16728, 7030, 15646, 16042,\n            16438, 9084, 17772, 38968, 39904, 40840, 22276, 19932, 43648, 44584,\n            45520, 24796, 12362, 26742, 27282, 27822, 14992.\n        ], (5,4,3)),\n\n        \"dw\"        => reshape([1.058184e6, 1.0362e6,    948264,    926280,\n                                    618504,   596520,    508584,    486600], (2,2,2)),\n        \"dw_stride\" => reshape([    74760,     72608,     64000,     61848,\n                                    31720,     29568,     20960,     18808.], (2,2,2)),\n        \"dw_pad\"    => reshape([1.26055e6, 1.30805e6, 1.40327e6, 1.44923e6,\n                                1.73731e6, 1.77589e6, 1.83259e6, 1.86731e6], (2,2,2)),\n        \"dw_dil\"    => reshape([   250320,    241512,    206280,    197472,\n                                    74160,     65352,     30120,     21312.], (2,2,2)),\n        \"dw_flip\"   => reshape([    639480,   670200,    793080,    823800,\n                                    1.25388e6, 1.2846e6, 1.40748e6,  1.4382e6], (2,2,2)),\n    ),\n)\n\n# A \"drop channels and batch dimension\" helper\nddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))\n\n@testset \"Dense Convolution\" begin\n    # Start with some easy-to-debug cases that we have worked through and _know_ work\n    for rank in (1,2,3)\n        @testset \"conv$(rank)d\" begin\n            # Pull out known-good answers for y = conv(x, w)\n            y_pad = conv_answer_dict[rank][\"y_pad\"]\n            y_dil = conv_answer_dict[rank][\"y_dil\"]\n            y_flip = conv_answer_dict[rank][\"y_flip\"]\n\n            # We can always derive y_plain and y_stride from the other answers.\n            y_plain = y_pad[((2:(size(y_pad,idx)-1)) for idx in 1:rank)...]\n            y_stride = y_pad[((2:2:(size(y_pad,idx)-1)) for idx in 1:rank)...]\n\n            # Same for dx and dw:\n            dx = conv_answer_dict[rank][\"dx\"]\n            dx_stride = conv_answer_dict[rank][\"dx_stride\"]\n            dx_pad = conv_answer_dict[rank][\"dx_pad\"]\n            dx_dil = conv_answer_dict[rank][\"dx_dil\"]\n            dx_flip = conv_answer_dict[rank][\"dx_flip\"]\n\n            dw = conv_answer_dict[rank][\"dw\"]\n            dw_stride = conv_answer_dict[rank][\"dw_stride\"]\n            dw_pad = conv_answer_dict[rank][\"dw_pad\"]\n            dw_dil = conv_answer_dict[rank][\"dw_dil\"]\n            dw_flip = conv_answer_dict[rank][\"dw_flip\"]\n\n            # We generate x and w from the shapes we know they must be\n            x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1)\n            w = reshape(Float64[1:prod(size(dw));], size(dw)..., 1, 1)\n\n            convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,]\n            for conv in convs\n                @testset \"$(conv)\" begin\n                    cdims = DenseConvDims(x, w)\n                    # First, your basic convolution with no parameters\n                    @test isapprox(ddims(conv(x, w, cdims)), y_plain, rtol = 1.0e-7)\n\n                    # Next, test convolution on views and alternate datatypes:\n                    @test isapprox(ddims(conv(view(x, repeat([:], ndims(x))...), w, cdims)), y_plain, rtol = 1.0e-7)\n                    @test isapprox(ddims(conv(Float32.(x), Float32.(w), cdims)), Float32.(y_plain), rtol = 1.0e-7)\n\n                    # Next, introduce stride:\n                    cdims = DenseConvDims(x, w; stride=2)\n                    @test isapprox(ddims(conv(x, w, cdims)), y_stride, rtol = 1.0e-7)\n\n                    # Next, introduce dilation:\n                    cdims = DenseConvDims(x, w; dilation=2)\n                    @test isapprox(ddims(conv(x, w, cdims)), y_dil, rtol = 1.0e-7)\n\n                    # Next, introduce padding:\n                    cdims = DenseConvDims(x, w; padding=1)\n                    @test isapprox(ddims(conv(x, w, cdims)), y_pad, rtol = 1.0e-7)\n\n                    # Next, test crosscor/conv with a flipped kernel\n                    cdims = DenseConvDims(x, w; flipkernel=true)\n                    @test isapprox(ddims(conv(x, w, cdims)), y_flip, rtol = 1.0e-7)\n                end\n            end\n\n            # Test all in-place implementations/interfaces\n            convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,]\n            for conv! in convs\n                α, β = 2e0, -1e0\n\n                @testset \"$(conv!)\" begin\n                    # First, your basic convolution with no parameters\n                    cdims = DenseConvDims(x, w)\n                    y0 = rand(rng, -9e0:9e0, size(y_plain)..., 1, 1)\n                    @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7)\n\n                    # Next, test convolution on views and alternate datatypes:\n                    @test isapprox(ddims(conv!(copy(y0), view(x, repeat([:], ndims(x))...), w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7)\n                    @test isapprox(ddims(conv!(Float32.(copy(y0)), Float32.(x), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), Float32.(α*y_plain + β*y0), rtol = 1.0e-7)\n\n                    # Next, introduce stride:\n                    cdims = DenseConvDims(x, w; stride=2)\n                    y0 = rand(rng, -9e0:9e0, size(y_stride)..., 1, 1)\n                    @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_stride + β*y0, rtol = 1.0e-7)\n\n                    # Next, introduce dilation:\n                    cdims = DenseConvDims(x, w; dilation=2)\n                    y0 = rand(rng, -9e0:9e0, size(y_dil)..., 1, 1)\n                    @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_dil + β*y0, rtol = 1.0e-7)\n\n                    # Next, introduce padding:\n                    cdims = DenseConvDims(x, w; padding=1)\n                    y0 = rand(rng, -9e0:9e0, size(y_pad)..., 1, 1)\n                    @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_pad + β*y0, rtol = 1.0e-7)\n\n                    # Next, test crosscor/conv with a flipped kernel\n                    cdims = DenseConvDims(x, w; flipkernel=true)\n                    y0 = rand(rng, -9e0:9e0, size(y_flip)..., 1, 1)\n                    @test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_flip + β*y0, rtol = 1.0e-7)\n                end\n            end\n\n            # Test all implementations/interfaces\n            for (∇conv_filter, ∇conv_data) in (\n                    (NNlib.∇conv_filter,        NNlib.∇conv_data),\n                    (NNlib.∇conv_filter_im2col, NNlib.∇conv_data_im2col),\n                    (NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct),\n                )\n                @testset \"$(∇conv_filter)/$(∇conv_data)\" begin\n                    # First, your basic convolution with no parameters\n                    cdims = DenseConvDims(x, w)\n                    dy = NNlib.conv(x, w, cdims)\n                    @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data(dy, w,  cdims)), dx, rtol = 1.0e-7)\n\n                    # Next, test convolution on views and alternate datatypes:\n                    @test isapprox(ddims(∇conv_filter(x, view(dy, repeat([:], ndims(dy))...), cdims)), dw, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data(view(dy, repeat([:], ndims(dy))...), w,   cdims)), dx, rtol = 1.0e-7)\n\n                    @test isapprox(ddims(∇conv_filter(Float32.(x), Float32.(dy), cdims)), dw, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data(Float32.(dy),  Float32.(w),  cdims)), dx, rtol = 1.0e-7)\n\n                    # Next, introduce stride:\n                    cdims = DenseConvDims(x, w; stride=2)\n                    dy = NNlib.conv(x, w, cdims)\n                    @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_stride, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data(dy, w,  cdims)), dx_stride, rtol = 1.0e-7)\n\n                    # Next, introduce dilation:\n                    cdims = DenseConvDims(x, w; dilation=2)\n                    dy = NNlib.conv(x, w, cdims)\n                    @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_dil, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data(dy, w,  cdims)), dx_dil, rtol = 1.0e-7)\n\n                    # Next, introduce padding:\n                    cdims = DenseConvDims(x, w; padding=1)\n                    dy = NNlib.conv(x, w, cdims)\n                    @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_pad, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data(dy, w,  cdims)), dx_pad, rtol = 1.0e-7)\n\n                    # Next, test crosscor/conv with a flipped kernel\n                    cdims = DenseConvDims(x, w; flipkernel=true)\n                    dy = NNlib.conv(x, w, cdims)\n                    @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_flip, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data(dy, w,  cdims)), dx_flip, rtol = 1.0e-7)\n                end\n            end\n\n            # Test im2col\n\n            for beta in (-2.0, -1.0, 0.0, 0.5, 1.0, 2.0)\n                cache_dx, cache_dy, cache_w = ([0.17;;; 0.19;;; 0.23], [0.11;;; 0.13;;; 0.15], [1.0;;;])\n                dx_old = copy(cache_dx)\n                cdims = DenseConvDims(cache_dx, cache_w)\n                NNlib.∇conv_data_im2col!(cache_dx, cache_dy, cache_w, cdims; alpha=1.0, beta)\n                @test isapprox(cache_dx, dx_old * beta + cache_dy, rtol = 1.0e-7)\n            end\n\n            # Test all in-place implementations/interfaces\n            for (∇conv_filter!, ∇conv_data!) in (\n                    (NNlib.∇conv_filter!,        NNlib.∇conv_data!),\n                    (NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!),\n                    (NNlib.∇conv_filter_direct!, NNlib.∇conv_data_direct!),\n                )\n                #α, β = 2*rand(rng) - 1, 2*rand(rng) - 1\n                α, β = 2e0, -1e0\n\n                @testset \"$(∇conv_filter!)/$(∇conv_data!)\" begin\n                    # First, your basic convolution with no parameters\n                    cdims = DenseConvDims(x, w)\n                    dy = NNlib.conv(x, w, cdims)\n                    @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data!(copy(x), dy, w,   cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7)\n\n                    # Next, test convolution on views and alternate datatypes:\n                    @test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w,   cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7)\n\n                    @test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy),  Float32.(w),  cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7)\n\n                    # Next, introduce stride:\n                    cdims = DenseConvDims(x, w; stride=2)\n                    dy = NNlib.conv(x, w, cdims)\n                    flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3)\n                    @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_\n                    @test isapprox(ddims(∇conv_data!(copy(x), dy, w,   cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7)\n\n                    # Next, introduce dilation:\n                    cdims = DenseConvDims(x, w; dilation=2)\n                    dy = NNlib.conv(x, w, cdims)\n                    flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3\n                    @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data!(copy(x), dy, w,   cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag_\n\n                    # Next, introduce padding:\n                    cdims = DenseConvDims(x, w; padding=1)\n                    dy = NNlib.conv(x, w, cdims)\n                    @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data!(copy(x), dy, w,   cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7)\n\n                    # Next, test crosscor/conv with a flipped kernel\n                    cdims = DenseConvDims(x, w; flipkernel=true)\n                    dy = NNlib.conv(x, w, cdims)\n                    @test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7)\n                    @test isapprox(ddims(∇conv_data!(copy(x), dy, w,   cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7)\n                end\n            end\n        end\n    end\nend\n\n@testset \"Complex Dense Convolution\" begin\n    # For now only 1 dimensional 1x1 convolution\n    x = reshape(complex.(Float64[1:4;], Float64[1:4;] .+ 1), 1, 4, 1)\n    w = reshape(complex.(Float64[1:4;] .+ 2, Float64[1:4;] .+ 3), 1, 4, 1)\n    cdims = DenseConvDims(x, w)\n    convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,]\n    for conv in convs\n        @testset \"$(conv)\" begin\n            @test isapprox(ddims(conv(x, w, cdims)), [transpose(vec(w)) * vec(x)], rtol = 1.0e-7)\n        end\n    end\n    dy = NNlib.conv(x, w, cdims)\n    for (∇conv_filter, ∇conv_data) in (\n        (NNlib.∇conv_filter,        NNlib.∇conv_data),\n        (NNlib.∇conv_filter_im2col, NNlib.∇conv_data_im2col),\n        (NNlib.∇conv_filter_direct, NNlib.∇conv_data_direct),\n    )\n        @testset \"$(∇conv_filter)/$(∇conv_data)\" begin\n            @test isapprox(∇conv_filter(x, dy, cdims), conj(x) .* dy, rtol = 1.0e-7)\n            @test isapprox(∇conv_data(dy, w, cdims), dy .* conj(w), rtol = 1.0e-7)\n        end\n    end\nend\n\nif get(ENV, \"NNLIB_TEST_FUZZING\", \"false\") == \"true\"\n    # @info(\"Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them\")\n    @testset \"fuzzing\" begin\n        @info(\"Starting Convolutional fuzzing tests; this can take a few minutes...\")\n        # Now that we're fairly certain things are working, let's fuzz things a little bit:\n        for x_size in (\n                # 1d tests\n                (1,), (3,), (7,),\n                # 2d tests\n                (1, 3), (3, 3), (12, 3), (20, 17),\n                # 3d tests\n                (1, 1, 3), (3, 5, 4), (20, 17, 14),\n            ),\n            C_in in (1, 3),\n            batch in (1, 5)\n\n            # Allocate x in this outer loop to save on allocations and speed things up\n            x = rand(x_size..., C_in, batch)\n            dx_direct = similar(x)\n            dx_im2col = similar(x)\n\n            for w_size in (\n                    (1,), (3,), (7,),\n                    (1,1), (1,3), (3,4), (7, 4),\n                    (1,1,1), (1,1,3,), (3,4,3), (7,3,2)),\n                C_out in (1, 4)\n\n                # Give some output to the user that something is in fact happening.\n                print(\".\")\n\n                # Allocate w in this outer loop to save on allocations and speed things up\n                w = rand(w_size..., C_in, C_out)\n                dw_direct = similar(w)\n                dw_im2col = similar(w)\n\n                for S_size in (1, 2, 4, (1,2), (4,1), (2,1,4)),\n                    P_size in (0, 1, 2, (0,3,0,3), (4,1,4,2), (1,2,3,4,5,6)),\n                    D_size in (1, 2, 4, (1,2), (3,2), (4,2,3))\n\n                    # Skip tests that are impossible due to mismatched sizes\n                    try\n                        DenseConvDims(x, w;\n                            stride=S_size, padding=P_size, dilation=D_size,\n                        )\n                    catch e\n                        if isa(e, DimensionMismatch) || isa(e, MethodError)\n                            continue\n                        end\n                        rethrow(e)\n                    end\n\n                    # Do the actual convolution, comparing convolution implementations\n                    cdims = DenseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size)\n\n                    # We use mutating calls with explicitly different initial values, so as\n                    # to be sure to catch when we're leaving pieces of the output untouched.\n                    y_direct = ones(output_size(cdims)..., C_out, batch) .* 666.666\n                    y_im2col = ones(output_size(cdims)..., C_out, batch) .* 777.777\n\n                    # Do the convolutions\n                    NNlib.conv_direct!(y_direct, x, w, cdims)\n                    NNlib.conv_im2col!(y_im2col, x, w, cdims)\n\n                    # Compare!\n                    @test y_direct ≈ y_im2col\n                    dy = y_im2col\n\n                    # Now push backwards; first for the filter.  Again, we initialize our\n                    # memory so that segments that never get touched are immediately noticable\n                    fill!(dw_direct, 666.666)\n                    fill!(dw_im2col, 777.777)\n                    NNlib.∇conv_filter_direct!(dw_direct, x, dy, cdims)\n                    NNlib.∇conv_filter_im2col!(dw_im2col, x, dy, cdims)\n                    @test dw_direct ≈ dw_im2col\n\n                    # And then for the input\n                    fill!(dx_direct, 666.666)\n                    fill!(dx_im2col, 777.777)\n                    NNlib.∇conv_data_direct!(dx_direct, dy, w, cdims)\n                    NNlib.∇conv_data_im2col!(dx_im2col, dy, w, cdims)\n                    @test dx_direct ≈ dx_im2col\n                end\n            end\n        end\n        println()\n    end\nelse\n    @info \"Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them\"\nend\n\n@testset \"Depthwise Convolution\" begin\n    # Start with some easy-to-debug cases that we have worked through and _know_ work.\n    # NOTE: these examples are all single-channel... which doesn't really stress test\n    # the important parts of depthwise convolution!\n    for rank in (1,2,3)\n        @testset \"depthwiseconv$(rank)d\" begin\n            # Pull out known-good answers for y = depthwiseconv(x, w)\n            y_pad = conv_answer_dict[rank][\"y_pad\"]\n            y_dil = conv_answer_dict[rank][\"y_dil\"]\n            y_flip = conv_answer_dict[rank][\"y_flip\"]\n\n            # We can always derive y_plain and y_stride from the other answers.\n            y_plain = y_pad[((2:(size(y_pad,idx)-1)) for idx in 1:rank)...]\n            y_stride = y_pad[((2:2:(size(y_pad,idx)-1)) for idx in 1:rank)...]\n\n            # Same for dx and dw:\n            dx = conv_answer_dict[rank][\"dx\"]\n            dx_stride = conv_answer_dict[rank][\"dx_stride\"]\n            dx_pad = conv_answer_dict[rank][\"dx_pad\"]\n            dx_dil = conv_answer_dict[rank][\"dx_dil\"]\n            dx_flip = conv_answer_dict[rank][\"dx_flip\"]\n\n            dw = conv_answer_dict[rank][\"dw\"]\n            dw_stride = conv_answer_dict[rank][\"dw_stride\"]\n            dw_pad = conv_answer_dict[rank][\"dw_pad\"]\n            dw_dil = conv_answer_dict[rank][\"dw_dil\"]\n            dw_flip = conv_answer_dict[rank][\"dw_flip\"]\n\n            # We generate x and w from the shapes we know they must be\n            x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1)\n            w = reshape(Float64[1:prod(size(dw));], size(dw)..., 1, 1)\n\n            for conv in (NNlib.depthwiseconv, NNlib.depthwiseconv_im2col, NNlib.depthwiseconv_direct)\n                @testset \"$(conv)\" begin\n                    # First, your basic convolution with no parameters\n                    cdims = DepthwiseConvDims(x, w)\n                    @test ddims(conv(x, w, cdims)) == y_plain\n\n                    # Next, test convolution on views and alternate datatypes:\n                    @test isapprox(ddims(conv(view(x, repeat([:], ndims(x))...), w, cdims)), y_plain, rtol = 1.0e-7)\n                    @test isapprox(ddims(conv(Float32.(x), Float32.(w), cdims)), Float32.(y_plain), rtol = 1.0e-7)\n\n                    # Next, introduce stride:\n                    cdims = DepthwiseConvDims(x, w; stride=2)\n                    @test isapprox(ddims(conv(x, w, cdims)), y_stride, rtol = 1.0e-7)\n\n                    # Next, introduce dilation:\n                    cdims = DepthwiseConvDims(x, w; dilation=2)\n                    @test isapprox(ddims(conv(x, w, cdims)), y_dil, rtol = 1.0e-7)\n\n                    # Next, introduce padding:\n                    cdims = DepthwiseConvDims(x, w; padding=1)\n                    @test isapprox(ddims(conv(x, w, cdims)), y_pad, rtol = 1.0e-7)\n\n                    # Next, test crosscor/conv with a flipped kernel\n                    cdims = DepthwiseConvDims(x, w; flipkernel=true)\n                    @test isapprox(ddims(conv(x, w, cdims)), y_flip, rtol = 1.0e-7)\n                end\n            end\n\n            # Test all implementations/interfaces\n            for (∇conv_filter, ∇conv_data) in (\n                    (NNlib.∇depthwiseconv_filter,        NNlib.∇depthwiseconv_data),\n                    (NNlib.∇depthwiseconv_filter_im2col, NNlib.∇depthwiseconv_data_im2col),\n                    (NNlib.∇depthwiseconv_filter_direct, NNlib.∇depthwiseconv_data_direct),\n                )\n                @testset \"$(∇conv_filter)/$(∇conv_data)\" begin\n                    # First, your basic convolution with no parameters\n                    cdims = DepthwiseConvDims(x, w)\n                    dy = NNlib.depthwiseconv(x, w, cdims)\n                    @test ddims(∇conv_filter(x, dy, cdims)) == dw\n                    @test ddims(∇conv_data(dy, w,  cdims)) == dx\n\n                    # Next, test convolution on views and alternate datatypes:\n                    @test ddims(∇conv_filter(x, view(dy, repeat([:], ndims(dy))...), cdims)) == dw\n                    @test ddims(∇conv_data(view(dy, repeat([:], ndims(dy))...), w,   cdims)) == dx\n\n                    @test ddims(∇conv_filter(Float32.(x), Float32.(dy), cdims)) == dw\n                    @test ddims(∇conv_data(Float32.(dy),  Float32.(w),  cdims)) == dx\n\n                    # Next, introduce stride:\n                    cdims = DepthwiseConvDims(x, w; stride=2)\n                    dy = NNlib.depthwiseconv(x, w, cdims)\n                    @test ddims(∇conv_filter(x, dy, cdims)) == dw_stride\n                    @test ddims(∇conv_data(dy, w,  cdims)) == dx_stride\n\n                    # Next, introduce dilation:\n                    cdims = DepthwiseConvDims(x, w; dilation=2)\n                    dy = NNlib.depthwiseconv(x, w, cdims)\n                    @test ddims(∇conv_filter(x, dy, cdims)) == dw_dil\n                    @test ddims(∇conv_data(dy, w,  cdims)) == dx_dil\n\n                    # Next, introduce padding:\n                    cdims = DepthwiseConvDims(x, w; padding=1)\n                    dy = NNlib.depthwiseconv(x, w, cdims)\n                    @test ddims(∇conv_filter(x, dy, cdims)) == dw_pad\n                    @test ddims(∇conv_data(dy, w,  cdims)) == dx_pad\n\n                    # Next, test crosscor/conv with a flipped kernel\n                    cdims = DepthwiseConvDims(x, w; flipkernel=true)\n                    dy = NNlib.depthwiseconv(x, w, cdims)\n                    @test ddims(∇conv_filter(x, dy, cdims)) == dw_flip\n                    @test ddims(∇conv_data(dy, w,  cdims)) == dx_flip\n                end\n            end\n        end\n    end\n\n    # Do some real depthwise convolution tests\n    x = Float64.(reshape(1:2, (1,2,1)))\n    w = Float64.(reshape(1:6, (3,1,2)))\n    cdims = DepthwiseConvDims(x, w; padding=1)\n    for conv in (NNlib.depthwiseconv, NNlib.depthwiseconv_im2col, NNlib.depthwiseconv_direct)\n        @test conv(x, w, cdims)[:] ≈ [2, 10]  rtol=1e-7\n    end\nend\n\n\nif get(ENV,\"NNLIB_TEST_FUZZING\",\"false\") == \"true\"\n    @testset \"fuzzing\" begin\n        @info(\"Starting Depthwise Convolutional fuzzing tests; this can take a few minutes...\")\n        # Now that we're fairly certain things are working, let's fuzz things a little bit:\n        for x_size in (\n                # 1d tests\n                (1,), (3,), (7,),\n                # 2d tests\n                (1, 3), (3, 3), (12, 3), (20, 17),\n                # 3d tests\n                (1, 1, 3), (3, 5, 4), (20, 17, 14),\n            ),\n            C_in in (1, 3),\n            batch in (1, 5)\n\n            # Allocate x in this outer loop to save on allocations and speed things up\n            x = rand(x_size..., C_in, batch)\n            dx_direct = similar(x)\n            dx_im2col = similar(x)\n\n            for w_size in (\n                    (1,), (3,), (7,),\n                    (1,1), (1,3), (3,4), (7, 4),\n                    (1,1,1), (1,1,3,), (3,4,3), (7,3,2)),\n                C_mult in (1, 4)\n\n                # Give some output to the user that something is in fact happening.\n                print(\".\")\n\n                # Allocate w in this outer loop to save on allocations and speed things up\n                w = rand(w_size..., C_mult, C_in)\n                dw_direct = similar(w)\n                dw_im2col = similar(w)\n\n                for S_size in (1, 2, 4, (1,2), (4,1), (2,1,4)),\n                    P_size in (0, 1, 2, (0,3,0,3), (4,1,4,2), (1,2,3,4,5,6)),\n                    D_size in (1, 2, 4, (1,2), (3,2), (4,2,3))\n\n                    # Skip tests that are impossible due to mismatched sizes\n                    try\n                        DepthwiseConvDims(x, w;\n                            stride=S_size, padding=P_size, dilation=D_size,\n                        )\n                    catch e\n                        if isa(e, DimensionMismatch) || isa(e, MethodError)\n                            continue\n                        end\n                        rethrow(e)\n                    end\n\n                    # Do the actual convolution, comparing convolution implementations\n                    cdims = DepthwiseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size)\n\n                    # We use mutating calls with explicitly different initial values, so as\n                    # to be sure to catch when we're leaving pieces of the output untouched.\n                    y_direct = ones(output_size(cdims)..., channels_out(cdims), batch) .* 666.666\n                    y_im2col = ones(output_size(cdims)..., channels_out(cdims), batch) .* 777.777\n\n                    # Do the convolutions\n                    NNlib.depthwiseconv_direct!(y_direct, x, w, cdims)\n                    NNlib.depthwiseconv_im2col!(y_im2col, x, w, cdims)\n\n                    # Compare!\n                    @test y_direct ≈ y_im2col\n                    dy = y_im2col\n\n                    # Now push backwards; first for the filter.  Again, we initialize our\n                    # memory so that segments that never get touched are immediately noticable\n                    fill!(dw_direct, 666.666)\n                    fill!(dw_im2col, 777.777)\n                    NNlib.∇depthwiseconv_filter_direct!(dw_direct, x, dy, cdims)\n                    NNlib.∇depthwiseconv_filter_im2col!(dw_im2col, x, dy, cdims)\n                    @test dw_direct ≈ dw_im2col\n\n                    # And then for the input\n                    fill!(dx_direct, 666.666)\n                    fill!(dx_im2col, 777.777)\n                    NNlib.∇depthwiseconv_data_direct!(dx_direct, dy, w, cdims)\n                    NNlib.∇depthwiseconv_data_im2col!(dx_im2col, dy, w, cdims)\n                    @test dx_direct ≈ dx_im2col\n                end\n            end\n        end\n        println()\n    end\nelse\n    @info \"Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them\"\nend\n\n@testset \"Grouped Convolutions\" begin\n   x′ = rand(Float32, 28, 28, 100, 2)\n   w′ = rand(Float32, 3, 3, 20, 15)\n\n   @test_throws DimensionMismatch DenseConvDims(x′, w′)\n   cdims = DenseConvDims(x′, w′, groups = 5)\n\n   @test groupcount(cdims) == 5\n\n   y = conv(x′, w′, cdims)\n   _, back = Zygote.pullback((x, w) -> sum(conv(x, w, cdims)), x′, w′)\n   gs_x, gs_w = back(1.f0)\n\n\n   ips = Iterators.partition(1:100, 20)\n   ops = Iterators.partition(1:15, 3)\n   for (i,o) in zip(ips,ops)\n      _, back_reg = Zygote.pullback((x, w) -> sum(conv(x, w)), x′[:,:,i,:], w′[:,:,:,o])\n      gs_x_reg, gs_w_reg = back_reg(1.f0)\n      @test conv(x′[:,:,i,:], w′[:,:,:,o]) ≈ y[:,:,o,:]\n      @test gs_x_reg ≈ gs_x[:,:,i,:]\n      @test gs_w_reg ≈ gs_w[:,:,:,o]\n   end\n\n   # Currently hangs due to a FiniteDifferences issue\n   @test_skip gradtest((x, w) -> sum(conv(x, w, cdims)), x′, w′)\nend\n\n@testset \"conv_wrapper\" begin\n    x = rand(10, 10, 3, 10)\n    w = rand(2, 2, 3, 16)\n    w1 = rand(3, 4, 3, 16)\n    @test size(conv(x, w)) == (9, 9, 16, 10)\n    @test size(conv(x, w; stride = (2, 2), pad = (2, 2))) == (7, 7, 16, 10)\n    @test size(conv(x, w1; stride = (1, 2), pad = (2, 3))) == (12, 7, 16, 10)\n    @test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2))) == (12, 7, 16, 10)\n    @test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (12, 7, 16, 10)\nend\n\n# https://github.com/FluxML/NNlib.jl/issues/369\n@testset \"conv_wrapper with groups - not equal types that trigger direct backend\" begin\n    x = rand(Float32, 10, 10, 32, 8)\n    w = rand(Float64, 2, 2, 16, 4)\n    g = 2\n    @test conv(x, w; groups=g) ≈ conv(x, Float32.(w); groups=g)\n    @test conv(x, w; stride = (2, 2), pad = (2, 2), groups=g) ≈ conv(x, w; stride = (2, 2), pad = (2, 2), groups=g)\n    @test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g) ≈ conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g)\n    @test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g) ≈ conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g)\nend\n\n@testset \"depthwiseconv_wrapper\" begin\n    x = rand(10, 10, 3, 10)\n    w = rand(2, 2, 3, 3)\n    w1 = rand(3, 4, 3, 3)\n    @test size(depthwiseconv(x, w)) == (9, 9, 9, 10)\n    @test size(depthwiseconv(x, w; stride = (2, 2), pad = (2, 2))) == (7, 7, 9, 10)\n    @test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3))) == (12, 7, 9, 10)\n    @test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3), dilation = (2, 2))) == (10, 5, 9, 10)\n    @test size(depthwiseconv(x, w1; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (10, 5, 9, 10)\nend\n\n# https://github.com/FluxML/NNlib.jl/pull/171\n@testset \"conv_direct! - Check Sizes\" begin\n    x_size = (6, 7, 8, 5, 3)\n    y_size = (5, 6, 7, 4, 3)\n    w_size = (2, 2, 2, 5, 4)\n    x = randn(Float32, x_size);\n    y = randn(Float32, y_size);\n    w = randn(Float32, w_size);\n    cdims = DenseConvDims(x_size, w_size)\n    @test size(NNlib.conv_direct!(y, x, w, cdims)) == y_size\n    @test size(NNlib.∇conv_data_direct!(x, y, w, cdims)) == x_size\n    @test size(NNlib.∇conv_filter_direct!(w, x, y, cdims)) == w_size\nend\n\n# https://github.com/FluxML/NNlib.jl/issues/490\n# https://github.com/FluxML/NNlib.jl/issues/405\n@testset \"conv_direct! - Unusual input types\" begin\n    # Create test type that can't be indexed when undefined.\n    # This simulates the worst-case scenario for custom types.\n    struct MyFloat <: Real\n        set::Set{Float32}\n    end\n\n    # Test that direct indexing fails when undefined.\n    v = Array{MyFloat}(undef, 3)\n    @test_throws UndefRefError v[1]\n\n    # Define minimal set of functions required for conv_direct!\n    MyFloat(x::MyFloat) = x\n    MyFloat(x::Real) = MyFloat(Set(Float32(x)))\n\n    Base.:+(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) + only(y.set))\n    Base.:*(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) * only(y.set))\n    Base.promote_rule(::Type{MyFloat}, ::Type{Float32})   = MyFloat\n    Base.rand(::AbstractRNG, ::SamplerType{MyFloat}) = MyFloat(rand(Float32))\n    Base.zero(::MyFloat) = MyFloat(zero(Float32))\n    Base.zero(::Type{MyFloat}) = MyFloat(zero(Float32))\n\n    # Test conv_direct!\n    x_size = (6, 7, 8, 5, 3)\n    y_size = (5, 6, 7, 4, 3)\n    w_size = (2, 2, 2, 5, 4)\n    x = rand(MyFloat, x_size);\n    w = randn(Float32, w_size);\n    y = Array{MyFloat}(undef, y_size...);\n    cdims = DenseConvDims(x_size, w_size)\n    y_out = NNlib.conv_direct!(y, x, w, cdims)\n\n    @test eltype(y_out) == MyFloat\n    @test size(y_out) == y_size\nend\n\n@testset \"AutoDiff: spatial_rank=$spatial_rank\" for spatial_rank in (1, 2, 3)\n  x = rand(rng, repeat([5], spatial_rank)..., 3, 2)\n  w = rand(rng, repeat([3], spatial_rank)..., 3, 3)\n  cdims = DenseConvDims(x, w)\n  gradtest((x, w) -> conv(x, w, cdims), x, w)\n  gradtest((x, w) -> sum(conv(x, w, cdims)), x, w)  # https://github.com/FluxML/Flux.jl/issues/1055\n\n  y = conv(x, w, cdims)\n  gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)\n  gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)\n  gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y)\n  gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y)\n\n  dcdims = DepthwiseConvDims(x, w)\n  gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)\n\n  # FIXME fails\n  y = depthwiseconv(x, w, dcdims)\n  gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)\n  gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)\nend\n\n@static if Test_Enzyme\n\n@testset \"EnzymeRules: conv! spatial_rank=$spatial_rank\" for spatial_rank in (1, 2, 3)\n  x = rand(rng, repeat([5], spatial_rank)..., 3, 2)\n  w = rand(rng, repeat([3], spatial_rank)..., 3, 3)\n\n  cdims = DenseConvDims(x, w)\n\n  curconv = conv\n  curconv! = conv!\n  dst = curconv(x, w, cdims)\n\n  for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)\n\n    Tret == EnzymeCore.Const && continue # ERROR\n    EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue\n\n    EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const), atol=1e-6, rtol=1e-6)\n  end\nend\n\n@testset \"EnzymeRules: ∇conv_data! spatial_rank=$spatial_rank\" for spatial_rank in (1, 2, 3)\n  x = rand(rng, repeat([5], spatial_rank)..., 3, 2)\n  w = rand(rng, repeat([3], spatial_rank)..., 3, 3)\n  cdims = DenseConvDims(x, w)\n  y = conv(x, w, cdims)\n  dy = randn(rng, size(y)...)\n\n  dx = ∇conv_data(dy, w, cdims)\n\n  for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)\n\n    Tret == EnzymeCore.Const && continue # ERROR\n    EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Ty, Tw) || continue\n\n    EnzymeTestUtils.test_reverse(∇conv_data!, Tret, (dx, Tdst), (dy, Ty), (w, Tw), (cdims, EnzymeCore.Const), atol=1e-6, rtol=1e-6)\n  end\nend\n\n@testset \"EnzymeRules: ∇conv_filter! spatial_rank=$spatial_rank\" for spatial_rank in (1, 2, 3)\n  x = rand(rng, repeat([5], spatial_rank)..., 3, 2)\n  w = rand(rng, repeat([3], spatial_rank)..., 3, 3)\n  cdims = DenseConvDims(x, w)\n  y = conv(x, w, cdims)\n  dy = randn(rng, size(y)...)\n\n  dw = ∇conv_filter(x, dy, cdims)\n\n  for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)\n\n    Tret == EnzymeCore.Const && continue # ERROR\n    EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Ty) || continue\n\n    EnzymeTestUtils.test_reverse(∇conv_filter!, Tret, (dw, Tdst), (x, Tx), (dy, Ty), (cdims, EnzymeCore.Const), atol=1e-6, rtol=1e-6)\n  end\nend\n\n@testset \"EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank\" for spatial_rank in (1, 2, 3)\n  x = rand(rng, repeat([5], spatial_rank)..., 3, 2)\n  w = rand(rng, repeat([3], spatial_rank)..., 3, 3)\n\n  cdims = DepthwiseConvDims(x, w)\n\n  curconv = depthwiseconv\n  curconv! = depthwiseconv!\n  dst = curconv(x, w, cdims)\n\n  for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)\n\n    Tret == EnzymeCore.Const && continue # ERROR\n    EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue\n\n    EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const), atol=1e-6, rtol=1e-6)\n  end\nend\n\nend\n"
  },
  {
    "path": "test/conv_bias_act.jl",
    "content": "@testset \"conv_bias_act\" begin\n    x = rand(4,4,3,3)\n    w = rand(2,2,3,3)\n    b = rand(1,1,1,3)\n    cdims = DenseConvDims(x, w; stride=2)\n    @test NNlib.conv_bias_act(x, w, cdims, b, relu) ≈ relu.(conv(x, w, cdims) .+ b) atol=1e-5\nend\n"
  },
  {
    "path": "test/ctc.jl",
    "content": "using Test\nusing NNlib: ctc_loss\nusing Zygote: gradient\nusing LinearAlgebra\n\n# Custom function to check numerical gradient of ctc loss,\n# based on `ngradient` in `Tracker.jl`\nfunction ctc_ngradient(x, y)\n  f = ctc_loss\n  grads = zero(x)\n  for i in 1:length(x)\n    δ = sqrt(eps())\n    tmp = x[i]\n    x[i] = tmp - δ/2\n    y1 = f(x, y)\n    x[i] = tmp + δ/2\n    y2 = f(x, y)\n    x[i] = tmp\n    grads[i] = (y2-y1)/δ\n  end\n  return grads\nend\n\n@testset \"ctc_loss\" begin\n  x = rand(10, 50)\n  y = rand(1:9, 30)\n  g1 = gradient(ctc_loss, x, y)[1]\n  g2 = ctc_ngradient(x, y)\n  @test g1 ≈ g2 rtol=1e-5 atol=1e-5\n  \n  # tests using hand-calculated values\n  x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.]\n  y = [1, 2]\n  @test ctc_loss(x, y) ≈ 3.6990738275138035\n\n  g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]\n  ghat = gradient(ctc_loss, x, y)[1]\n  @test g ≈ ghat rtol=1e-5 atol=1e-5\n\n  x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.]\n  y = [1, 2]\n  @test ctc_loss(x, y) ≈ 8.02519869363453\n\n  g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]\n  ghat = gradient(ctc_loss, x, y)[1]\n  @test g ≈ ghat rtol=1e-5 atol=1e-5\nend"
  },
  {
    "path": "test/dropout.jl",
    "content": "using NNlib, Test, Statistics, Random, LinearAlgebra\nusing Zygote, StableRNGs, ChainRulesCore, Enzyme\n\n@testset \"dropout\" begin\n    # Basics\n    x1 = randn(Float32, 3, 4)\n    @test size(@inferred dropout(x1, 0.1)) == (3, 4)\n    @test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4)\n    @test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4)\n    @test eltype(dropout(x1, 0.1)) == Float32\n    @test eltype(dropout(x1, 0.1; dims=1)) == Float32\n    @test eltype(dropout(x1, 0.1; dims=(1,2))) == Float32\n\n    rng =  Random.default_rng()\n    @test size(@inferred dropout(rng, x1, 0.1)) == (3, 4)\n    @test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)\n\n    x2 = Diagonal(randn(Float32, 10))  # Just to check it runs on weird matrices.\n    @test dropout(x2, 0.3) isa Matrix{Float32}  # does not infer, but that's OK?\n    \n    # Values\n    @test dropout(x1, 0) == x1\n    @test dropout(x1.+0im, 0) == x1\n    @test dropout(x1, 1) == zero.(x1)\n    @test dropout(x1.+im, 1) == zero.(x1)\n\n    d45 = dropout(trues(100, 100, 100), 0.45)\n    @test mean(d45) ≈ 1 atol=1e-2\n    dpi2 = dropout(fill(pi, 1000), 0.2)\n    @test sort(unique(dpi2)) ≈ [0, 5pi/4]\n    d33 = dropout(fill(3, 10, 1000), 0.3, dims=2)\n    @test sort(unique(vec(d33))) ≈ [0, 3/(1-0.3)]\n\n    # Complex -- not worth too much optimisation, but should work!\n    x2 = [1.0+0im,2.0+1im,3.0+3im]  # from Flux's tests\n    @test dropout(x2, 0.5) isa Vector{ComplexF64}\n    @test dropout(x2, 0.5; dims=1) isa Vector{ComplexF64}\n\n    # Gradient rule\n    y, back = rrule(dropout, rng, hcat(trues(1000), falses(1000)), 0.45)\n    dx = back(fill(3, 1000, 2))[3]\n    @test !all(iszero, dx[:,2])  # this is why we save the random choices\n    @test sort(unique(vec(dx))) ≈ [0, 3/(1-0.45)]\n\n    y2, back2 = rrule(dropout, rng, x2, 0.5)\n    @test y2 isa Vector{ComplexF64}\n    @test back2(one.(y2))[3] isa Vector{ComplexF64}\n\n    @testset \"Zygote\" begin\n        @test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa Matrix{Float32}\n        @test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa Matrix{Float32}\n        @test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa Matrix{Float32}\n\n        # p=0 & p=1\n        @test Zygote.gradient(x -> sum(dropout(x, 0)), x1)[1] == ones(3,4)\n        @test Zygote.gradient(x -> sum(dropout(x, 1)), x1)[1] == zeros(3,4)\n\n        # Second order\n        f1(x) = sum(dropout(x, 0.5))\n        @test_broken Zygote.hessian(f1, [1.0,2.0,3.0]) == zeros(3, 3)  # forward over reverse\n        @test Zygote.hessian_reverse(f1, [1.0,2.0,3.0]) == zeros(3, 3)\n    end\n\n    # Bang\n    y1 = fill!(similar(x1), NaN)\n    @test dropout!(y1, x1, 0.0) == x1\n    @test y1 == x1\n    @test dropout!(rng, y1, x1, 1) == zero(x1)\n    @test y1 == zero(x1)\n\n    # Errors\n    @test_throws ArgumentError dropout(x1, -1)\n    @test_throws ArgumentError dropout(x1, 2)\n    @test_throws ArgumentError dropout!(y1, x1, 3)\nend\n\n@static if Test_Enzyme\n\n@testset \"EnzymeRules: dropout \" begin\n    rng = Random.default_rng()\n\n    x1 = randn(Float32, 3000, 4000)\n    dx1 = zeros(Float32, 3000, 4000)\n\n    dout = randn(Float32, 3000, 4000)\n\n    p = 0.2f0\n\n    forward, reverse = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, typeof(Const(dropout)), Duplicated, typeof(Const(rng)), typeof(Duplicated(x1, dx1)), typeof(Const(0.2f0)))\n\n    tape, primal, shadow = forward(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p))\n\n    shadow .= dout\n\n    reverse(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p), tape)\n\n    @test dx1[.!tape[1]] ≈ zero(x1)[.!tape[1]]\n\n    val = convert(Float32, 1/(1-p))\n\n    @test dx1[tape[1]] ≈ (val * dout)[tape[1]]\nend\n\nend"
  },
  {
    "path": "test/ext_amdgpu/activations.jl",
    "content": "@testset \"Compare CPU & GPU\" begin\n    for (T, atol) in ((Float16, 1.0f-2), (Float32, 1.0f-5))\n        @testset \"ndims: $(ndims(x))\" for x in (randn(T, 16), randn(T, ntuple(_ -> 2, 5)...), randn(T, ntuple(_ -> 2, 6)...))\n            gputest(x -> NNlib.relu.(x), x; atol)\n            gputest(x -> NNlib.relu6.(x), x; atol)\n            gputest(x -> NNlib.softplus.(x), x; atol)\n            gputest(x -> tanh.(x), x; atol)\n            gputest(x -> identity.(x), x; atol)\n        end\n    end\nend\n"
  },
  {
    "path": "test/ext_amdgpu/attention.jl",
    "content": "@testset \"Compare CPU & GPU\" begin\n    n = 15\n    lenq = 3\n    lenkv = 4\n    for batch_size in [(), 1, 2, (2, 1, 3)], nheads in [1, 3, 5]\n        q = AMDGPU.rand(Float32, n, lenq, batch_size...)\n        k = AMDGPU.rand(Float32, n, lenkv, batch_size...)\n        v = AMDGPU.rand(Float32, n, lenkv, batch_size...)\n        y, α = @inferred dot_product_attention(q, k, v; nheads)\n\n        @test y isa ROCArray{Float32}\n        @test size(y) == (n, lenq, batch_size...)\n        @test size(α) == (lenkv, lenq, nheads, batch_size...)\n        @test sum(Array(α), dims=1) ≈ ones(1, lenq, nheads, batch_size...)\n\n        qh = rand(Float32, n, lenq, batch_size...)\n        kh = rand(Float32, n, lenkv, batch_size...)\n        vh = rand(Float32, n, lenkv, batch_size...)\n        gputest(\n            (x...) -> dot_product_attention(x...; nheads)[1], qh, kh, vh;\n            atol=1f-5)\n    end\nend\n\n@testset \"Mask\" begin\n    x = AMDGPU.rand(Float32, 4, 2, 3, 1)\n    mask = make_causal_mask(x, dims=3)\n    @test mask isa ROCArray{Bool}\n    α = dot_product_attention_scores(x, x; mask)\n\n    α_host, mask_host = Array.((α, mask))\n    @test all((α_host[:, :, 1, 1] .> 0) .== mask_host)\n    @test all((α_host[:, :, 2, 1] .> 0) .== mask_host)\nend\n\n@testset \"Dropout\" begin\n    q = k = v = AMDGPU.rand(Float32, 10, 10, 10)\n    fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)\n    y, α = dot_product_attention(\n        q, k, v; nheads=2, fdrop=x -> dropout(x, 0.5))\n    @test 0.6 > mean(>(0), α) > 0.4\nend\n"
  },
  {
    "path": "test/ext_amdgpu/batched_mul.jl",
    "content": "@testset \"batched_mul\" begin\n    A = rand(Float32, 3, 3, 2)\n    B = rand(Float32, 3, 3, 2)\n    dA, dB = ROCArray.((A, B))\n\n    C = batched_mul(A, B)\n    @test ROCArray(C) ≈ batched_mul(dA, dB)\n\n    Ct = batched_mul(batched_transpose(A), B)\n    @test ROCArray(Ct) ≈ batched_mul(batched_transpose(dA), dB)\n\n    Ca = batched_mul(A, batched_adjoint(B))\n    @test ROCArray(Ca) ≈ batched_mul(dA, batched_adjoint(dB))\n\n    # 5-arg batched_mul!\n    C .= pi\n    batched_mul!(C, A, B, 2f0, 3f0)\n    Cpi = ROCArray(similar(C)) .= pi\n    @test ROCArray(C) ≈ batched_mul!(Cpi, dA, dB, 2f0, 3f0)\n\n    # PermutedDimsArray\n    @test ROCArray(Ct) ≈ batched_mul(PermutedDimsArray(dA, (2, 1, 3)), dB)\n\n    # FIXME same but with (1, 3, 2) errors\n    D = permutedims(B, (2, 1, 3))\n    Cp = batched_mul(batched_adjoint(A), B)\n    @test ROCArray(Cp) ≈ batched_mul(\n        batched_adjoint(dA), PermutedDimsArray(ROCArray(D), (2, 1, 3)))\n\n    # Methods which reshape\n    M = randn(Float32, 3, 3)\n    Cm = batched_mul(A, M)\n    @test ROCArray(Cm) ≈ batched_mul(dA, ROCArray(M))\nend\n"
  },
  {
    "path": "test/ext_amdgpu/batched_repr.jl",
    "content": "function print_array_strs(x)\n    str = sprint((io, x)->show(io, MIME\"text/plain\"(), x), x)\n    return @view split(str, '\\n')[2:end]\nend\n\n@testset \"BatchedAdjOrTrans\" begin\n    x = rand(Float32, 3, 4, 2)\n    y = ROCArray(x)\n\n    bax = batched_adjoint(x)\n    btx = batched_transpose(x)\n    bay = batched_adjoint(y)\n    bty = batched_transpose(y)\n\n    @test sprint(show, bax) == sprint(show, bay)\n    @test sprint(show, btx) == sprint(show, bty)\n\n    @test print_array_strs(bax) == print_array_strs(bay)\n    @test print_array_strs(btx) == print_array_strs(bty)\n\n    @test Array(bax) == Array(bay)\n    @test collect(bax) == collect(bay)\n    @test Array(btx) == Array(bty)\n    @test collect(btx) == collect(bty)\n\n    for shape in (:, (12, 2))\n        rbax = reshape(bax, shape)\n        rbtx = reshape(btx, shape)\n        rbay = reshape(bay, shape)\n        rbty = reshape(bty, shape)\n\n        @test sprint(show, rbax) == sprint(show, rbay)\n        @test sprint(show, rbtx) == sprint(show, rbty)\n\n        @test print_array_strs(rbax) == print_array_strs(rbay)\n        @test print_array_strs(rbtx) == print_array_strs(rbty)\n\n        @test Array(rbax) == Array(rbay)\n        @test collect(rbax) == collect(rbay)\n        @test Array(rbtx) == Array(rbty)\n        @test collect(rbtx) == collect(rbty)\n    end\nend\n"
  },
  {
    "path": "test/ext_amdgpu/conv.jl",
    "content": "@testset \"Compare CPU & GPU\" begin\n    channels, batch = 3, 2\n    for T in (Float16, Float32), nd in (1, 2, 3)\n        x = rand(Float32, fill(4, nd)..., 3, 1)\n        w = rand(Float32, fill(2, nd)..., channels, 4)\n\n        cdims = DenseConvDims(x, w, flipkernel=true)\n        gputest((x, w) -> NNlib.conv(x, w, cdims), x, w; atol=1e-4)\n\n        # This one flips manually kernel for AMDGPU.\n        cdims = DenseConvDims(x, w)\n        gputest((x, w) -> NNlib.conv(x, w, cdims), x, w; atol=1e-4)\n    end\nend\n"
  },
  {
    "path": "test/ext_amdgpu/dropout.jl",
    "content": "@testset \"Test API\" begin\n    x = AMDGPU.randn(Float32, 3, 4)\n    @test size(@inferred dropout(x, 0.1)) == (3, 4)\n    @test size(@inferred dropout(x, 0.2; dims=2)) == (3, 4)\n    @test size(@inferred dropout(x, 0.3; dims=(1, 2))) == (3, 4)\n\n    rng = AMDGPU.rocrand_rng()\n    @test size(@inferred dropout(rng, x, 0.1)) == (3, 4)\n    @test size(@inferred dropout(rng, x, 0.1; dims=2)) == (3, 4)\n\n    # Values\n    d45 = dropout(AMDGPU.ones(100, 100, 100), 0.45)\n    @test mean(d45) ≈ 1 atol=1e-2\n    dpi2 = dropout(AMDGPU.fill(1f0 * pi, 1000), 0.2)\n    @test sort(unique(Array(dpi2))) ≈ [0, 5 * pi / 4]\n    d33 = dropout(AMDGPU.fill(3f0, 10, 1000), 0.3, dims=2)\n    @test sort(unique(vec(Array(d33)))) ≈ [0, 3 / (1 - 0.3)]\n\n    @test Zygote.gradient(x -> sum(dropout(x, 0.1)), x)[1] isa ROCArray{Float32}\nend\n"
  },
  {
    "path": "test/ext_amdgpu/pool.jl",
    "content": "@testset \"Compare CPU & GPU\" begin\n    channels, batch = 3, 2\n    for T in (Float16, Float32), nd in (1, 2, 3)\n        x = rand(T, fill(8, nd)..., channels, batch)\n        pdims = PoolDims(x, 2)\n        # NOTE: Disable grad check for maxpool as *sometimes*\n        # it does not *completely* agree with CPU :/\n        gputest(x -> NNlib.maxpool(x, pdims), x; checkgrad=false)\n        gputest(x -> NNlib.meanpool(x, pdims), x)\n    end\nend\n"
  },
  {
    "path": "test/ext_amdgpu/runtests.jl",
    "content": "using NNlib: batched_adjoint, batched_mul, batched_mul!, batched_transpose\nusing NNlib: is_strided, storage_type\nusing LinearAlgebra\n\nAMDGPU.allowscalar(false)\n\nfunction gputest(f, xs...; checkgrad=true, atol=1e-6, kws...)\n    cpu_in = xs\n    gpu_in = ROCArray.(xs)\n\n    cpu_out = f(cpu_in...; kws...)\n    gpu_out = f(gpu_in...; kws...)\n    @test collect(cpu_out) ≈ collect(gpu_out)\n\n    if checkgrad\n        cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...)\n        gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...)\n        for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad)\n            if cpu_g === nothing\n                @test gpu_g === nothing\n            else\n                @test collect(cpu_g) ≈ collect(gpu_g) atol=atol\n            end\n        end\n    end\nend\n\n@testset \"Storage types\" begin\n    include(\"storage_type.jl\")\nend\n\n@testset \"Batched repr\" begin\n    include(\"batched_repr.jl\")\nend\n\n@testset \"Batched multiplication\" begin\n    include(\"batched_mul.jl\")\nend\n\n@testset \"Convolution\" begin\n    include(\"conv.jl\")\nend\n\n@testset \"Pooling\" begin\n    include(\"pool.jl\")\nend\n\n@testset \"Softmax\" begin\n    include(\"softmax.jl\")\nend\n\n@testset \"Activations\" begin\n    include(\"activations.jl\")\nend\n\n@testset \"Dropout\" begin\n    include(\"dropout.jl\")\nend\n\n@testset \"Attention\" begin\n    include(\"attention.jl\")\nend\n"
  },
  {
    "path": "test/ext_amdgpu/softmax.jl",
    "content": "@testset \"Compare CPU & GPU\" begin\n    for (T, atol) in ((Float16, 1f-2), (Float32, 1f-5))\n        for (sz, dims) in [\n            ((5,), :), ((5,), 1),\n            ((5, 5), :), ((5, 5), 1), ((5, 5), 2),\n            ((5, 5, 5, 5), (2, 3)), ((5, 5, 5, 5), (2, 4)),\n        ]\n            if T == Float16\n                x = ones(T, sz) # Really low precision.\n            else\n                x = randn(T, sz)\n            end\n            gputest(NNlib.softmax, x; atol)\n            gputest(NNlib.logsoftmax, x; atol)\n        end\n    end\nend\n"
  },
  {
    "path": "test/ext_amdgpu/storage_type.jl",
    "content": "@testset \"NNlib storage type\" begin\n    x = ROCArray(ones(Float32, 10, 10))\n    @test storage_type(x) <: ROCArray{Float32, 2}\n    @test storage_type(reshape(view(x, 1:2:10,:), 10, :)) <: ROCArray{Float32, 2}\n\n    @test is_strided(x)\n    @test is_strided(view(x, 1:2:5,:))\n    @test is_strided(PermutedDimsArray(x, (2, 1)))\n\n    @test !is_strided(reshape(view(x, 1:2:10, :), 10, :))\n    @test !is_strided((x .+ im)')\n    @test !is_strided(Diagonal(ROCArray(ones(3))))\nend\n"
  },
  {
    "path": "test/ext_cuda/activations.jl",
    "content": "@testset \"activation broadcast\" begin\n    for f in NNlib.ACTIVATIONS\n        if f ∉ [:rrelu]\n            @eval gputest(x -> $f.(x), rand(Float64, 5))\n        end\n    end\nend\n\n@testset \"forward diff\" begin\n    for f in NNlib.ACTIVATIONS\n        if f ∉ [:rrelu]\n            @eval gputest(x -> $f.(x), Dual.(rand(5), 1))\n        end\n    end\nend\n\n# Broadcasting over complex CuArray works without NNlibCUDAExt, this test checks that\n# NNlibCUDAExt does not cause such operations to take a fast path which does not support\n# complex numbers (e.g. cuDNN)\n@testset \"complex\" begin\n    f(x) = tanh.(x)\n    cs = rand(ComplexF64, 5)\n    @test f(cs) ≈ collect(f(CuArray(cs)))\nend\n\n@testset \"softplus\" begin \n  # softplus does not give `Inf` for large arguments\n   x = CuArray([1000.])\n   @test all(softplus.(x) .== x)\nend\n\n@testset \"input is preserved\" begin\n    x = CUDA.ones(1)\n    @test Array(x) == [1f0]\n    tanh.(x)\n    @test Array(x) == [1f0]\n    y = tanh.(x)\n    @test Array(x) == [1f0]\n    @test Array(y) == [tanh(1f0)]\n    x .= tanh.(y)\n    @test Array(y) == [tanh(1f0)]\n    @test Array(x) == [tanh(tanh(1f0))]\nend\n\n@testset \"fused act addition broadcast\" begin\n    x = CUDA.rand(Float32, 10, 10)\n    b = CUDA.rand(Float32, 10)\n\n    for act in getfield.((NNlib,), NNlib.ACTIVATIONS)\n        fused_act_add = act ∘ +\n        @test fused_act_add.(x, b) ≈ act.(x .+ b)\n    end\nend\n"
  },
  {
    "path": "test/ext_cuda/batchedadjtrans.jl",
    "content": "function print_array_strs(x)\n    str = sprint((io, x)->show(io, MIME\"text/plain\"(), x), x)\n    return @view split(str, '\\n')[2:end]\nend\n\n@testset \"BatchedAdjOrTrans\" begin\n    x = randn(Float32, 3,4,2)\n    y = cu(x)\n\n    bax = batched_adjoint(x)\n    btx = batched_transpose(x)\n    bay = batched_adjoint(y)\n    bty = batched_transpose(y)\n\n    @test sprint(show, bax) == sprint(show, bay)\n    @test sprint(show, btx) == sprint(show, bty)\n\n    @test print_array_strs(bax) == print_array_strs(bay)\n    @test print_array_strs(btx) == print_array_strs(bty)\n\n    @test Array(bax) == Array(bay)\n    @test collect(bax) == collect(bay)\n    @test Array(btx) == Array(bty)\n    @test collect(btx) == collect(bty)\n    \n    for shape in (:, (12, 2))\n        rbax = reshape(bax, shape)\n        rbtx = reshape(btx, shape)\n        rbay = reshape(bay, shape)\n        rbty = reshape(bty, shape)\n\n        @test sprint(show, rbax) == sprint(show, rbay)\n        @test sprint(show, rbtx) == sprint(show, rbty)\n    \n        @test print_array_strs(rbax) == print_array_strs(rbay)\n        @test print_array_strs(rbtx) == print_array_strs(rbty)\n    \n        @test Array(rbax) == Array(rbay)\n        @test collect(rbax) == collect(rbay)\n        @test Array(rbtx) == Array(rbty)\n        @test collect(rbtx) == collect(rbty)\n    end\n\nend\n"
  },
  {
    "path": "test/ext_cuda/batchedmul.jl",
    "content": "@testset \"batched_mul\" begin\n    using NNlib: batched_mul, batched_mul!, batched_vec, \n                 batched_adjoint, batched_transpose\n\n    A = randn(Float32, 3,3,2);\n    B = randn(Float32, 3,3,2);\n\n    C = batched_mul(A, B)\n    @test CuArray(C) ≈ batched_mul(CuArray(A), CuArray(B))\n\n    Ct = batched_mul(batched_transpose(A), B)\n    @test CuArray(Ct) ≈ batched_mul(batched_transpose(CuArray(A)), CuArray(B))\n\n    Ca = batched_mul(A, batched_adjoint(B))\n    @test CuArray(Ca) ≈ batched_mul(CuArray(A), batched_adjoint(CuArray(B)))\n\n    # 5-arg batched_mul!\n    C .= pi\n    batched_mul!(C, A, B, 2f0, 3f0)\n    cuCpi = CuArray(similar(C)) .= pi\n    @test CuArray(C) ≈ batched_mul!(cuCpi, CuArray(A), CuArray(B), 2f0, 3f0)\n\n    # PermutedDimsArray\n    @test CuArray(Ct) ≈ batched_mul(PermutedDimsArray(CuArray(A), (2,1,3)), CuArray(B))\n\n    D = permutedims(B, (1,3,2))\n    Cp = batched_mul(batched_adjoint(A), B)\n    @test CuArray(Cp) ≈ batched_mul(batched_adjoint(CuArray(A)), PermutedDimsArray(CuArray(D), (1,3,2)))\n\n    # Methods which reshape\n    M = randn(Float32, 3,3)\n\n    Cm = batched_mul(A, M)\n    @test CuArray(Cm) ≈ batched_mul(CuArray(A), CuArray(M))\n\n    Cv = batched_vec(permutedims(A,(3,1,2)), M)\n    @test CuArray(Cv) ≈ batched_vec(PermutedDimsArray(CuArray(A),(3,1,2)), CuArray(M))\nend\n\n@testset \"NNlib storage_type etc.\" begin\n    using LinearAlgebra\n    using NNlib: is_strided, are_strided, storage_type\n\n    M = cu(ones(10,10))\n\n    @test is_strided(M)\n    @test is_strided(view(M, 1:2:5,:))\n    @test is_strided(PermutedDimsArray(M, (2,1)))\n\n    @test !is_strided(reshape(view(M, 1:2:10,:), 10,:))\n    @test !is_strided((M .+ im)')\n    @test !is_strided(Diagonal(cu(ones(3))))\n\n    @test storage_type(M) <: CuArray{Float32,2}\n    @test storage_type(reshape(view(M, 1:2:10,:), 10,:)) <: CuArray{Float32,2}\nend\n"
  },
  {
    "path": "test/ext_cuda/batchnorm.jl",
    "content": "using Statistics\n\n@testset \"Batchnorm\" begin\n    v = CUDA.rand(Float32, 2)\n    m = CUDA.rand(Float32, 2, 5)\n\n    @testset for training in (true, false), track_stats in (true, false)\n        kws = (training=training, track_stats=track_stats)\n\n        # Normal\n        batchnorm(v, v, m, v, v, 1.0; kws...)\n        ∇batchnorm(v, v, m, m, v, v, 1.0; kws...)\n\n        # No affine\n        batchnorm(nothing, nothing, m, v, v, 1.0; kws...)\n        ∇batchnorm(nothing, nothing, m, m, v, v, 1.0; kws...)\n\n        # No tracking\n        batchnorm(v, v, m, nothing, nothing, 1.0; kws...)\n        ∇batchnorm(v, v, m, m, nothing, nothing, 1.0; kws...)\n\n        # Both or neither tracked or affine params must be set\n        for (α, β) in ((v, nothing), (nothing, v))\n            @test_throws MethodError batchnorm(α, β, m, v, v, 1.0; kws...)\n            @test_throws MethodError ∇batchnorm(α, β, m, m, v, v, 1.0; kws...)\n            @test_throws ArgumentError batchnorm(v, v, m, α, β, 1.0; kws...)\n        end\n    end \n    @testset \"test mode\" begin\n        y_no_track_stats = batchnorm(v, v, m, nothing, nothing, 1.0; training=false, track_stats=false)\n        running_mean = mean(m, dims=[2])\n        running_var = var(m, mean=running_mean, dims=[2], corrected=false)\n        y_track_stats = batchnorm(v, v, m, running_mean, running_var, 1.0; training=false, track_stats=true)\n        # batchnorm without tracked stats should equal bathnorm with tracked stats where the\n        # stats are calculated only on the input.\n        @test y_no_track_stats ≈ y_track_stats\n    end\nend\n"
  },
  {
    "path": "test/ext_cuda/conv.jl",
    "content": "using NNlib: DenseConvDims\n\n@testset \"convolution\" begin\n@testset \"$T\" for T in (Float64, ComplexF64)\n    a, b, c = rand(T, 10, 10, 3, 1), rand(T, 2, 2, 3, 4), rand(T, 9, 9, 4, 1)\n    da, db, dc = CuArray(a), CuArray(b), CuArray(c)\n    cdims = DenseConvDims(a, b)\n    @test NNlib.conv(a, b, cdims) ≈ collect(NNlib.conv(da, db, cdims))\n    @test ∇conv_data(c, b, cdims) ≈ collect(∇conv_data(dc, db, cdims))\n    @test ∇conv_filter(a, c, cdims) ≈ collect(∇conv_filter(da, dc, cdims))\n\n    if T <: Complex\n        @testset \"mixed real and complex\" begin\n            @test NNlib.conv(real(a), b, cdims) ≈ collect(NNlib.conv(real(da), db, cdims))\n            @test NNlib.conv(a, real(b), cdims) ≈ collect(NNlib.conv(da, real(db), cdims))\n            @test ∇conv_data(c, real(b), cdims) ≈ collect(∇conv_data(dc, real(db), cdims))\n            @test ∇conv_filter(real(a), c, cdims) ≈ collect(∇conv_filter(real(da), dc, cdims))\n        end\n    end\n\n    # Test Conv Bias Activation\n    bias = rand(T, 1, 1, 4, 1)\n    dbias = CuArray(bias)\n    act = T <: Complex ? abs2 : NNlib.relu \n    @test conv_bias_act(a, b, cdims, bias, act) ≈ collect(conv_bias_act(da, db, cdims, dbias, act))\n    @test conv_bias_act(a, b, cdims, bias, identity) ≈ collect(conv_bias_act(da, db, cdims, dbias, identity))\n\n    # Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs\n    options = Dict{Any, Any}.((\n        (), (:dilation => 2), (:flipkernel => true), (:stride => 2),\n        (:padding => 1),\n        (:padding => (1,0)),\n        (:padding => (0,1)),\n        (:padding => (2,3)),\n    ))\n    C_in_ = 3\n    C_out = 4\n    batch_size = 1\n\n    # we use this activation for the gpu tests\n    # as we can't take gradients of complex quantities\n    act = T <: Complex ? x-> abs2(x) : identity\n    @testset \"groups=$groups, num_spatial_dims=$num_spatial_dims\" for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3)\n        # Make `C_in = C_out` when using grouped convolution.\n        C_in = groups == 1 ? C_in_ : C_out\n        # Initialize data we'll run our tests over\n        x = rand(T, fill(8, num_spatial_dims)..., C_in, batch_size)\n        w = rand(T, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out)\n\n        @testset \"opts #$i\" for (i,opts) in enumerate(options)\n            opts[:groups] = groups\n\n\n            if :padding in keys(opts)\n                padding = opts[:padding]\n                if 1 < length(padding) && length(padding) != 2num_spatial_dims\n                    opts[:padding] = ntuple(i -> padding[mod1(i,2)] .+ 2div(i-1,2), 2num_spatial_dims)   \n                end\n            end\n\n            cdims = DenseConvDims(x, w; opts...)\n            y = NNlib.conv(x, w, cdims)\n\n            # Test that basic convolution is equivalent across GPU/CPU\n            @testset \"cpu==gpu\" begin\n                @testset \"conv\" begin\n                    gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, w)\n                    if T <: Complex\n                        gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), real(x), w)\n                        gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, real(w))\n                    end\n                end\n                @testset \"∇conv_data\" begin\n                    gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, w)\n                    if T <: Complex\n                        gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, real(w))\n                    end\n                end\n                @testset \"∇conv_filter\" begin\n                    gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), x, y) \n                    if T <: Complex\n                        gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), real(x), y)\n                    end\n                end\n            end\n\n            # Scaling factors\n            @testset \"scale-alpha\" begin\n                gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, w, checkgrad=false) # TODO\n                gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, w, checkgrad=false) # TODO\n                gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), x, y, checkgrad=false) # TODO \n\n                if T <: Complex\n                    gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), real(x), w, checkgrad=false) \n                    gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, real(w), checkgrad=false) # TODO\n                    gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, real(w), checkgrad=false) # TODO\n                    gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), real(x), y, checkgrad=false) # TODO\n                end\n            end\n\n            @testset \"scale-beta\" begin\n                gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, w, checkgrad=false, broken=false)\n                gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, x, y, checkgrad=false, broken=false) \n                gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, w, checkgrad=false, broken=false) \n\n                if T <: Complex\n                    gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, real(x), w, checkgrad=false) \n                    gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, real(w), checkgrad=false) \n                    gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, real(w), checkgrad=false) \n                    gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, real(x), y, checkgrad=false)\n                end\n            end\n\n        end\n    end\nend\nend\n"
  },
  {
    "path": "test/ext_cuda/ctc.jl",
    "content": "# Custom function to check numerical gradient of ctc loss,\n# based on `ngradient` in `Tracker.jl`\nfunction ctc_ngradient(x, y)\n  f = ctc_loss\n  grads = zero(x)\n  for i in 1:length(x)\n    δ = sqrt(eps())\n    tmp = x[i]\n    x[i] = tmp - δ/2\n    y1 = f(x, y)\n    x[i] = tmp + δ/2\n    y2 = f(x, y)\n    x[i] = tmp\n    grads[i] = (y2-y1)/δ\n  end\n  return grads\nend\n\n@testset \"ctc-gpu\" begin\n  x = rand(10, 50)\n  y = rand(1:9, 30)\n  x_cu = CuArray(x)\n  g1 = gradient(ctc_loss, x_cu, y)[1]\n  g1 = g1 |> collect\n  g2 = ctc_ngradient(x, y)\n  @test g1 ≈ g2 rtol=1e-5 atol=1e-5\n  \n  # test that GPU loss matches CPU implementation\n  l1 = ctc_loss(x_cu, y)\n  l2 = ctc_loss(x, y)\n  @test l1 ≈ l2\n  \n  # tests using hand-calculated values\n  x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray\n  y = [1, 2]\n  @test ctc_loss(x_cu, y) ≈ 3.6990738275138035\n  \n  g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]\n  ghat = gradient(ctc_loss, x_cu, y)[1] |> collect\n  @test g ≈ ghat rtol=1e-5 atol=1e-5\n\n  x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray\n  y = [1, 2] |> CuArray\n  @test ctc_loss(x_cu, y) ≈ 8.02519869363453\n\n  g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]\n  ghat = gradient(ctc_loss, x_cu, y)[1] |> collect\n  @test g ≈ ghat rtol=1e-5 atol=1e-5\nend"
  },
  {
    "path": "test/ext_cuda/dropout.jl",
    "content": "@testset \"dropout + CUDA\" begin\n    # Basics\n    x1 = CUDA.randn(3, 4)\n    @test size(@inferred dropout(x1, 0.1)) == (3, 4)\n    @test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4)\n    @test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4)\n\n    rng =  CUDA.default_rng()\n    @test size(@inferred dropout(rng, x1, 0.1)) == (3, 4)\n    @test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)\n\n    # Values\n    d45 = dropout(CUDA.ones(100, 100, 100), 0.45)\n    @test mean(d45) ≈ 1 atol=1e-2\n    dpi2 = dropout(CUDA.fill(1f0 * pi, 1000), 0.2)\n    @test sort(unique(Array(dpi2))) ≈ [0, 5pi/4]\n    d33 = dropout(CUDA.fill(3f0, 10, 1000), 0.3, dims=2)\n    @test sort(unique(vec(Array(d33)))) ≈ [0, 3/(1-0.3)]\n\n    # Gradient rule\n    y, back = rrule(dropout, rng, hcat(CUDA.ones(1000), CUDA.zeros(1000)), 0.45)\n    dx = back(CUDA.fill(3f0, 1000, 2))[3]\n    @test !all(iszero, dx[:,2])  # this is why we save the random choices\n    @test sort(unique(vec(Array(dx)))) ≈ [0, 3/(1-0.45)]\n\n    @testset \"Zygote\" begin\n        @test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa CuArray{Float32}\n        @test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa CuArray{Float32}\n        @test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa CuArray{Float32}\n    end\nend\n"
  },
  {
    "path": "test/ext_cuda/fold.jl",
    "content": "\n@testset \"fold\" begin\n    # Test for agreement between CPU/GPU versions, across a variety of kwargs\n    options = Dict{Any, Any}.((\n        (), (:dilation => 2), (:flipkernel => true), (:stride => 2),\n        (:padding => 1),\n        (:padding => (1,0)),\n        (:padding => (0,1)),\n        (:padding => (2,3)),\n    ))\n\n    C_in = 3\n    C_out = 4\n    batch_size = 1\n\n    @testset \"spatial_rank=$spatial_rank\" for spatial_rank in (1, 2, 3)\n        for opts in options\n            if :padding in keys(opts)\n                padding = opts[:padding]\n                if 1 < length(padding) && length(padding) != 2spatial_rank\n                    opts[:padding] = ntuple(i -> padding[mod1(i,2)] .+ 2div(i-1,2), 2spatial_rank)   \n                end\n            end\n\n            x = rand(Float64, fill(8, spatial_rank)..., C_in, batch_size)\n            w = rand(Float64, fill(2, spatial_rank)..., C_in, C_out)\n            cdims = DenseConvDims(x, w; opts...)\n            y = NNlib.unfold(x, cdims)\n\n            # test equivalence of fold/unfold across GPU/CPU\n            gputest(x -> NNlib.unfold(x, cdims), x) \n            gputest(y -> NNlib.fold(y, size(x), cdims), y) \n        end\n    end\nend\n\n"
  },
  {
    "path": "test/ext_cuda/gather.jl",
    "content": "@testset \"gather\" begin\n    T = Float32\n    CT = CuArray{Float32}\n\n    ## 1d src, 2d index of ints -> 2d output\n    src = CT([3, 4, 5, 6, 7])\n    index = cu([1 2 3 4;\n                4 2 1 3;\n                3 5 5 3])\n    output = CT([3 4 5 6;\n                6 4 3 5;\n                5 7 7 5])\n\n    y = NNlib.gather(src, index)\n    @test y isa CuArray{Float32,2}\n    @test size(y) == size(index)\n    gputest(src -> NNlib.gather(src, index), src, checkgrad=true)\n    @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output\n    @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)\n\n    ## 1d src, 2d index of tuples -> 2d output\n    src = CT([3, 4, 5, 6, 7])\n    index = cu([(1,) (2,) (3,) (4,);\n                (4,) (2,) (1,) (3,);\n                (3,) (5,) (5,) (3,)])\n    output = CT([3 4 5 6;\n                6 4 3 5;\n                5 7 7 5])\n\n    y = NNlib.gather(src, index)\n    @test y isa CuArray{Float32,2}\n    @test size(y) == size(index)\n    gputest(src -> NNlib.gather(src, index), src, checkgrad=true)\n    @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output\n    @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)\n\n    ## 1d src, 2d index of CartesianIndex -> 2d output\n    src = CT([3, 4, 5, 6, 7])\n    index = cu(CartesianIndex.([(1,) (2,) (3,) (4,);\n                (4,) (2,) (1,) (3,);\n                (3,) (5,) (5,) (3,)]))\n    output = CT([3 4 5 6;\n                6 4 3 5;\n                5 7 7 5])\n\n    y = NNlib.gather(src, index)\n    @test y isa CuArray{Float32,2}\n    @test size(y) == size(index)\n    gputest(src -> NNlib.gather(src, index), src, checkgrad=true)\n    @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output\n    @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)\n\n    ## 1d src, 3d index of ints -> 3d output\n    src = CT([3, 4, 5, 6, 7])\n    index = cu([1 2 3 4;\n                4 2 1 3;\n                3 5 5 3][:,:,1:1])\n    output = CT([3 4 5 6;\n                6 4 3 5;\n                5 7 7 5][:,:,1:1])\n\n    y = NNlib.gather(src, index)\n    @test y isa CuArray{Float32,3}\n    @test size(y) == size(index)\n    gputest(src -> NNlib.gather(src, index), src, checkgrad=true)\n\n\n    ## 2d src, 2d index of ints -> 3d output\n    src = CT([3 5 7\n             4 6 8])\n    index = cu([1 2 3;\n                2 2 1;\n                3 1 3])\n\n    output = zeros(T, 2, 3, 3)\n\n    output[:,:,1] = [3 5 7\n                    4 6 8]\n\n    output[:,:,2] = [5 5 3\n                    6 6 4]\n\n    output[:,:,3] = [7 3 7\n                    8 4 8]\n\n    y = NNlib.gather(src, index)\n    M = NNlib.typelength(eltype(index))\n    Nsrc = ndims(src)\n    @test y isa CuArray{Float32,3}\n    @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)\n    gputest(src -> NNlib.gather(src, index), src, checkgrad=true)\n\n    @testset \"views\" begin\n        x = cu(rand(2, 5))\n        v = view(x, axes(x)...)\n        i = cu([1, 2])   \n        outx = NNlib.gather(x, i)\n        outv = NNlib.gather(v, i)\n        @test outx == outv\n\n        # discontinuous view\n        v2 = view(x, :, [1,3,5])\n        outv2 = NNlib.gather(v2, i)\n        @test collect(outv2) == NNlib.gather(collect(v2), collect(i))        \n    end\n\n    # Zero-sized\n    x = CT([1,2,3])\n    i = CT(Int[])\n    y = NNlib.gather(x, i)\n    @test isempty(y)\nend\n"
  },
  {
    "path": "test/ext_cuda/pooling.jl",
    "content": "@testset \"pooling\" begin\n\n    # Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs\n    for num_spatial_dims in (1, 2, 3)\n        # Initialize data we'll run our tests over\n        C_in = 3\n        batch_size = 1\n        x = rand(Float64, fill(8, num_spatial_dims)..., C_in, batch_size)\n       \n        # Test that pooling is equivalent across GPU/CPU\n        pdims = PoolDims(x, 2)\n        y = maxpool(x, pdims)\n        dy = ones(size(y))\n        gputest(x -> maxpool(x, pdims), x)\n        gputest((dy, y, x) -> ∇maxpool(dy, y, x, pdims), dy, y, x, checkgrad=false)\n        gputest(x -> maxpool(x, pdims), x)\n        gputest((dy, y, x) -> ∇maxpool(dy, y, x, pdims), dy, y, x, checkgrad=false)\n    end\nend\n"
  },
  {
    "path": "test/ext_cuda/runtests.jl",
    "content": "using Test\nusing NNlib\nusing Zygote\nusing ForwardDiff: Dual\nusing Statistics: mean\nusing CUDA, cuDNN\nimport CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO\nusing NNlib: batchnorm, ∇batchnorm\nCUDA.allowscalar(false)\n\ninclude(\"test_utils.jl\")\ninclude(\"activations.jl\")\ninclude(\"dropout.jl\")\ninclude(\"batchedadjtrans.jl\")\ninclude(\"batchedmul.jl\")\ninclude(\"conv.jl\")\ninclude(\"ctc.jl\")\ninclude(\"fold.jl\")\ninclude(\"pooling.jl\")\ninclude(\"softmax.jl\")\ninclude(\"batchnorm.jl\")\ninclude(\"scatter.jl\")\ninclude(\"gather.jl\")\ninclude(\"sampling.jl\")\n"
  },
  {
    "path": "test/ext_cuda/sampling.jl",
    "content": "@testset \"Grid Sampling\" begin\n    for T in (Float32, Float64)\n        x = ones(T, (2, 2, 1, 1))\n        grid = Array{T}(undef, 2, 2, 2, 1)\n        grid[:, 1, 1, 1] .= (-1, -1)\n        grid[:, 2, 1, 1] .= (1, -1)\n        grid[:, 1, 2, 1] .= (-1, 1)\n        grid[:, 2, 2, 1] .= (1, 1)\n\n        ∇grid_true = Array{T}(undef, size(grid))\n        ∇grid_true[:, :, 1, 1] = [[0.0, 0.0] [-0.5, 0.0]]\n        ∇grid_true[:, :, 2, 1] = [[0.0, -0.5] [-0.5, -0.5]]\n\n        x_gpu, grid_gpu = CuArray(x), CuArray(grid)\n\n        padding_mode = :zeros\n        y_gpu = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode)\n        @test x == collect(y_gpu)\n        @test eltype(y_gpu) == T\n\n        external_grad = CUDA.ones(T, size(y_gpu))\n        ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode)\n        @test x == collect(∇input)\n        @test ∇grid_true == collect(∇grid)\n        @test eltype(∇input) == T\n        @test eltype(∇grid) == T\n\n        padding_mode = :border\n        fill!(∇grid_true, 0.0)\n        sampled = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode)\n        @test x == collect(sampled)\n        @test eltype(sampled) == T\n\n        ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode)\n        @test x == collect(∇input)\n        @test ∇grid_true == collect(∇grid)\n        @test eltype(∇input) == T\n        @test eltype(∇grid) == T\n    end\nend\n\n@testset \"Compare grid sampling with NNlib\" begin\n    w, h, c, n = 16, 16, 2, 4\n    input = rand(Float64, w, h, c, n)\n    grid = zeros(Float64, 2, w, h, n)\n    @inbounds for xi in 1:w, yi in 1:h, ni in 1:n\n        grid[1, xi, yi, ni] = (xi / w) * 2.0 - 1.0 + 0.01\n        grid[2, xi, yi, ni] = (yi / h) * 2.0 - 1.0\n    end\n    for padding_mode in (:zeros, :border)\n        gputest(grid_sample, input, grid; atol=1e-6, padding_mode=padding_mode)\n    end\nend\n\n@testset \"Grid Sampling 3D\" begin\n    for T in (Float32, Float64)\n        x = ones(T, (2, 2, 2, 1, 1))  # 3D input with depth=2\n        grid = Array{T}(undef, 3, 2, 2, 2, 1)  # 3D grid with depth=2\n        grid[:, 1, 1, 1, 1] .= (-1, -1, -1)\n        grid[:, 2, 1, 1, 1] .= (1, -1, -1)\n        grid[:, 1, 2, 1, 1] .= (-1, 1, -1)\n        grid[:, 2, 2, 1, 1] .= (1, 1, -1)\n        grid[:, 1, 1, 2, 1] .= (-1, -1, 1)\n        grid[:, 2, 1, 2, 1] .= (1, -1, 1)\n        grid[:, 1, 2, 2, 1] .= (-1, 1, 1)\n        grid[:, 2, 2, 2, 1] .= (1, 1, 1)\n\n        ∇grid_true = Array{T}(undef, size(grid))\n        ∇grid_true[:, 1, 1, 1, 1] .= (0.0, 0.0, 0.0)\n        ∇grid_true[:, 2, 1, 1, 1] .= (-0.5, 0.0, 0.0)\n        ∇grid_true[:, 1, 2, 1, 1] .= (0.0, -0.5, 0.0)\n        ∇grid_true[:, 2, 2, 1, 1] .= (-0.5, -0.5, 0.0)\n        ∇grid_true[:, 1, 1, 2, 1] .= (0.0, 0.0, -0.5)\n        ∇grid_true[:, 2, 1, 2, 1] .= (-0.5, 0.0, -0.5)\n        ∇grid_true[:, 1, 2, 2, 1] .= (0.0, -0.5, -0.5)\n        ∇grid_true[:, 2, 2, 2, 1] .= (-0.5, -0.5, -0.5)\n\n\n        x_gpu, grid_gpu = CuArray(x), CuArray(grid)\n\n        padding_mode = :zeros\n        y_gpu = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode)\n        @test x == collect(y_gpu)\n        @test eltype(y_gpu) == T\n\n        external_grad = CUDA.ones(T, size(y_gpu))\n        ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode)\n        @test x == collect(∇input)\n        @test ∇grid_true == collect(∇grid)\n        @test eltype(∇input) == T\n        @test eltype(∇grid) == T\n\n        padding_mode = :border\n        fill!(∇grid_true, 0.0)\n        sampled = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode)\n        @test x == collect(sampled)\n        @test eltype(sampled) == T\n\n        ∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode)\n        @test x == collect(∇input)\n        @test ∇grid_true == collect(∇grid)\n        @test eltype(∇input) == T\n        @test eltype(∇grid) == T\n    end\nend\n\n@testset \"Compare grid sampling with NNlib 3D\" begin\n    w, h, d, c, n = 16, 16, 16, 2, 4  # Added depth dimension `d`\n    input = rand(Float64, w, h, d, c, n)\n    grid = zeros(Float64, 3, w, h, d, n)  # 3D grid with depth `d`\n    @inbounds for xi in 1:w, yi in 1:h, zi in 1:d, ni in 1:n\n        grid[1, xi, yi, zi, ni] = (xi / w) * 2.0 - 1.0 + 0.01\n        grid[2, xi, yi, zi, ni] = (yi / h) * 2.0 - 1.0\n        grid[3, xi, yi, zi, ni] = (zi / d) * 2.0 - 1.0\n    end\n    for padding_mode in (:zeros, :border)\n        gputest(grid_sample, input, grid; atol=1e-6, padding_mode=padding_mode)\n    end\nend\n"
  },
  {
    "path": "test/ext_cuda/scatter.jl",
    "content": "dsts = Dict(\n    0 => cu([3, 4, 5, 6, 7]),\n    1 => cu([3 3 4 4 5;\n             5 5 6 6 7]),\n)\nsrcs = Dict(\n    (0, true) => cu(ones(Int, 3, 4)),\n    (0, false) => cu(ones(Int, 3) * collect(1:4)'),\n    (1, true) => cu(ones(Int, 2, 3, 4)),\n    (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)),\n)\nidxs = [\n    cu([1 2 3 4;\n        4 2 1 3;\n        3 5 5 3]),  # integer index\n    cu([(1,) (2,) (3,) (4,);\n        (4,) (2,) (1,) (3,);\n        (3,) (5,) (5,) (3,)]),  # tuple index\n    cu(CartesianIndex.([(1,) (2,) (3,) (4,);\n        (4,) (2,) (1,) (3,);\n        (3,) (5,) (5,) (3,)])),  # CartesianIndex index\n]\n\ntypes = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}]\n\n\n@testset \"scatter\" begin\n    for T = types\n        @testset \"$(T)\" begin\n            @testset \"+\" begin\n                for idx = idxs, dims = [0, 1]\n                    mutated = true\n                    gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)\n\n                    mutated = false\n                    gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)\n                end\n            end\n\n            @testset \"-\" begin\n                for idx = idxs, dims = [0, 1]\n                    mutated = true\n                    gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)\n\n                    mutated = false\n                    gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)\n                end\n            end\n\n            @testset \"max\" begin\n                for idx = idxs, dims = [0, 1]\n                    mutated = true\n                    gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)\n\n                    mutated = false\n                    gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)\n                end\n            end\n\n            @testset \"min\" begin\n                for idx = idxs, dims = [0, 1]\n                    mutated = true\n                    gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)\n\n                    mutated = false\n                    gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)\n                end\n            end\n        end\n    end\n\n\n    for T = [CuArray{Float32}, CuArray{Float64}]\n        @testset \"$(T)\" begin\n            @testset \"*\" begin\n                for idx = idxs, dims = [0, 1]\n                    mutated = true\n                    gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)\n\n                    mutated = false\n                    gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)\n                end\n            end\n\n            @testset \"/\" begin\n                for idx = idxs, dims = [0, 1]\n                    mutated = true\n                    gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)\n\n                    mutated = false\n                    gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)\n                end\n            end\n\n            @testset \"mean\" begin\n                for idx = idxs, dims = [0, 1]\n                    mutated = true\n                    gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)\n\n                    mutated = false\n                    gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)\n                end\n            end\n        end\n    end\nend\n"
  },
  {
    "path": "test/ext_cuda/softmax.jl",
    "content": "@testset \"softmax\" begin\n    for (sz, dims) in [((5,), :), ((5,), 1), ((5,5), :), ((5,5), 1), ((5,5), 2), ((5,5,5,5), (2,3)), ((5,5,5,5), (2,4))]\n        x = randn(Float64, sz)\n        dy = randn(Float64, sz)\n\n        y = softmax(x, dims=dims)\n        gputest(softmax, x, dims=dims)\n        gputest(NNlib.∇softmax_data, dy, y; dims=dims)\n\n        y2 = logsoftmax(x, dims=dims)\n        gputest(logsoftmax, x, dims=dims)\n        gputest(NNlib.∇logsoftmax_data, dy, y2; dims=dims)\n\n        # From NNlib 0.8.3, ∇softmax! is not used in the gradient.\n        # But NNlibCUDA still knows how to call cuDNN routines, let's test they agree:\n        @test NNlib.∇softmax_data(dy, y; dims=dims) ≈ collect(∇softmax!(similar(cu(x)), cu(dy), cu(x), cu(y); dims=dims)) atol=1e-4\n        @test NNlib.∇logsoftmax_data(dy, y2; dims=dims) ≈ collect(∇logsoftmax!(similar(cu(x)), cu(dy), cu(x), cu(y2); dims=dims)) atol=1e-4\n        # (Note that ∇softmax! does not depend on x, it's just there to disambiguate from an even older signature.)\n    end\nend\n"
  },
  {
    "path": "test/ext_cuda/test_utils.jl",
    "content": "function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...)\n    cpu_in = xs\n    gpu_in = CuArray.(xs)\n\n    cpu_out = f(cpu_in...; kws...)\n    gpu_out = f(gpu_in...; kws...)\n    @test collect(cpu_out) ≈ collect(gpu_out) rtol=rtol atol=atol broken=broken \n\n    if checkgrad\n        # use mean instead of sum to prevent error accumulation (for larger\n        # tensors) which causes error to go above atol\n        cpu_grad = gradient((x...) -> mean(f(x...; kws...)), cpu_in...)\n        gpu_grad = gradient((x...) -> mean(f(x...; kws...)), gpu_in...)\n        for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad)\n            if cpu_g === nothing\n                @test gpu_g === nothing\n            else\n                @test collect(cpu_g) ≈ collect(gpu_g) rtol=rtol atol=atol broken=broken_grad \n            end\n        end\n    end\nend\n"
  },
  {
    "path": "test/ext_metal/activations.jl",
    "content": "@testset \"activation broadcast\" begin\n    broken_f = (:hardσ, :leakyrelu) \n    for name in NNlib.ACTIVATIONS\n        # println(\"Testing forward diff for activation: \", name)\n        f = @eval $name\n        @test gputest(DEVICE, x -> f.(x), rand(5)) broken=name ∈ broken_f\n    end\nend\n\n@testset \"forward diff\" begin\n    broken_f = (:hardσ, :leakyrelu) \n    for name in NNlib.ACTIVATIONS\n        # println(\"Testing forward diff for activation: \", name)\n        f = @eval $name\n        @test gputest(DEVICE, x -> f.(x), Dual.(rand(Float32, 5), 1)) broken=name ∈ broken_f\n    end\nend\n"
  },
  {
    "path": "test/ext_metal/runtests.jl",
    "content": "using NNlib\nusing Test\nusing Metal\nusing Zygote: gradient\nusing MLDataDevices: gpu_device\nusing ForwardDiff: Dual\n\nMetal.allowscalar(false)\n\n#TODO move this to test/ test_utils.jl and use it with all backends\nfunction gputest(device, f, xs...; checkgrad=true, atol=1e-6, kws...)\n    cpu_in = xs\n    gpu_in = device(xs)\n\n    cpu_out = f(cpu_in...; kws...)\n    gpu_out = f(gpu_in...; kws...)\n    @test collect(cpu_out) ≈ collect(gpu_out)\n\n    if checkgrad\n        cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...)\n        gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...)\n        for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad)\n            if cpu_g === nothing\n                @test gpu_g === nothing\n            else\n                @test collect(cpu_g) ≈ collect(gpu_g) atol=atol\n            end\n        end\n    end\n    return true\nend\n\nDEVICE = gpu_device(force=true)\n\ninclude(\"activations.jl\")"
  },
  {
    "path": "test/functions.jl",
    "content": "using NNlib: glu\nusing Zygote\n\n@testset \"glu\" begin\n    x = [1 2 3; 4 5 6; 7 8 9; 10 11 12]\t \n    @test ceil.(glu(x, 1)) == [1 2 3; 4 5 6]\n    @test_throws AssertionError glu(x, 2)\nend\n\n@testset \"AutoDiff\" begin\n    local rng = StableRNG(17)\n    gradtest(glu, rand(rng, 4, 3))\nend\n\n"
  },
  {
    "path": "test/inference.jl",
    "content": "import NNlib: conv_direct, conv_im2col, channels_in, channels_out\n\n@testset \"Conv Inference\" begin\n    for T in (Float32, Float64)\n        impl = [conv, conv_direct, conv_im2col]\n\n        x = rand(T, 10, 10, 3, 2)\n        w = rand(T, 3, 3, 3, 1)\n        cdims = DenseConvDims(x, w)\n        dy = conv(x, w, cdims)\n\n        for f in impl\n            @test @inferred(f(x, w, cdims)) isa Array{T,4}\n        end\n\n        @test @inferred(conv(x, w)) isa Array{T,4}\n        @test @inferred(∇conv_filter(x, dy, cdims)) isa Array{T,4}\n        @test @inferred(∇conv_data(dy, w, cdims)) isa Array{T,4}\n    end\nend\n\n@testset \"DepthwiseConv Inference\" begin\n    for T in (Float32, Float64)\n        x = rand(T, 10, 10, 3, 2)\n        w = rand(T, 3, 3, 3, 3)\n        cdims = DepthwiseConvDims(x, w)\n        dy = depthwiseconv(x, w)\n\n        @test @inferred(depthwiseconv(x, w)) isa Array{T,4}\n        @test @inferred(∇depthwiseconv_filter(x, dy, cdims)) isa Array{T,4}\n        @test @inferred(∇depthwiseconv_data(dy, w, cdims)) isa Array{T,4}\n    end\nend\n\n@testset \"DenseConvDims Inference\" begin\n    # this needs to be in a function to trigger inference problems\n    function channels_in_test(w::AbstractArray)\n        cdims = DenseConvDims((1,1,1,1), size(w))\n        channels_in(cdims)\n    end\n\n    # this needs to be in a function to trigger inference problems\n    function channels_out_test(w::AbstractArray)\n        cdims = DenseConvDims((1,1,1,1), size(w))\n        channels_out(cdims)\n    end\n\n    w = rand(Float32, 1, 1, 1, 1)\n    @test @inferred(channels_in_test(w)) isa Int\n    @test @inferred(channels_out_test(w)) isa Int\nend\n\n@testset \"Pooling inference\" begin\n    for T in (Float32, Float64)\n        x = rand(T, 10, 10, 3, 2)\n        pdims = PoolDims(x, 3)\n\n        y_maxpool = NNlib.maxpool(x, pdims)\n        y_meanpool = NNlib.meanpool(x, pdims)\n        dy = ones(T, size(y_maxpool)...)\n\n        @test @inferred(NNlib.maxpool(x, pdims)) isa Array{T, 4}\n        @test @inferred(NNlib.meanpool(x, pdims)) isa Array{T, 4}\n        @test @inferred(NNlib.∇maxpool(dy, y_maxpool, x, pdims)) isa Array{T, 4}\n        @test @inferred(NNlib.∇maxpool(dy, y_meanpool, x, pdims)) isa Array{T, 4}\n    end\nend\n"
  },
  {
    "path": "test/padding.jl",
    "content": "using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular\n\n@testset \"padding constant\" begin\n    x = rand(2, 2, 2)\n\n    p = NNlib.gen_pad((1,2,3,4,5,6), (1,2,3), 4)\n    @test p == ((1, 2), (3, 4), (5, 6), (0, 0))\n\n    @test_throws ArgumentError NNlib.gen_pad((1,2,3,4,5,), (1,2,3), 4)\n\n    p = NNlib.gen_pad((1,3), (1,3), 4)\n    @test p == ((1, 1), (0, 0), (3, 3), (0, 0))\n\n    p = NNlib.gen_pad(1, (1,2,3), 4)\n    @test p == ((1, 1), (1, 1), (1, 1), (0, 0))\n\n    p = NNlib.gen_pad(3, :, 2)\n    @test p == ((3, 3), (3, 3))\n\n    p = NNlib.gen_pad((1,0), 1, 2)\n    @test p == ((1,0), (0,0))\n\n    y = pad_constant(x, (3, 2, 4))\n    @test size(y) == (8, 6, 10)\n    @test y[4:5, 3:4, 5:6] ≈ x\n    y[4:5, 3:4, 5:6] .= 0\n    @test all(y .== 0)\n\n    @test pad_constant(x, (3, 2, 4)) ≈ pad_zeros(x, (3, 2, 4))\n    @test pad_zeros(x, 2) ≈ pad_zeros(x, (2,2,2))\n\n    y = pad_constant(x, (3, 2, 4, 5), 1.2, dims = (1,3))\n    @test size(y) == (7, 2, 11)\n    @test y[4:5, 1:2, 5:6] ≈ x\n    y[4:5, 1:2, 5:6] .= 1.2\n    @test all(y .== 1.2)\n\n    @test pad_constant(x, (2,2,2,2), 1.2, dims = (1,3)) ≈\n        pad_constant(x, 2, 1.2, dims = (1,3))\n\n    @test pad_constant(x, 1, dims = 1:2) ==\n        pad_constant(x, 1, dims = (1,2))\n\n    @test size(pad_constant(x, 1, dims = 1)) == (4,2,2)\n\n    @test all(pad_zeros(randn(2), (1, 2))[[1, 4, 5]] .== 0)\n\n    gradtest(x -> pad_constant(x, 2), rand(2,2,2))\n    gradtest(x -> pad_constant(x, (2, 1, 1, 2)), rand(2,2))\n    gradtest(x -> pad_constant(x, (2, 1,)), rand(2))\nend\n\n@testset \"padding repeat\" begin\n    x = rand(2, 2, 2)\n\n    # y = @inferred pad_repeat(x, (3, 2, 4, 5))\n    y = pad_repeat(x, (3, 2, 4, 5))\n    @test size(y) == (7, 11, 2)\n    @test y[4:5, 5:6, :] ≈ x\n\n    # y = @inferred pad_repeat(x, (3, 2, 4, 5), dims=(1,3))\n    y = pad_repeat(x, (3, 2, 4, 5), dims=(1,3))\n    @test size(y) == (7, 2, 11)\n    @test y[4:5, :, 5:6] ≈ x\n\n    @test pad_repeat(reshape(1:9, 3, 3), (1,2)) ==\n        [1  4  7\n         1  4  7\n         2  5  8\n         3  6  9\n         3  6  9\n         3  6  9]\n\n    @test pad_repeat(reshape(1:9, 3, 3), (2,2), dims=2) ==\n        [1  1  1  4  7  7  7\n         2  2  2  5  8  8  8\n         3  3  3  6  9  9  9]\n\n    @test pad_repeat(x, (2, 2, 2, 2), dims=(1,3)) ≈\n        pad_repeat(x, 2, dims=(1,3))\n\n    gradtest(x -> pad_repeat(x, (2,2,2,2)), rand(2,2,2))\nend\n\n@testset \"padding reflect\" begin\n    y = pad_reflect(reshape(1:9, 3, 3), (2,2), dims=2)\n    @test y == [7  4  1  4  7  4  1\n                8  5  2  5  8  5  2\n                9  6  3  6  9  6  3]\n\n    y = pad_reflect(reshape(1:9, 3, 3), (2,2,2,2))\n    @test y == [9  6  3  6  9  6  3\n                8  5  2  5  8  5  2\n                7  4  1  4  7  4  1\n                8  5  2  5  8  5  2\n                9  6  3  6  9  6  3\n                8  5  2  5  8  5  2\n                7  4  1  4  7  4  1]\n\n    x = rand(4, 4, 4)\n    @test pad_reflect(x, (2, 2, 2, 2), dims=(1,3)) ≈\n        pad_reflect(x, 2, dims=(1,3))\n\n    # pad_reflect needs larger test input as padding must\n    # be strictly less than array size in that dimension\n    gradtest(x -> pad_reflect(x, (2,2,2,2)), rand(3,3,3))\n\n    x = reshape(1:9, 3, 3, 1, 1)\n    @test NNlib.pad_reflect(x, (1, 0, 1, 0); dims=1:2) == [\n        5 2 5 8;\n        4 1 4 7;\n        5 2 5 8;\n        6 3 6 9;;;;]\n    @test NNlib.pad_reflect(x, (0, 1, 0, 1); dims=1:2) == [\n        1 4 7 4;\n        2 5 8 5;\n        3 6 9 6;\n        2 5 8 5;;;;]\nend\n\n@testset \"padding symmetric\" begin\n    y = pad_symmetric(reshape(1:9, 3, 3), (2,2), dims=2)\n    @test y == [4  1  1  4  7  7  4\n                5  2  2  5  8  8  5\n                6  3  3  6  9  9  6]\n\n    y = pad_symmetric(reshape(1:9, 3, 3), (2,2,2,2))\n    @test y == [5  2  2  5  8  8  5\n                4  1  1  4  7  7  4\n                4  1  1  4  7  7  4\n                5  2  2  5  8  8  5\n                6  3  3  6  9  9  6\n                6  3  3  6  9  9  6\n                5  2  2  5  8  8  5]\n\n    x = rand(4, 4, 4)\n    @test pad_symmetric(x, (2, 2, 2, 2), dims=(1,3)) ≈\n        pad_symmetric(x, 2, dims=(1,3))\n\n    gradtest(x -> pad_symmetric(x, (2,2,2,2)), rand(2,2,2))\n\n    x = reshape(1:9, 3, 3, 1, 1)\n    @test NNlib.pad_symmetric(x, (1, 0, 1, 0); dims=1:2) == [\n        1 1 4 7;\n        1 1 4 7;\n        2 2 5 8;\n        3 3 6 9;;;;]\n    @test NNlib.pad_symmetric(x, (0, 1, 0, 1); dims=1:2) == [\n        1 4 7 7;\n        2 5 8 8;\n        3 6 9 9;\n        3 6 9 9;;;;]\nend\n\n@testset \"padding circular\" begin\n    y = pad_circular(reshape(1:9, 3, 3), (2,2), dims=2)\n    @test y == [4  7  1  4  7  1  4\n                5  8  2  5  8  2  5\n                6  9  3  6  9  3  6]\n\n    y = pad_circular(reshape(1:9, 3, 3), (2,2,2,2))\n    @test y == [5  8  2  5  8  2  5\n                6  9  3  6  9  3  6\n                4  7  1  4  7  1  4\n                5  8  2  5  8  2  5\n                6  9  3  6  9  3  6\n                4  7  1  4  7  1  4\n                5  8  2  5  8  2  5]\n\n    x = rand(4, 4, 4)\n    @test pad_circular(x, (2, 2, 2, 2), dims=(1,3)) ≈\n        pad_circular(x, 2, dims=(1,3))\n\n    gradtest(x -> pad_circular(x, (2,2,2,2)), rand(2,2,2))\nend\n"
  },
  {
    "path": "test/pooling.jl",
    "content": "# using NNlib, Test\n\nmaxpool_answer_dict = Dict(\n    1 => Dict(\n        \"y\"          => [2, 4.],\n        \"y_nostride\" => [2, 3, 4, 5.],\n        \"y_pad\"      => [1, 3, 5.],\n\n        \"dx\"          => [0, 2, 0, 4, 0.],\n        \"dx_nostride\" => [0, 2, 3, 4, 5.],\n        \"dx_pad\"      => [1, 0, 3, 0, 5.],\n    ),\n    2 => Dict(\n        \"y\" => [\n            7 17.;\n            9 19.\n        ],\n        \"y_nostride\" => [\n            7  12 17;\n            8  13 18;\n            9  14 19;\n            10 15 20.\n        ],\n        \"y_pad\" => [\n            1  11 16;\n            3  13 18;\n            5  15 20.\n        ],\n\n        \"dx\" => [\n            0 0 0  0;\n            0 7 0 17;\n            0 0 0  0;\n            0 9 0 19;\n            0 0 0  0.\n        ],\n        \"dx_nostride\" => [\n            0  0  0  0;\n            0  7 12 17;\n            0  8 13 18;\n            0  9 14 19;\n            0 10 15 20.\n        ],\n        \"dx_pad\"      => [\n            1 0 11 16;\n            0 0  0  0;\n            3 0 13 18;\n            0 0  0  0;\n            5 0 15 20.\n        ],\n    ),\n    3 => Dict(\n        \"y\" => reshape([\n            27, 29,\n            37, 39.\n        ], (2, 2, 1)),\n        \"y_nostride\" => reshape([\n            27, 28, 29, 30,\n            32, 33, 34, 35,\n            37, 38, 39, 40,\n\n            47, 48, 49, 50,\n            52, 53, 54, 55,\n            57, 58, 59, 60.\n        ], (4, 3, 2)),\n        \"y_pad\" => reshape([\n             1,  3, 5,\n            11, 13, 15,\n            16, 18, 20,\n\n            41, 43, 45,\n            51, 53, 55,\n            56, 58, 60.\n        ], (3, 3, 2)),\n\n        \"dx\" => reshape([\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n\n            0,  0, 0,  0, 0,\n            0, 27, 0, 29, 0,\n            0,  0, 0,  0, 0,\n            0, 37, 0, 39, 0,\n\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0.\n        ], (5, 4, 3)),\n        \"dx_nostride\" => reshape([\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n\n            0,  0,  0,  0,  0,\n            0, 27, 28, 29, 30,\n            0, 32, 33, 34, 35,\n            0, 37, 38, 39, 40,\n\n            0,  0,  0,  0,  0,\n            0, 47, 48, 49, 50,\n            0, 52, 53, 54, 55,\n            0, 57, 58, 59, 60.\n        ], (5, 4, 3)),\n        \"dx_pad\" => reshape([\n             1, 0,  3, 0,  5,\n             0, 0,  0, 0,  0,\n            11, 0, 13, 0, 15,\n            16, 0, 18, 0, 20,\n\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n            0, 0, 0, 0, 0,\n\n            41, 0, 43, 0, 45,\n             0, 0,  0, 0,  0,\n            51, 0, 53, 0, 55,\n            56, 0, 58, 0, 60.\n        ], (5, 4, 3)),\n    )\n)\n\nmeanpool_answer_dict = Dict(\n    1 => Dict(\n        \"y\"          => [1.5, 3.5],\n        \"y_nostride\" => [1.5, 2.5, 3.5, 4.5],\n        \"y_pad\"      => [0.5, 2.5, 4.5],\n\n        \"dx\"          => [0.75, 0.75, 1.75, 1.75,  0.0],\n        \"dx_nostride\" => [0.75,  2.0,  3.0,  4.0, 2.25],\n        \"dx_pad\"      => [0.25, 1.25, 1.25, 2.25, 2.25],\n    ),\n    2 => Dict(\n        \"y\" => [\n            4.0 14.0;\n            6.0 16.0\n        ],\n        \"y_nostride\" => [\n            4.0  9.0 14.0\n            5.0 10.0 15.0\n            6.0 11.0 16.0\n            7.0 12.0 17.0\n        ],\n        \"y_pad\" => [\n            0.25  4.25 4.0\n            1.25 10.0  8.75\n            2.25 12.0  9.75\n        ],\n\n        \"dx\" => [\n            1.0 1.0 3.5 3.5;\n            1.0 1.0 3.5 3.5;\n            1.5 1.5 4.0 4.0;\n            1.5 1.5 4.0 4.0;\n            0.0 0.0 0.0 0.0\n        ],\n        \"dx_nostride\" => [\n            1.0  3.25  5.75 3.5;\n            2.25 7.0  12.0  7.25;\n            2.75 8.0  13.0  7.75;\n            3.25 9.0  14.0  8.25;\n            1.75 4.75  7.25 4.25\n        ],\n        \"dx_pad\"      => [\n            0.0625 1.0625 1.0625 1.0;\n            0.3125 2.5    2.5    2.1875;\n            0.3125 2.5    2.5    2.1875;\n            0.5625 3.0    3.0    2.4375;\n            0.5625 3.0    3.0    2.4375\n        ],\n    ),\n    3 => Dict(\n        \"y\" => reshape([\n            14.0, 16.0,\n            24.0, 26.0\n        ], (2, 2, 1)),\n        \"y_nostride\" => reshape([\n            14.0, 15.0, 16.0, 17.0,\n            19.0, 20.0, 21.0, 22.0,\n            24.0, 25.0, 26.0, 27.0,\n\n            34.0, 35.0, 36.0, 37.0,\n            39.0, 40.0, 41.0, 42.0,\n            44.0, 45.0, 46.0, 47.0\n        ], (4, 3, 2)),\n        \"y_pad\" => reshape([\n            0.125, 0.625, 1.125,\n            2.125, 5.0,   6.0,\n            2.0,   4.375, 4.875,\n\n             7.75, 16.25, 17.25,\n            19.25, 40.0,  42.0,\n            11.5,  23.75, 24.75,\n        ], (3, 3, 2)),\n\n        \"dx\" => reshape([\n            1.75, 1.75, 2.0, 2.0, 0.0,\n            1.75, 1.75, 2.0, 2.0, 0.0,\n            3.0, 3.0, 3.25, 3.25, 0.0,\n            3.0, 3.0, 3.25, 3.25, 0.0,\n\n            1.75, 1.75, 2.0, 2.0, 0.0,\n            1.75, 1.75, 2.0, 2.0, 0.0,\n            3.0, 3.0, 3.25, 3.25, 0.0,\n            3.0, 3.0, 3.25, 3.25, 0.0,\n\n            0.0, 0.0, 0.0, 0.0, 0.0,\n            0.0, 0.0, 0.0, 0.0, 0.0,\n            0.0, 0.0, 0.0, 0.0, 0.0,\n            0.0, 0.0, 0.0, 0.0, 0.0,\n        ], (5, 4, 3)),\n        \"dx_nostride\" => reshape([\n            1.75,   3.625,  3.875,  4.125, 2.125,\n            4.125,  8.5,    9.0,    9.5,   4.875,\n            5.375, 11.0,   11.5,   12.0,   6.125,\n            3.0,    6.125,  6.375,  6.625, 3.375,\n\n             6.0,  12.25, 12.75, 13.25,  6.75,\n            13.25, 27.0,  28.0,  29.0,  14.75,\n            15.75, 32.0,  33.0,  34.0,  17.25,\n             8.5,  17.25, 17.75, 18.25,  9.25,\n\n             4.25,   8.625,  8.875,  9.125,  4.625,\n             9.125, 18.5,   19.0,   19.5,    9.875,\n            10.375, 21.0,   21.5,   22.0,   11.125,\n             5.5,   11.125, 11.375, 11.625,  5.875\n        ], (5, 4, 3)),\n        \"dx_pad\" => reshape([\n            0.015625, 0.078125, 0.078125, 0.140625, 0.140625,\n            0.265625, 0.625,    0.625,    0.75,     0.75,\n            0.265625, 0.625,    0.625,    0.75,     0.75,\n            0.25,     0.546875, 0.546875, 0.609375, 0.609375,\n\n            0.96875, 2.03125, 2.03125, 2.15625, 2.15625,\n            2.40625, 5.0,     5.0,     5.25,    5.25,\n            2.40625, 5.0,     5.0,     5.25,    5.25,\n            1.4375,  2.96875, 2.96875, 3.09375, 3.09375,\n\n            0.96875, 2.03125, 2.03125, 2.15625, 2.15625,\n            2.40625, 5.0,     5.0,     5.25,    5.25,\n            2.40625, 5.0,     5.0,     5.25,    5.25,\n            1.4375,  2.96875, 2.96875, 3.09375, 3.09375\n        ], (5, 4, 3)),\n    )\n)\n\nlpnormpool_answer_dict = Dict(\n    1 => Dict(\n        \"y\"           => [2.019312856150994, 4.221163518110637],\n        \"y_nostride\"  => [\n            2.080083823051904, 3.2710663101885897,\n            4.497941445275415, 5.738793548317167\n        ],\n        \"y_pad\"       => [1.0, 3.605551275463989, 6.4031242374328485],\n        \"dx\"          => [\n            0.17258020254042603, 1.9525221042381296,\n            1.2774501198988355, 3.496467771732918, 0.0\n        ],\n        \"dx_nostride\" => [\n            0.48074985676913606, 3.1458422620080637,\n            4.752311710531486, 6.345225258061685, 4.356316321455918\n        ],\n        \"dx_pad\"       => [1.0, 2.0, 3.0, 4.0, 5.0],\n        \"p\"           => 4.5,\n        \"p_nostride\"  => 3.0,\n        \"p_pad\"       => 2.0\n    ),\n    2 => Dict(\n        \"y\"           => [\n            8.71909  24.9703;\n            11.7336  28.3804\n        ],\n        \"y_nostride\"  => [\n            11.1128  23.134   35.5704;\n            13.4219  25.6082  38.0707;\n            15.8033  28.0907  40.5735;\n            18.2249  30.5795  43.0782\n        ],\n        \"y_pad\"       => [\n            1.0      11.3616  16.0;\n            3.19158  15.9662  21.3545;\n            5.56869  18.7771  23.7903\n        ],\n        \"dx\"          => [\n            0.33866   4.97727  7.30092  12.8076;\n            0.957876  6.27208  8.31879  14.0269;\n            1.51693   6.6057   8.79844  14.3351;\n            2.33547   7.8822   9.83293  15.5461;\n            0.0       0.0      0.0      0.0 \n        ],\n        \"dx_nostride\" => [\n            3.33359  19.9471  35.7329  23.8564;\n            9.89551  44.627   76.2257  50.0307;\n           13.231    50.9101  82.5686  53.2022;\n           16.4888   57.223   88.9133  56.3742;\n            9.54591  30.9869  46.8371  29.3524\n        ],\n        \"dx_pad\"      => [\n            1.0       2.30261  10.4791   16.0;\n            0.992125  2.0321    7.81903  12.075;\n            2.73398   2.83743   9.5512   13.9299;\n            2.43512   2.98652   9.0132   13.5608;\n            4.25398   3.8865   10.7099   15.4161\n        ],\n        \"p\"           => 2.5,\n        \"p_nostride\"  => 1.5,\n        \"p_pad\"       => 3.5\n    )\n)\n\nfor rank in (1, 2, 3)\n    @testset \"pool$(rank)d\" begin\n        for (pool, ∇pool, answer_dict) in (\n                # Main API name\n                (maxpool, ∇maxpool, maxpool_answer_dict),\n                (meanpool, ∇meanpool, meanpool_answer_dict),\n\n                # _direct name\n                (NNlib.maxpool_direct, NNlib.∇maxpool_direct, maxpool_answer_dict),\n                (NNlib.meanpool_direct, NNlib.∇meanpool_direct, meanpool_answer_dict),)\n\n            @testset \"$(pool)$(rank)d\" begin\n                y = answer_dict[rank][\"y\"]\n                y_nostride = answer_dict[rank][\"y_nostride\"]\n                y_pad = answer_dict[rank][\"y_pad\"]\n                dx = answer_dict[rank][\"dx\"]\n                dx_nostride = answer_dict[rank][\"dx_nostride\"]\n                dx_pad = answer_dict[rank][\"dx_pad\"]\n\n                x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1)\n\n                # A \"drop channels and batch dimension\" helper\n                ddims(x) = dropdims(x, dims=(rank + 1, rank + 2))\n\n                # Let's ensure that a 1x1x1 pooling kernel always just returns `x`\n                @test pool(x, PoolDims(x, 1)) == x\n\n                # Test vanilla pooling\n                pdims = PoolDims(x, 2)\n                y_hat = pool(x, pdims)\n                @test ddims(y_hat) == y\n                @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx\n\n                # Strided pooling\n                pdims = PoolDims(x, 2; stride=1)\n                y_hat = pool(x, pdims)\n                @test ddims(y_hat) == y_nostride\n                @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx_nostride\n\n                # Padded pooling\n                pdims = PoolDims(x, 2; padding=1)\n                y_hat = pool(x, pdims)\n                @test ddims(y_hat) == y_pad\n                @test ddims(∇pool(y_hat, y_hat, x, pdims)) == dx_pad\n            end\n        end\n    end\nend\n\nfor rank in (1, 2)\n    for (pool, ∇pool, answer_dict) in (\n            (lpnormpool, ∇lpnormpool, lpnormpool_answer_dict),\n            (NNlib.lpnormpool_direct, NNlib.∇lpnormpool_direct, lpnormpool_answer_dict),)\n        @testset \"$(pool)$(rank)d\" begin\n            y = answer_dict[rank][\"y\"]\n            y_nostride = answer_dict[rank][\"y_nostride\"]\n            y_pad = answer_dict[rank][\"y_pad\"]\n            dx = answer_dict[rank][\"dx\"]\n            dx_nostride = answer_dict[rank][\"dx_nostride\"]\n            dx_pad = answer_dict[rank][\"dx_pad\"]\n            p = answer_dict[rank][\"p\"]\n            p_nostride = answer_dict[rank][\"p_nostride\"]\n            p_pad = answer_dict[rank][\"p_pad\"]\n\n            x = reshape(Float64[1:prod(size(dx));], size(dx)..., 1, 1)\n\n            ddims(x) = dropdims(x, dims=(rank + 1, rank + 2))\n\n            @test pool(x, PoolDims(x, 1); p=p) ≈ x atol = 1e-3\n\n            # Test vanilla pooling\n            pdims = PoolDims(x, 2)\n            y_hat = pool(x, pdims; p=p)\n            @test ddims(y_hat) ≈ y atol = 1e-3\n            @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p)) ≈ dx atol = 1e-3\n\n            # Strided pooling\n            pdims = PoolDims(x, 2; stride=1)\n            y_hat = pool(x, pdims; p=p_nostride)\n            @test ddims(y_hat) ≈ y_nostride atol = 1e-3\n            @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p_nostride)) ≈ dx_nostride atol = 1e-3\n\n            # Padded pooling\n            pdims = PoolDims(x, 2; padding=1)\n            y_hat = pool(x, pdims; p=p_pad)\n            @test ddims(y_hat) ≈ y_pad atol = 1e-3\n            @test ddims(∇pool(y_hat, y_hat, x, pdims; p=p_pad)) ≈ dx_pad atol = 1e-3\n        end\n    end\nend\n\n@testset \"Pooling - Check Sizes\" begin\n    x = rand(10, 10, 3, 10)\n    @test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)\n    @test size(maxpool(x, (2, 2); pad=(1, 1), stride=(2, 2))) == (6, 6, 3, 10)\n    @test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)\n    @test size(meanpool(x, (2, 2); pad=(1, 1), stride=(2, 2))) == (6, 6, 3, 10)\nend\n\n# Add another test for 2d maxpool that uses an odd-length size:\n@testset \"Issue #133\" begin\n    x = reshape([(1.:9.)...], 3, 3, 1, 1)\n    pdims = PoolDims(size(x), (2, 2), padding=(1, 1), stride=(2, 2))\n    y = maxpool(x, pdims)\n\n    dy = y .* 0 .+ 1\n    dx = ∇maxpool(dy, y, x, pdims)\n    @test dx[:,:,1,1] == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0]\nend\n\n# test \"true\" strided case, see https://github.com/FluxML/NNlib.jl/issues/205\n\n\n# obtained with\n# using FiniteDifferences\nmaxpool_answer_nature = Dict(\n    \"rank1\" => Dict(\n        # kernel size 2, stride 1, pad 0\n        \"k2s1p0\" => (size = (2,),\n            stride = 1,\n            pad = 0,\n\n            x = reshape([\n                0.0299635,  0.233456,  0.596161,   0.161514,  0.0094027\n            ], 5, 1, 1), # width, channel, batch_size\n\n            dx_maxpool = reshape([\n                 0.0, 1.0, 2.0, 1.0, 0.0\n            ], 5, 1, 1),\n\n            dx_meanpool = reshape([\n                 0.5, 1.0, 1.0, 1.0, 0.5\n            ], 5, 1, 1),),\n        \"k2s1p1\" => (size = (2,),\n            stride = 1,\n            pad = 1,\n\n            x = reshape([\n                0.0299635,  0.233456,  0.596161,   0.161514,  0.0094027\n            ], 5, 1, 1),\n\n            dx_maxpool = reshape([\n                 1.0, 1.0, 2.0, 1.0, 1.0\n            ], 5, 1, 1),\n\n            dx_meanpool = reshape([\n                 1.0, 1.0, 1.0, 1.0, 1.0\n            ], 5, 1, 1),),\n        \"k3s1p1\" => (size = (3,),\n            stride = 1,\n            pad = 1,\n\n            x = reshape([\n                0.0299635,  0.233456,  0.596161,   0.161514,  0.0094027\n            ], 5, 1, 1),\n\n            dx_maxpool = reshape([\n                 0.0, 1.0, 3.0, 1.0, 0.0\n            ], 5, 1, 1),\n\n            dx_meanpool = reshape([\n                 0.6666666666, 1.0, 1.0, 1.0, 0.6666666666\n            ], 5, 1, 1),),\n        \"k3s2p1\" => (size = (3,),\n            stride = 2,\n            pad = 1,\n\n            x = reshape([\n                0.0299635,  0.233456,  0.596161,   0.161514,  0.0094027\n            ], 5, 1, 1),\n\n            dx_maxpool = reshape([\n                 0.0, 1.0, 1.0, 1.0, 0.0\n            ], 5, 1, 1),\n\n            dx_meanpool = reshape([\n                 0.333333333,\n                 0.666666666,\n                 0.333333333,\n                 0.666666666,\n                 0.333333333,\n            ], 5, 1, 1),)\n    ),\n    \"rank2\" => Dict(\n        # kernel size 2, stride 1, pad 0\n        \"k2s1p0\" => (size = (2, 2),\n            stride = 1,\n            pad = 0,\n\n            x = reshape([\n                0.0299635  0.233456  0.596161   0.161514  0.0094027\n                0.389984   0.235158  0.579525   0.301893  0.561358\n                0.0830242  0.483759  0.914904   0.253871  0.820061\n                0.425287   0.53451   0.0405225  0.729861  0.403925\n                0.473724   0.571418  0.558427   0.552183  0.561624\n            ], 5, 5, 1, 1),\n\n            dx_maxpool = reshape([\n                0.0  0.0  2.0  0.0  0.0\n                1.0  0.0  0.0  0.0  1.0\n                0.0  1.0  4.0  0.0  2.0\n                0.0  1.0  0.0  2.0  0.0\n                0.0  2.0  0.0  0.0  0.0\n            ], 5, 5, 1, 1),\n\n            dx_meanpool = reshape([\n                0.25  0.5  0.5  0.5  0.25\n                0.5   1.0  1.0  1.0  0.5\n                0.5   1.0  1.0  1.0  0.5\n                0.5   1.0  1.0  1.0  0.5\n                0.25  0.5  0.5  0.5  0.25\n            ], 5, 5, 1, 1)),\n        \"k2s1p1\" => (size = (2, 2),\n            stride = 1,\n            pad = 1,\n\n            x = reshape([\n                0.0299635  0.233456  0.596161   0.161514  0.0094027\n                0.389984   0.235158  0.579525   0.301893  0.561358\n                0.0830242  0.483759  0.914904   0.253871  0.820061\n                0.425287   0.53451   0.0405225  0.729861  0.403925\n                0.473724   0.571418  0.558427   0.552183  0.561624\n            ], 5, 5, 1, 1),\n\n            dx_maxpool = reshape([\n                1.0  1.0  4.0  1.0  1.0\n                3.0  0.0  0.0  0.0  2.0\n                0.0  1.0  4.0  0.0  4.0\n                1.0  1.0  0.0  2.0  0.0\n                2.0  4.0  1.0  0.0  3.0\n            ], 5, 5, 1, 1),\n\n            dx_meanpool = reshape([\n                1.0  1.0  1.0  1.0  1.0\n                1.0  1.0  1.0  1.0  1.0\n                1.0  1.0  1.0  1.0  1.0\n                1.0  1.0  1.0  1.0  1.0\n                1.0  1.0  1.0  1.0  1.0\n            ], 5, 5, 1, 1)),\n        \"k3s1p1\" => (size = (3, 3),\n            stride = 1,\n            pad = 1,\n\n            x = reshape([\n                0.0299635  0.233456  0.596161   0.161514  0.0094027\n                0.389984   0.235158  0.579525   0.301893  0.561358\n                0.0830242  0.483759  0.914904   0.253871  0.820061\n                0.425287   0.53451   0.0405225  0.729861  0.403925\n                0.473724   0.571418  0.558427   0.552183  0.561624\n            ], 5, 5, 1, 1),\n\n            dx_maxpool = reshape([\n                0.0  0.0  3.0  0.0  0.0\n                1.0  0.0  0.0  0.0  1.0\n                0.0  1.0  9.0  0.0  3.0\n                0.0  1.0  0.0  3.0  0.0\n                0.0  3.0  0.0  0.0  0.0\n            ], 5, 5, 1, 1),\n\n            dx_meanpool = reshape([\n                0.444444  0.666667  0.666667  0.666667  0.444444\n                0.666667  1.0       1.0       1.0       0.666667\n                0.666667  1.0       1.0       1.0       0.666667\n                0.666667  1.0       1.0       1.0       0.666667\n                0.444444  0.666667  0.666667  0.666667  0.444444\n            ], 5, 5, 1, 1)),\n        \"k3s2p1\" => (size = (3, 3),\n            stride = 2,\n            pad = 1,\n\n            x = reshape([\n                0.0299635  0.233456  0.596161   0.161514  0.0094027\n                0.389984   0.235158  0.579525   0.301893  0.561358\n                0.0830242  0.483759  0.914904   0.253871  0.820061\n                0.425287   0.53451   0.0405225  0.729861  0.403925\n                0.473724   0.571418  0.558427   0.552183  0.561624\n            ], 5, 5, 1, 1),\n\n            dx_maxpool = reshape([\n                0.0  0.0  1.0  0.0  0.0\n                1.0  0.0  0.0  0.0  1.0\n                0.0  0.0  1.0  0.0  1.0\n                0.0  1.0  0.0  2.0  0.0\n                0.0  1.0  0.0  0.0  0.0\n            ], 5, 5, 1, 1),\n\n            dx_meanpool = reshape([\n                0.111111  0.222222  0.111111  0.222222  0.111111\n                0.222222  0.444444  0.222222  0.444444  0.222222\n                0.111111  0.222222  0.111111  0.222222  0.111111\n                0.222222  0.444444  0.222222  0.444444  0.222222\n                0.111111  0.222222  0.111111  0.222222  0.111111\n            ], 5, 5, 1, 1))\n    ),\n    \"rank3\" => Dict(\n        # kernel size 2, stride 1, pad 0\n        \"k2s1p0\" => (size = (2, 2, 2),\n            stride = 1,\n            pad = 0,\n\n            x = reshape(cat([\n                    0.82584   0.416818  0.92668   0.471931\n                    0.798798  0.131608  0.344556  0.79681\n                    0.716898  0.320672  0.24453   0.288568\n                    0.261484  0.258469  0.121916  0.0685961\n                ],\n                [\n                    0.73934   0.16631    0.525109   0.0223458\n                    0.164918  0.790875   0.444085   0.469671\n                    0.116848  0.359845   0.0653075  0.804886\n                    0.525431  0.0402844  0.846814   0.84876\n                ],\n                [\n                    0.709245  0.325828  0.715952  0.719116\n                    0.576722  0.405659  0.770104  0.259131\n                    0.640221  0.28811   0.129229  0.97571\n                    0.953795  0.1316    0.94538   0.705337\n                ],dims=3), 4,4,3,1,1),\n\n            dx_maxpool = reshape(cat([\n                     1.0  0.0  2.0  0.0\n                     1.0  0.0  0.0  0.0\n                     1.0  0.0  0.0  0.0\n                     0.0  0.0  0.0  0.0\n                ],\n                [\n                     0.0  0.0  0.0  0.0\n                     0.0  5.0  0.0  0.0\n                     0.0  0.0  0.0  1.0\n                     0.0  0.0  1.0  1.0\n                ],\n                [\n                     0.0  0.0  0.0  0.0\n                     0.0  0.0  1.0  0.0\n                     0.0  0.0  0.0  2.0\n                     1.0  0.0  1.0  0.0\n                ],dims=3), 4,4,3,1,1),\n\n            dx_meanpool = reshape(cat([\n                     0.125  0.25  0.25  0.125\n                     0.25   0.5   0.5   0.25\n                     0.25   0.5   0.5   0.25\n                     0.125  0.25  0.25  0.125\n                ],\n                [\n                     0.25  0.5  0.5  0.25\n                     0.5   1.0  1.0  0.5\n                     0.5   1.0  1.0  0.5\n                     0.25  0.5  0.5  0.25\n                ],\n                [\n                     0.125  0.25  0.25  0.125\n                     0.25   0.5   0.5   0.25\n                     0.25   0.5   0.5   0.25\n                     0.125  0.25  0.25  0.125\n                ],dims=3), 4,4,3,1,1)),\n        \"k2s1p1\" => (size = (2, 2, 2),\n            stride = 1,\n            pad = 1,\n\n            x = reshape(cat([\n                    0.82584   0.416818  0.92668   0.471931\n                    0.798798  0.131608  0.344556  0.79681\n                    0.716898  0.320672  0.24453   0.288568\n                    0.261484  0.258469  0.121916  0.0685961\n                ],\n                [\n                    0.73934   0.16631    0.525109   0.0223458\n                    0.164918  0.790875   0.444085   0.469671\n                    0.116848  0.359845   0.0653075  0.804886\n                    0.525431  0.0402844  0.846814   0.84876\n                ],\n                [\n                    0.709245  0.325828  0.715952  0.719116\n                    0.576722  0.405659  0.770104  0.259131\n                    0.640221  0.28811   0.129229  0.97571\n                    0.953795  0.1316    0.94538   0.705337\n                ],dims=3), 4,4,3,1,1),\n\n            dx_maxpool = reshape(cat([\n                     8.0  0.0  8.0  2.0\n                     4.0  0.0  1.0  4.0\n                     4.0  1.0  0.0  2.0\n                     2.0  1.0  1.0  1.0\n                ],\n                [\n                     3.0  0.0  0.0  0.0\n                     0.0  5.0  0.0  0.0\n                     0.0  0.0  0.0  2.0\n                     2.0  0.0  2.0  5.0\n                ],\n                [\n                     4.0  0.0  2.0  6.0\n                     0.0  0.0  4.0  0.0\n                     3.0  0.0  0.0  8.0\n                     8.0  0.0  6.0  1.0\n                ],dims=3), 4,4,3,1,1),\n\n            dx_meanpool = reshape(cat([\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                ],\n                [\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                ],\n                [\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                     1.0  1.0  1.0  1.0\n                ],dims=3), 4,4,3,1,1)),\n        \"k3s1p1\" => (size = (3, 3, 2),\n            stride = 1,\n            pad = 1,\n\n            x = reshape(cat([\n                    0.82584   0.416818  0.92668   0.471931\n                    0.798798  0.131608  0.344556  0.79681\n                    0.716898  0.320672  0.24453   0.288568\n                    0.261484  0.258469  0.121916  0.0685961\n                ],\n                [\n                    0.73934   0.16631    0.525109   0.0223458\n                    0.164918  0.790875   0.444085   0.469671\n                    0.116848  0.359845   0.0653075  0.804886\n                    0.525431  0.0402844  0.846814   0.84876\n                ],\n                [\n                    0.709245  0.325828  0.715952  0.719116\n                    0.576722  0.405659  0.770104  0.259131\n                    0.640221  0.28811   0.129229  0.97571\n                    0.953795  0.1316    0.94538   0.705337\n                ],dims=3), 4,4,3,1,1),\n\n            dx_maxpool = reshape(cat([\n                     4.0  0.0  12.0  0.0\n                     3.0  0.0   0.0  2.0\n                     3.0  1.0   0.0  1.0\n                     0.0  0.0   0.0  0.0\n                ],\n                [\n                     0.0  0.0  0.0  0.0\n                     0.0  5.0  0.0  0.0\n                     0.0  0.0  0.0  0.0\n                     0.0  0.0  2.0  4.0\n                ],\n                [\n                     2.0  0.0  0.0   0.0\n                     0.0  0.0  5.0   0.0\n                     0.0  0.0  0.0  12.0\n                     8.0  0.0  0.0   0.0\n                ],dims=3), 4,4,3,1,1),\n\n            dx_meanpool = reshape(cat([\n                     0.444444  0.666667  0.666667  0.444444\n                     0.666667  1.0       1.0       0.666667\n                     0.666667  1.0       1.0       0.666667\n                     0.444444  0.666667  0.666667  0.444444\n                ],\n                [\n                     0.444444  0.666667  0.666667  0.444444\n                     0.666667  1.0       1.0       0.666667\n                     0.666667  1.0       1.0       0.666667\n                     0.444444  0.666667  0.666667  0.444444\n                ],\n                [\n                     0.444444  0.666667  0.666667  0.444444\n                     0.666667  1.0       1.0       0.666667\n                     0.666667  1.0       1.0       0.666667\n                     0.444444  0.666667  0.666667  0.444444\n                ],dims=3), 4,4,3,1,1)),\n        \"k3s2p1\" => (size = (3, 3, 2),\n            stride = 2,\n            pad = 1,\n\n            x = reshape(cat([\n                    0.82584   0.416818  0.92668   0.471931\n                    0.798798  0.131608  0.344556  0.79681\n                    0.716898  0.320672  0.24453   0.288568\n                    0.261484  0.258469  0.121916  0.0685961\n                ],\n                [\n                    0.73934   0.16631    0.525109   0.0223458\n                    0.164918  0.790875   0.444085   0.469671\n                    0.116848  0.359845   0.0653075  0.804886\n                    0.525431  0.0402844  0.846814   0.84876\n                ],\n                [\n                    0.709245  0.325828  0.715952  0.719116\n                    0.576722  0.405659  0.770104  0.259131\n                    0.640221  0.28811   0.129229  0.97571\n                    0.953795  0.1316    0.94538   0.705337\n                ],dims=3), 4,4,3,1,1),\n\n            dx_maxpool = reshape(cat([\n                     1.0  0.0  1.0  0.0\n                     1.0  0.0  0.0  1.0\n                     0.0  0.0  0.0  0.0\n                     0.0  0.0  0.0  0.0\n                ],\n                [\n                     0.0  0.0  0.0  0.0\n                     0.0  2.0  0.0  0.0\n                     0.0  0.0  0.0  0.0\n                     0.0  0.0  0.0  0.0\n                ],\n                [\n                     0.0  0.0  0.0  0.0\n                     0.0  0.0  0.0  0.0\n                     0.0  0.0  0.0  1.0\n                     1.0  0.0  0.0  0.0\n                ],dims=3), 4,4,3,1,1),\n\n            dx_meanpool = reshape(cat([\n                     0.0555556  0.111111  0.0555556  0.0555556\n                     0.111111   0.222222  0.111111   0.111111\n                     0.0555556  0.111111  0.0555556  0.0555556\n                     0.0555556  0.111111  0.0555556  0.0555556\n                ],\n                [\n                     0.0555556  0.111111  0.0555556  0.0555556\n                     0.111111   0.222222  0.111111   0.111111\n                     0.0555556  0.111111  0.0555556  0.0555556\n                     0.0555556  0.111111  0.0555556  0.0555556\n                ],\n                [\n                     0.0555556  0.111111  0.0555556  0.0555556\n                     0.111111   0.222222  0.111111   0.111111\n                     0.0555556  0.111111  0.0555556  0.0555556\n                     0.0555556  0.111111  0.0555556  0.0555556\n                ],dims=3), 4,4,3,1,1))\n    )\n)\n\n\n@testset \"more maxpool and meanpool tests\" begin\n    # issue #205\n    function check(config, T)\n        # CHECK DEFAULT\n        pdims = PoolDims(config.x, config.size; stride=config.stride, padding=config.pad)\n        x = T.(config.x)\n        y_maxpool = NNlib.maxpool(x, pdims)\n        y_meanpool = NNlib.meanpool(x, pdims)\n        dy = ones(T, size(y_maxpool)...) # size(y_maxpool) == size(y_meanpool)\n        @test isapprox(config.dx_maxpool, NNlib.∇maxpool(dy, y_maxpool, x, pdims), rtol=1e-5)\n        @test isapprox(config.dx_meanpool, NNlib.∇meanpool(dy, y_meanpool, x, pdims), rtol=1e-5)\n        # CHECK DIRECT\n        y_maxpool_dir = NNlib.maxpool_direct(x, pdims)\n        y_meanpool_dir = NNlib.meanpool_direct(x, pdims)\n        @test y_maxpool_dir ≈ y_maxpool  atol = 1e-6\n        @test isapprox(config.dx_maxpool, NNlib.∇maxpool_direct(dy, y_maxpool_dir, x, pdims), rtol=1e-5)\n        @test isapprox(config.dx_meanpool, NNlib.∇meanpool_direct(dy, y_meanpool_dir, x, pdims), rtol=1e-5)\n    end\n\n    for (rank_name, config_dict) in maxpool_answer_nature\n        for (setting_name, config) in config_dict\n            for T in (Float32, Float64)\n                check(config, T)\n            end\n        end\n    end\n\n    # issue 210\n    x, k = rand(Float32, 5, 2, 1, 3), (2, 1)\n    pdims1 = NNlib.PoolDims(x, k, padding=1, stride=1)\n    pdims2 = NNlib.PoolDims(x, k, padding=(1, 0, 0, 0), stride=1)\n    @test maxpool(x, pdims1) isa Array{Float32,4}\n    @test maxpool(x, pdims2) isa Array{Float32,4}\n\n    # issue #229\n    x = ones(Float32, 4, 4, 1, 1) .* -1\n    pool = meanpool(x, PoolDims(x, 2, padding=1))\n    valid = reshape([\n    -0.25,  -0.5,  -0.25,\n    -0.5,   -1.0,  -0.5,\n    -0.25,  -0.5,  -0.25], (3, 3, 1, 1))\n    @test all(pool .== valid)\n\n    # issue #484\n    # Description: some in-place pooling functions only accepted arrays with the same eltype.\n    # The strict method signatures were based on assumption on the return type of `similar`.\n    # For ReverseDiff, this caused problems, e.g. with taking derivatives of pooling \n    # operations.\n    # Now, if explicitly calling an in-place pooling functions, a different `yT` is allowed.\n    for xT in (Int32, Int64, Float16, Float32, Float64, BigFloat)\n        for (xsz, psz) in (     # test a few different data and kernel sizes\n            ((1,1), (1,1)),\n            ((1,2), (1,1)), ((1,2), (1,2)),\n            ((2,1), (1,1)), ((2,1), (2,1)),\n            ((2,2), (1,1)), ((2,2), (1,2)), ((2,2), (2,1)),\n        )\n            x = ones(xT, xsz..., 1, 1)\n            pdims = PoolDims(x, psz)\n            for yT in (Float16, Float32, Float64, BigFloat) \n                # `yT` is the target eltype and we do not test integer types here\n                # because those cannot always store the pooling results.\n                y = similar(x, yT, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, 4))\n                @test maxpool!(y, x, pdims) isa Array{yT}\n                @test meanpool!(y, x, pdims) isa Array{yT}\n                @test lpnormpool!(y, x, pdims; p=2) isa Array{yT}\n                @test lpnormpool!(y, x, pdims; p=1.0) isa Array{yT}\n            end\n        end\n    end\n    \n    # This is how to test #484 with ReverseDiff:\n    x = reshape(Float32[ 1 2; 3 4 ], (2,2,1,1))\n    @test only(maxpool(x, (2,2))) == 4\n    # define typemin, because of https://github.com/JuliaDiff/ReverseDiff.jl/issues/225\n    Base.typemin(tr::Type{<:T}) where{V, T<:RD.TrackedReal{V, <:Any, <:Any}} = T(typemin(V))\n    @test RD.gradient(_x -> only(maxpool(_x,(2,2))), x)[:,:,1,1] == [0 0; 0 1]\n    @test only(meanpool(x, (2,2))) == 2.5\n    @test all(==(0.25), RD.gradient(_x -> only(meanpool(_x,(2,2))), x))\nend\n\n@testset \"AutoDiff: spatial_rank=$spatial_rank\" for spatial_rank in (1, 2)\n  x = rand(rng, repeat([10], spatial_rank)..., 3, 2)\n  pdims = PoolDims(x, 2)\n  gradtest(x -> maxpool(x, pdims), x; skip = spatial_rank==2)\n  gradtest(x -> meanpool(x, pdims), x)\n  gradtest(x -> sum(maxpool(x, pdims)), x, skip = spatial_rank==2)\n  gradtest(x -> sum(meanpool(x, pdims)), x)\n\n  #https://github.com/FluxML/NNlib.jl/issues/188\n  k = ntuple(_ -> 2, spatial_rank)  # Kernel size of pool in ntuple format\n  gradtest(x -> maxpool(x, k), x; skip = spatial_rank==2)\n  gradtest(x -> meanpool(x, k), x)\n  gradtest(x -> sum(maxpool(x, k)), x, skip = spatial_rank==2)\n  gradtest(x -> sum(meanpool(x, k)), x)\nend\n\n@static if Test_Enzyme\n\n@testset \"EnzymeRules: pooling! $pool spatial_rank=$spatial_rank \" for spatial_rank in (1, 2),\n                                                                                (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!))\n\n  x = rand(rng, repeat([10], spatial_rank)..., 3, 2)\n  pdims = PoolDims(x, 2)\n  y = pool(x, pdims)\n\n  for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n    Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)\n\n    Tret == EnzymeCore.Const && continue # ERROR\n    EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue\n\n    EnzymeTestUtils.test_reverse(pool!, Tret, (y, Tdst), (x, Tsrc), (pdims, EnzymeCore.Const))\n  end\n\nend\n\nend"
  },
  {
    "path": "test/runtests.jl",
    "content": "using NNlib, Test, Statistics, Random\nusing ChainRulesCore, ChainRulesTestUtils\nusing Base.Broadcast: broadcasted\nimport EnzymeTestUtils\nusing EnzymeCore\nimport FiniteDifferences\nimport ForwardDiff\nimport Zygote\nusing Zygote: gradient\nusing StableRNGs\nusing Documenter\nusing Adapt\nusing ImageTransformations\nusing Interpolations: Constant\nusing KernelAbstractions\nusing FFTW\nimport ReverseDiff as RD        # used in `pooling.jl`\nimport Pkg\nusing SpecialFunctions\n\nconst Test_Enzyme = VERSION <= v\"1.12-\"\n\nDocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true)\n\n# ENV[\"NNLIB_TEST_CUDA\"] = \"true\" # uncomment to run CUDA tests\n# ENV[\"NNLIB_TEST_AMDGPU\"] = \"true\" # uncomment to run AMDGPU tests\n# ENV[\"NNLIB_TEST_METAL\"] = \"true\" # uncomment to run Metal tests\n# ENV[\"NNLIB_TEST_CPU\"] = \"false\" # uncomment to skip CPU tests\n\nconst rng = StableRNG(123)\ninclude(\"test_utils.jl\")\n\nmacro conditional_testset(name, skip_tests, expr)\n    esc(quote\n        @testset $name begin\n            if $name ∉ $skip_tests\n                $expr\n            else\n                @test_skip false\n            end\n        end\n    end)\nend\n\ncpu(x) = adapt(CPU(), x)\n\ninclude(\"testsuite/gather.jl\")\ninclude(\"testsuite/scatter.jl\")\ninclude(\"testsuite/upsample.jl\")\ninclude(\"testsuite/rotation.jl\")\ninclude(\"testsuite/spectral.jl\")\ninclude(\"testsuite/fold.jl\")\n\nfunction nnlib_testsuite(Backend; skip_tests = Set{String}())\n    @conditional_testset \"Upsample\" skip_tests begin\n        upsample_testsuite(Backend)\n    end\n    @conditional_testset \"rotation\" skip_tests begin\n        rotation_testsuite(Backend)\n    end\n    @conditional_testset \"Gather\" skip_tests begin\n        gather_testsuite(Backend)\n    end\n    @conditional_testset \"Scatter\" skip_tests begin\n        scatter_testsuite(Backend)\n    end\n    @conditional_testset \"Spectral\" skip_tests begin\n        spectral_testsuite(Backend)\n    end\n    @conditional_testset \"Fold\" skip_tests begin\n        fold_testsuite(Backend)\n    end\nend\n\n@testset verbose=true \"NNlib.jl\" begin\n\n    if get(ENV, \"NNLIB_TEST_CPU\", \"true\") == \"true\"\n        @testset \"CPU\" begin\n            @testset \"Doctests\" begin\n                doctest(NNlib, manual=false)\n            end\n\n            nnlib_testsuite(CPU)\n\n            if Threads.nthreads(:default) > 1\n                @test NNlib.should_use_spawn()\n                NNlib.@disallow_spawns begin\n                    @test NNlib.should_use_spawn() == false\n                end\n            else\n                @test NNlib.should_use_spawn() == false\n            end\n\n            @testset \"Activation Functions\" begin\n                include(\"activations.jl\")\n                include(\"bias_act.jl\")\n            end\n\n            @testset \"Attention\" begin\n                include(\"attention.jl\")\n            end\n\n            @testset \"Batched Multiplication\" begin\n                include(\"batchedmul.jl\")\n            end\n\n            @testset \"Convolution\" begin\n                include(\"conv.jl\")\n                include(\"conv_bias_act.jl\")\n            end\n\n            @testset \"CTC Loss\" begin\n                include(\"ctc.jl\")\n            end\n\n            @testset \"Dropout\" begin\n                include(\"dropout.jl\")\n            end\n\n            @testset \"Inference\" begin\n                include(\"inference.jl\")\n            end\n\n            @testset \"Pooling\" begin\n                include(\"pooling.jl\")\n            end\n\n            @testset \"Padding\" begin\n                include(\"padding.jl\")\n            end\n\n            @testset \"Softmax\" begin\n                include(\"softmax.jl\")\n            end\n\n            @testset \"Utilities\" begin\n                include(\"utils.jl\")\n            end\n\n            @testset \"Grid Sampling\" begin\n                include(\"sampling.jl\")\n            end\n\n            @testset \"Functions\" begin\n                include(\"functions.jl\")\n            end\n        end\n    else\n        @info \"Skipping CPU tests, set NNLIB_TEST_CPU=true to run them.\"\n    end\n\n    if get(ENV, \"NNLIB_TEST_CUDA\", \"false\") == \"true\"\n        Pkg.add([\"CUDA\", \"cuDNN\"])\n\n        using CUDA\n        if CUDA.functional()\n            @testset \"CUDA\" begin\n                nnlib_testsuite(CUDABackend; skip_tests=Set((\"Scatter\", \"Gather\")))\n\n                include(\"ext_cuda/runtests.jl\")\n            end\n        else\n            @info \"Insufficient version or CUDA not found; Skipping CUDA tests\"\n        end\n    else\n        @info \"Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them\"\n    end\n\n    if get(ENV, \"NNLIB_TEST_AMDGPU\", \"false\") == \"true\"\n        Pkg.add(\"AMDGPU\")\n\n        using AMDGPU\n        AMDGPU.versioninfo()\n        if AMDGPU.functional() && AMDGPU.functional(:MIOpen)\n            @testset \"AMDGPU\" begin\n                nnlib_testsuite(ROCBackend)\n                AMDGPU.synchronize(; blocking=false, stop_hostcalls=true)\n\n                include(\"ext_amdgpu/runtests.jl\")\n                AMDGPU.synchronize(; blocking=false, stop_hostcalls=true)\n            end\n        else\n            @info \"AMDGPU.jl package is not functional. Skipping AMDGPU tests.\"\n        end\n    else\n        @info \"Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them.\"\n    end\n\n    if get(ENV, \"NNLIB_TEST_METAL\", \"false\") == \"true\"\n        Pkg.add(\"Metal\")\n\n        using Metal\n        if Metal.functional()\n            @testset \"Metal\" begin\n                # nnlib_testsuite(MetalBackend)\n                include(\"ext_metal/runtests.jl\")\n            end\n        else\n            @info \"Insufficient version or Metal not found; Skipping Metal tests\"\n        end\n    else\n        @info \"Skipping Metal tests, set NNLIB_TEST_METAL=true to run them\"\n    end\nend\n"
  },
  {
    "path": "test/sampling.jl",
    "content": "@testset \"Known gradients\" begin\n    x = ones(Float64, (2, 2, 1, 1))\n    grid = Array{Float64}(undef, 2, 2, 2, 1)\n    grid[:, 1, 1, 1] .= (-1, -1)\n    grid[:, 2, 1, 1] .= (1, -1)\n    grid[:, 1, 2, 1] .= (-1, 1)\n    grid[:, 2, 2, 1] .= (1, 1)\n\n    ∇grid_true = Array{Float64}(undef, size(grid))\n    ∇grid_true[:, :, 1, 1] = [[0.0, 0.0] [-0.5, 0.0]]\n    ∇grid_true[:, :, 2, 1] = [[0.0, -0.5] [-0.5, -0.5]]\n\n    padding_mode = :zeros\n    sampled = grid_sample(x, grid; padding_mode=padding_mode)\n    @test x == sampled\n    @test eltype(sampled) == Float64\n    external_grad = ones(size(sampled))\n    ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode)\n    @test ∇input == x\n    @test ∇grid == ∇grid_true\n    @test eltype(∇input) == Float64\n    @test eltype(∇grid) == Float64\n\n    # ∇grid from FiniteDifferences is incorrent in case when 0-padding.\n    # gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,))\n\n    padding_mode = :border\n    fill!(∇grid_true, 0.0)\n    sampled = grid_sample(x, grid; padding_mode=padding_mode)\n    @test x == sampled\n    @test eltype(sampled) == Float64\n    external_grad = ones(size(sampled))\n    ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode)\n    @test ∇input == x\n    @test ∇grid == ∇grid_true\n    @test eltype(∇input) == Float64\n    @test eltype(∇grid) == Float64\n\n    gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,))\nend\n\n@testset \"Test out-of-bounds for different paddings\" begin\n    x = ones(Float64, (2, 2, 1, 1))\n    grid = Array{Float64}(undef, 2, 3, 2, 1)\n    grid[:, 1, 1, 1] .= (-3, -1)\n    grid[:, 2, 1, 1] .= (0, -1)\n    grid[:, 3, 1, 1] .= (3, -1)\n    grid[:, 1, 2, 1] .= (-1, 3)\n    grid[:, 2, 2, 1] .= (0, 1)\n    grid[:, 3, 2, 1] .= (1, 3)\n\n    # With 0-padding, out-of-bound values are will contribute nothing to\n    # the output values, because they are too far from any bound.\n    y = grid_sample(x, grid; padding_mode=:zeros)\n    y_true = reshape(Float64[[0, 1, 0] [0, 1, 0]], size(y))\n    @test y_true == y\n\n    # With border-padding, out-of-bound values simly become border values\n    # and the result should be all ones.\n    y = grid_sample(x, grid; padding_mode=:border)\n    y_true = ones(Float64, size(y))\n    @test y_true == y\nend\n\n@testset \"Known gradients 3D\" begin\n    x = ones(Float64, (2, 2, 2, 1, 1))  # 3D input with depth=2\n    grid = Array{Float64}(undef, 3, 2, 2, 2, 1)  # 3D grid with depth=2\n    grid[:, 1, 1, 1, 1] .= (-1, -1, -1)\n    grid[:, 2, 1, 1, 1] .= (1, -1, -1)\n    grid[:, 1, 2, 1, 1] .= (-1, 1, -1)\n    grid[:, 2, 2, 1, 1] .= (1, 1, -1)\n    grid[:, 1, 1, 2, 1] .= (-1, -1, 1)\n    grid[:, 2, 1, 2, 1] .= (1, -1, 1)\n    grid[:, 1, 2, 2, 1] .= (-1, 1, 1)\n    grid[:, 2, 2, 2, 1] .= (1, 1, 1)\n\n    ∇grid_true = Array{Float64}(undef, size(grid))\n    ∇grid_true[:, 1, 1, 1, 1] .= (0.0, 0.0, 0.0)\n    ∇grid_true[:, 2, 1, 1, 1] .= (-0.5, 0.0, 0.0)\n    ∇grid_true[:, 1, 2, 1, 1] .= (0.0, -0.5, 0.0)\n    ∇grid_true[:, 2, 2, 1, 1] .= (-0.5, -0.5, 0.0)\n    ∇grid_true[:, 1, 1, 2, 1] .= (0.0, 0.0, -0.5)\n    ∇grid_true[:, 2, 1, 2, 1] .= (-0.5, 0.0, -0.5)\n    ∇grid_true[:, 1, 2, 2, 1] .= (0.0, -0.5, -0.5)\n    ∇grid_true[:, 2, 2, 2, 1] .= (-0.5, -0.5, -0.5)\n\n    # ∇grid_true[:, :, :, 1, 1] = [\n    #     [[0.0, 0.0, 0.0], [-0.5, 0.0, 0.0]],\n    #     [[0.0, -0.5, 0.0], [-0.5, -0.5, 0.0]]\n    # ]\n    # ∇grid_true[:, :, :, 2, 1] = [\n    #     [[0.0, 0.0, -0.5], [-0.5, 0.0, -0.5]]\n    #     [[0.0, -0.5, -0.5], [-0.5, -0.5, -0.5]]\n    # ]\n\n    padding_mode = :zeros\n    sampled = grid_sample(x, grid; padding_mode=padding_mode)\n    @test x == sampled\n    @test eltype(sampled) == Float64\n    external_grad = ones(size(sampled))\n    ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode)\n    @test ∇input == x\n    @test ∇grid == ∇grid_true\n    @test eltype(∇input) == Float64\n    @test eltype(∇grid) == Float64\n\n    # ∇grid from FiniteDifferences is incorrect in case when 0-padding.\n    # gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,))\n\n    padding_mode = :border\n    fill!(∇grid_true, 0.0)\n    sampled = grid_sample(x, grid; padding_mode=padding_mode)\n    @test x == sampled\n    @test eltype(sampled) == Float64\n    external_grad = ones(size(sampled))\n    ∇input, ∇grid = ∇grid_sample(external_grad, x, grid; padding_mode=padding_mode)\n    @test ∇input == x\n    @test ∇grid == ∇grid_true\n    @test eltype(∇input) == Float64\n    @test eltype(∇grid) == Float64\n\n    gradtest(grid_sample, x, grid; fkwargs=(padding_mode=padding_mode,))\nend\n\n@testset \"Test out-of-bounds for different paddings 3D\" begin\n    x = ones(Float64, (2, 2, 2, 1, 1))  # 3D input with depth=2\n    grid = Array{Float64}(undef, 3, 2, 2, 2, 1)  # 3D grid with depth=2\n    grid[:, 1, 1, 1, 1] .= (-3, -1, -1)\n    grid[:, 2, 1, 1, 1] .= (0, -1, -1)\n    grid[:, 1, 2, 1, 1] .= (-1, 3, -1)\n    grid[:, 2, 2, 1, 1] .= (0, 1, -1)\n    grid[:, 1, 1, 2, 1] .= (-1, -1, 3)\n    grid[:, 2, 1, 2, 1] .= (0, -1, 3)\n    grid[:, 1, 2, 2, 1] .= (-1, 1, 3)\n    grid[:, 2, 2, 2, 1] .= (0, 1, 3)\n\n    # With 0-padding, out-of-bound values will contribute nothing to\n    # the output values, because they are too far from any bound.\n    y = grid_sample(x, grid; padding_mode=:zeros)\n    y_true = reshape(Float64[[0, 1] [0, 1] [0, 0] [0, 0]], size(y))\n    @test y_true == y\n\n    # With border-padding, out-of-bound values simply become border values\n    # and the result should be all ones.\n    y = grid_sample(x, grid; padding_mode=:border)\n    y_true = ones(Float64, size(y))\n    @test y_true == y\nend\n"
  },
  {
    "path": "test/softmax.jl",
    "content": "using Statistics: mean\nusing NNlib: ∇softmax_data, ∇logsoftmax_data\n\n@testset \"softmax integer input\" begin\n    @test softmax(Int[0, 0]) == [0.5, 0.5]\nend\n\n@testset \"softmax on different dims\" begin\n    xs = rand(fill(2, 5)...)\n    out = similar(xs)\n    for (fn!, fn) in [(softmax!, softmax), (logsoftmax!, logsoftmax)], i = 1:ndims(xs)\n        @test fn!(out, xs; dims = i) == fn(xs; dims = i)\n    end\nend\n\n@testset \"softmax\" begin\n    xs = rand(5, 5)\n    @test all(sum(softmax(xs), dims = 1) .≈ 1)\n    @test all(sum(softmax(xs; dims = 2), dims = 2) .≈ 1)\n    @test sum(softmax(vec(xs))) ≈ 1\n    @test log.(softmax(xs; dims = 2)) ≈ logsoftmax(xs; dims = 2)\n\n    xs = [-100_000.0, -100_000.0]\n    @test softmax(xs) ≈ [0.5, 0.5]\n    @test logsoftmax(xs) ≈ log.([0.5, 0.5])\n\n    xs = rand(5)\n    @test softmax(xs) ≈ exp.(xs) ./ sum(exp.(xs))\n    @test logsoftmax(xs) ≈ log.(softmax(xs))\n\n    xs = Float32[1, 2, 3000.0]\n    @test logsoftmax(xs) ≈ [-2999, -2998, 0]\n\n    xs = Float32[1 2 3; 1000 2000 3000]\n    @test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.0]\n\n    y = logsoftmax(xs)\n    @test ∇logsoftmax_data(ones(Float32, size(xs)), y) ≈ Float32[1 1 1; -1 -1 -1]\n    \n    y = softmax(xs)\n    @test ∇softmax_data(ones(Float32, size(xs)), y) ≈ zeros(Float32, size(xs))\n\n    # These values precalculated using PyTorch's nn.LogSoftmax\n    xs = [\n        -0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842\n        0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663\n        -1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977\n    ]\n    ys = [\n        0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273\n        -0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428\n        0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697\n    ]\n    \n    y = logsoftmax(xs)\n    @test ∇logsoftmax_data(ones(size(xs)), y) ≈ ys rtol = 1e-6\n    \n    y = softmax(xs)\n    @test ∇softmax_data(ones(size(xs)), y) ≈ zeros(size(xs)) atol = 1e-6\nend\n\n@testset \"softmax with Inf, NaN\" begin\n    @test softmax(Float32[1 2; 3 Inf]) ≈    Float32[0.11920292 0.0; 0.880797 1.0]\n    @test softmax(Float32[1 -Inf; 3 Inf]) ≈ Float32[0.11920292 0.0; 0.880797 1.0]\n    @test softmax(Float32[1 Inf; 3 Inf]) ≈  Float32[0.11920292 0.5; 0.880797 0.5]\n\n    @test softmax(Float32[1 2; 3 NaN]) ≈    Float32[0.11920292 NaN; 0.880797 NaN] nans=true\n    @test softmax(Float32[1 2; 3 Inf]; dims=2) ≈ Float32[0.26894143 0.7310586; 0.0 1.0]\n    @test softmax(Float32[1 2; 3 Inf]; dims=(:)) ≈ Float32[0.0 0.0; 0.0 1.0]\n    @test softmax(Float32[1 2; 3 Inf]; dims=(1,2)) ≈ Float32[0.0 0.0; 0.0 1.0]\n\n    @test exp.(logsoftmax(Float32[1 2; 3 Inf])) ≈ softmax(Float32[1 2; 3 Inf])\n    @test exp.(logsoftmax(Float32[1 -Inf; 3 Inf])) ≈ softmax(Float32[1 -Inf; 3 Inf])\n    @test exp.(logsoftmax(Float32[1 Inf; 3 Inf])) ≈ softmax(Float32[1 Inf; 3 Inf])\nend\n\n@testset \"mutating softmax\" begin\n    map([\n        Float64[1 2 3; 5 6 7],\n        Float64[\n            -0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842\n            0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663\n            -1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977\n        ],\n    ]) do xs\n        out = similar(xs)\n        softmax!(out, xs)\n        @test out ≈ softmax(xs) rtol = 1e-6\n        logsoftmax!(out, xs)\n        @test out ≈ logsoftmax(xs) rtol = 1e-6\n\n        @testset \"$fn(Float64, $(size(xs)))\" for fn in [zeros, ones, rand]\n            Δ = fn(Float64, size(xs))\n            y = softmax(xs) \n            ∇softmax!(out, Δ, xs, y)  # deprecated\n            @test out ≈ ∇softmax_data(Δ, y)  rtol = 1e-6\n            \n            y = logsoftmax(xs)\n            ∇logsoftmax!(out, Δ, xs, y)  # deprecated\n            @test out ≈ ∇logsoftmax_data(Δ, y)  rtol = 1e-6\n        end\n    end\nend\n\n@testset \"logsumexp\" begin\n    flogsoft(x; dims) = mean(x .- logsoftmax(x; dims = dims), dims = dims)\n\n    x = rand(3, 4)\n    @test logsumexp(x) ≈ flogsoft(x, dims = :)\n    @test logsumexp(x; dims = 1) ≈ flogsoft(x, dims = 1)\nend\n\n@testset \"AutoDiff\" begin\n    for f in (softmax, logsoftmax), d in (:, 1, 2)\n        gradtest(f, (3,4); fkwargs = (dims = d,), check_rrule = true)\n    end\n    gradtest(x -> softmax(x) .* (1:3), 3)\n    gradtest(x -> softmax(x) .* (1:3), (3,5), atol = 1e-4)\n    gradtest(x -> softmax(x, dims = 2) .* (1:3), (3,5), atol = 1e-4)\n\n    gradtest(x -> logsoftmax(x) .* (1:3), 3)\n    gradtest(x -> logsoftmax(x) .* (1:3), (3,5))\n    gradtest(x -> logsoftmax(x, dims = 2) .* (1:3), (3,5))\n\n    for d  in (:, 1, 2)\n        gradtest(logsumexp, (3,4), fkwargs = (dims = d,))\n    end\nend\n\n@testset \"Second derivatives\" begin\n    x = [1 2 3; 6 5 4]\n    H = Zygote.hessian_dual(x -> sum(sin, softmax(x)), x)\n    @test H ≈ Zygote.hessian_reverse(x -> sum(sin, softmax(x)), x)\n\n    H2 = Zygote.hessian_dual(x -> sum(sin, logsoftmax(x)), x)\n    @test H2 ≈ Zygote.hessian_reverse(x -> sum(sin, logsoftmax(x)), x)\n\n    H3 = Zygote.hessian_dual(x -> sum(sin, logsumexp(x)), x)\n    @test H3 ≈ Zygote.hessian_reverse(x -> sum(sin, logsumexp(x)), x)\nend\n"
  },
  {
    "path": "test/test_utils.jl",
    "content": "const IntOrTuple = Union{Int, NTuple{N,Int} where N}\n\ngradtest(f, dims::IntOrTuple...; kw...) =\n    gradtest(f, randn.(Ref(rng), Float64, dims)...; kw...) # julia v1.3 compat\n    # gradtest(f, randn.(rng, Float64, dims)...; kw...)\n\n\"\"\"\nCompare numerical gradient and automatic gradient\ngiven by Zygote. `f` has to be a scalar valued function.\n\nApplies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly defined.\n\"\"\"\nfunction gradtest(\n    f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(),\n    check_rrule = false, fdm = :central, check_broadcast = false,\n    skip = false, broken = false,\n)\n    # TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166\n    # is merged\n    if check_rrule\n        test_rrule(f, xs...; fkwargs = fkwargs)\n    end\n\n    if check_broadcast\n        length(fkwargs) > 0 && @warn(\"CHECK_BROADCAST: dropping keywords args\")\n        h = (xs...) -> sum(f.(xs...))\n    else\n        h = (xs...) -> sum(f(xs...; fkwargs...))\n    end\n\n    y_true = h(xs...)\n    if fdm == :central\n        fdm_obj = FiniteDifferences.central_fdm(5, 1)\n    elseif fdm == :forward\n        fdm_obj = FiniteDifferences.forward_fdm(5, 1)\n    elseif fdm == :backward\n        fdm_obj = FiniteDifferences.backward_fdm(5, 1)\n    end\n    # @show fdm fdm_obj\n\n    gs_fd = FiniteDifferences.grad(fdm_obj, h, xs...)\n\n    y_ad, pull = Zygote.pullback(h, xs...)\n    gs_ad = pull(one(y_ad))\n\n    @test y_true ≈ y_ad  atol = atol rtol = rtol\n    for (g_ad, g_fd) in zip(gs_ad, gs_fd)\n        if skip\n            @test_skip g_ad ≈ g_fd   atol = atol rtol = rtol\n        elseif broken\n            @test_broken g_ad ≈ g_fd   atol = atol rtol = rtol\n        else\n            @test g_ad ≈ g_fd   atol = atol rtol = rtol\n        end\n    end\n    return true\nend\n\n\"\"\"\n    gputest(f, xs...; checkgrad=true, atol=1e-6, kws...)\n\nCompare gradients computed on the device vs CPU.\n`xs...` should already be on the device.\n\"\"\"\nfunction gputest(f, xs...; checkgrad=true, atol=1e-6, kws...)\n    cpu_xs = map(x -> adapt(CPU(), x), xs)\n\n    cpu_y = f(cpu_xs...; kws...)\n    y = f(xs...; kws...)\n    @test collect(cpu_y) ≈ collect(y)\n\n    if checkgrad\n        cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_xs...)\n        gpu_grad = gradient((x...) -> sum(f(x...; kws...)), xs...)\n\n        for (cpu_g, gpu_g) in zip(cpu_grad, adapt(CPU(), gpu_grad))\n            if cpu_g === nothing\n                @test gpu_g === nothing\n            else\n                @test collect(cpu_g) ≈ collect(gpu_g) atol=atol\n            end\n        end\n    end\nend\n"
  },
  {
    "path": "test/testsuite/fold.jl",
    "content": "import NNlib\n\nfunction fold_testsuite(Backend)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gradtest : gputest\n\n    @testset \"unfold wrapper\" begin\n        x = device(rand(rng, 16, 16, 3, 10))\n        w = device(rand(rng, 5, 5, 3, 2))\n        @test size(NNlib.unfold(x, size(w))) == (144, 75, 10)\n        @test size(NNlib.unfold(x, size(w); pad=2)) == (256, 75, 10)\n        @test size(NNlib.unfold(x, size(w); stride=2)) == (36, 75, 10)\n        @test size(NNlib.unfold(x, size(w); dilation=2)) == (64, 75, 10)\n    end\n\n    @testset \"Inverses: spatial_rank=$spatial_rank\" for spatial_rank in (1, 2, 3)\n        x = device(rand(rng, repeat([8], spatial_rank)..., 3, 2))\n        w = device(rand(rng, repeat([3], spatial_rank)..., 3, 3))\n\n        cdims = DenseConvDims(x, w; padding=1)\n        y = NNlib.unfold(x, cdims)\n        z = NNlib.fold(y, size(x), cdims)\n\n        o = device(ones(eltype(x), size(x)...))\n        divisor = NNlib.fold(NNlib.unfold(o, cdims), size(x), cdims)\n        @test isapprox(z ./ divisor, x, rtol=1.0e-7)\n\n        # introduce stride\n        cdims = DenseConvDims(x, w; padding=1, stride=2)\n        y = NNlib.unfold(x, cdims)\n        z = NNlib.fold(y, size(x), cdims)\n        divisor = NNlib.fold(NNlib.unfold(o, cdims), size(x), cdims)\n        @test isapprox(z ./ divisor, x, rtol=1.0e-7)\n    end\n\n    @testset \"AutoDiff: spatial_rank=$spatial_rank\" for spatial_rank in (1, 2, 3)\n        x = device(rand(rng, repeat([5], spatial_rank)..., 3, 2))\n        w = device(rand(rng, repeat([3], spatial_rank)..., 3, 3))\n        cdims = DenseConvDims(x, w)\n\n        gradtest_fn(x -> NNlib.unfold(x, cdims), x)\n        Backend == CPU && test_rrule(NNlib.unfold, x, cdims)\n\n        y = NNlib.unfold(x, cdims)\n        gradtest_fn(y -> NNlib.fold(y, size(x), cdims), y)\n        Backend == CPU && test_rrule(NNlib.fold, y, size(x), cdims)\n    end\nend\n"
  },
  {
    "path": "test/testsuite/gather.jl",
    "content": "using NNlib: gather, gather!\nimport EnzymeTestUtils\nusing EnzymeCore\n\nfunction gather_testsuite(Backend)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gradtest : gputest\n    T = Float32\n\n    @testset \"gather scalar index\" begin\n        ## 1d src, 2d index of ints -> 2d output\n        src = device(T[3, 4, 5, 6, 7])\n        index = device([\n            1 2 3 4;\n            4 2 1 3;\n            3 5 5 3])\n        output = T[\n            3 4 5 6;\n            6 4 3 5;\n            5 7 7 5]\n\n        y = cpu(gather(src, index))\n        @test y isa Array{T,2}\n        @test size(y) == size(index)\n        @test y == output\n\n        dst = device(T.(zero(index)))\n        @test cpu(gather!(dst, src, index)) == output\n        dst = device(zeros(T, 3, 5))\n        @test_throws ArgumentError gather!(dst, src, index)\n\n        if Backend == CPU\n            index2 = [1 2 3 4;\n                      4 2 1 3;\n                      3 6 5 3]\n            @test_throws BoundsError gather!(T.(zero(index)), src, index2)\n        end\n\n        ## 1d src, 3d index of ints -> 3d output\n        src = device(T[3, 4, 5, 6, 7])\n        index = device([\n            1 2 3 4;\n            4 2 1 3;\n            3 5 5 3][:,:,1:1])\n        output = T[\n            3 4 5 6;\n            6 4 3 5;\n            5 7 7 5][:,:,1:1]\n\n        y = cpu(gather(src, index))\n        @test y isa Array{T,3}\n        @test size(y) == size(index)\n        @test y == output\n\n        ## 2d src, 2d index of ints -> 3d output\n        src = device(T[\n            3 5 7\n            4 6 8])\n        index = device([\n            1 2 3;\n            2 2 1;\n            3 1 3])\n\n        output = zeros(T, 2, 3, 3)\n        output[:,:,1] = [\n            3 5 7\n            4 6 8]\n        output[:,:,2] = [\n            5 5 3\n            6 6 4]\n        output[:,:,3] = [\n            7 3 7\n            8 4 8]\n\n        y = cpu(gather(src, index))\n        M = NNlib.typelength(eltype(index))\n        Nsrc = ndims(src)\n        @test y isa Array{T,3}\n        @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)\n        @test y == output\n    end\n\n    @testset \"gather tuple index\" begin\n        ## 2d src, 1d index of 2-tuples -> 1d output\n        src = device(T[\n            3 5 7\n            4 6 8])\n        index = device([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])\n        output = T[3, 5, 7, 4, 6, 8]\n\n        y = cpu(gather(src, index))\n        M = NNlib.typelength(eltype(index))\n        Nsrc = ndims(src)\n        @test y isa Array{T,1}\n        @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)\n        @test y == output\n\n        ## 3d src, 2d index of 2-tuples -> 3d output\n        n1, nsrc, nidx = 2, 3, 6\n        src = device(rand(T, n1, nsrc, nsrc))\n        index = device([\n            (rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx])\n\n        y = cpu(gather(src, index))\n        M = NNlib.typelength(eltype(index))\n        Nsrc = ndims(src)\n        @test y isa Array{T,3}\n        @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)\n    end\n\n    @testset \"gather cartesian index\" begin\n        ## 2d src, 1d index of 2-tuples -> 1d output\n        src = device(T[\n            3 5 7\n            4 6 8])\n        index = device(CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]))\n        output = T[3, 5, 7, 4, 6, 8]\n\n        y = cpu(gather(src, index))\n        M = NNlib.typelength(eltype(index))\n        Nsrc = ndims(src)\n        @test y isa Array{T,1}\n        @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)\n        @test y == output\n\n        ## 3d src, 2d index of 2-tuples -> 3d output\n        n1, nsrc, nidx = 2, 3, 6\n        src = device(rand(Float32, n1, nsrc, nsrc))\n        index = device([\n            CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx])\n\n        y = cpu(gather(src, index))\n        M = NNlib.typelength(eltype(index))\n        Nsrc = ndims(src)\n        @test y isa Array{T,3}\n        @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)\n    end\n\n    @testset \"gather gradient for scalar index\" begin\n        src = device(Float64[3, 4, 5, 6, 7])\n        idx = device([\n            1 2 3 4;\n            4 2 1 3;\n            3 5 5 3])\n        dst = device(Float64[\n            3 4 5 6;\n            6 4 3 5;\n            5 7 7 5])\n        Backend == CPU ?\n            gradtest_fn(xs -> gather!(dst, xs, idx), src) :\n            gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx)\n        Backend == CPU ?\n            gradtest_fn(xs -> gather(xs, idx), src) :\n            gradtest_fn((s, i) -> gather(s, i), src, idx)\n    end\n\n    @static if Test_Enzyme\n\n    @testset \"EnzymeRules: gather! gradient for scalar index\" begin\n        src = device(Float64[3, 4, 5, 6, 7])\n        idx = device([\n            1 2 3 4;\n            4 2 1 3;\n            3 5 5 3])\n        dst = gather(src, idx)\n        for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n            Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n            Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)\n\n            Tret == EnzymeCore.Const && continue # ERROR\n            EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue\n\n            EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const))\n        end\n    end\n\n    end\n\n    @testset \"gather gradient for tuple index\" begin\n        src = device(Float64[\n            3 5 7\n            4 6 8])\n        idx = device([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])\n        dst = device(Float64[3, 5, 7, 4, 6, 8])\n        Backend == CPU ?\n            gradtest_fn(xs -> gather!(dst, xs, idx), src) :\n            gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx)\n        Backend == CPU ?\n            gradtest_fn(xs -> gather(xs, idx), src) :\n            gradtest_fn((s, i) -> gather(s, i), src, idx)\n    end\n\n    @testset \"gather(src, IJK...)\" begin\n        x = device(reshape([1:15;], 3, 5))\n        i, j = device([1,2]), device([2,4])\n        y = gather(x, i, j)\n        @test cpu(y) == [4, 11]\n        y = gather(x, device([1, 2]))\n        @test cpu(y) == [\n            1 4\n            2 5\n            3 6]\n    end\nend\n\n"
  },
  {
    "path": "test/testsuite/rotation.jl",
    "content": "function rotation_testsuite(Backend)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gradtest : gputest\n    T = Float64\n    atol = T == Float32 ? 1e-3 : 1e-6\n    rtol = T == Float32 ? 1f-3 : 1f-6\n    angles = deg2rad.([0, 0.0001, 35, 90, -90, -90.0123, 170, 180, 270, 360, 450, 1234.1234]) \n\n    @testset \"imrotate\" begin\n        @testset \"Simple test\" begin\n            arr = device(zeros((6, 6, 1, 1))); \n            arr[3:4, 4, 1, 1] .= 1;\n            @test all(cpu(NNlib.imrotate(arr, deg2rad(45))) .≈ [0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.29289321881345254 0.585786437626905 0.0; 0.0 0.0 0.08578643762690495 1.0 0.2928932188134524 0.0; 0.0 0.0 0.0 0.08578643762690495 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0])\n        end\n\n\n        @testset \"Compare with ImageTransformations\" begin\n            for sz in [(51,51,1,1), (52,52,1,1), (51,52,1,1), (52,51,1,1)]\n                rotation_center = (sz[1:2] .+ 1) ./ 2  \n                arr1 = device(zeros(T, sz))\n                arr1[15:40, 15:40, :, :] .= device(1 .+ randn((26, 26)))                                                                       \n                arr2 = device(zeros(T, (sz[1], sz[2], sz[3], 3)))\n                arr2[15:40, 15:40, :, :] .= device(arr1[15:40, 15:40, :, :])\n\n                for method in [:nearest, :bilinear]\n                    @testset \"$method\" begin\n                        for angle in angles\n                            res1 = cpu(NNlib.imrotate(arr1, angle; method, rotation_center=rotation_center))\n                            res2 = cpu(NNlib.imrotate(arr2, angle; method, rotation_center=rotation_center))\n                            if method == :nearest\n                                res_IT = ImageTransformations.imrotate(cpu(arr1)[:, :, 1, 1], angle, axes(arr1)[1:2], method=Constant(), fillvalue=0)\n                            elseif method == :bilinear\n                                res_IT = ImageTransformations.imrotate(cpu(arr1)[:, :, 1, 1], angle, axes(arr1)[1:2], fillvalue=0)\n                            end\n                            if method == :nearest\n                                @test ≈(1 .+ res1[:, :, :, :], 1 .+ res_IT[:, :], rtol=0.5)\n                                @test ≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 1], rtol=0.5)\n                                @test ≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 2], rtol=0.5)\n                            else\n                                @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res_IT[:, :], rtol=rtol))\n                                @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 1], rtol=rtol))\n                                @test all(.≈(1 .+ res1[:, :, :, :], 1 .+ res2[:, :,:, 2], rtol=rtol))\n                            end\n                        end\n                    end\n                end\n            end\n        end\n            \n        @testset \"Compare for plausibilty\" begin\n            @testset \"Special cases of rotation\" begin\n                arr = device(zeros(T, (10, 10, 1, 3)))\n                arr[6, 6, :, 1] .= 1\n                arr[6, 6, :, 2] .= 2\n                arr[6, 6, :, 3] .= 3\n\n                for method in [:bilinear, :nearest]\n                    @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(0); method)))\n                    @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(90); method)))\n                    @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(180); method)))\n                    @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(270); method)))\n                    @test all(.≈(arr , NNlib.imrotate(arr, deg2rad(360); method)))\n                end\n            end\n        end\n\n        @testset \"Test gradients\" begin\n            for method in [:nearest, :bilinear]\n                for angle in angles \n                    gradtest_fn(\n                        x -> NNlib.imrotate(x, angle; method),\n                        device(rand(T, 11,11,1,1)); atol)\n                    gradtest_fn(\n                        x -> NNlib.imrotate(x, angle; method),\n                        device(rand(T, 10,10,1,1)); atol)        \n                end\n            end\n        end\n    end\nend\n"
  },
  {
    "path": "test/testsuite/scatter.jl",
    "content": "using NNlib: scatter, scatter!\n\ndsts = Dict(\n    0 => [3, 4, 5, 6, 7],\n    1 => [3 3 4 4 5;\n          5 5 6 6 7],\n)\nsrcs = Dict(\n    (0, true) => ones(Int, 3, 4),\n    (0, false) => ones(Int, 3) * collect(1:4)',\n    (1, true) => ones(Int, 2, 3, 4),\n    (1, false) => [1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4),\n)\nidxs = Dict(\n    :int => [1 2 3 4;\n             4 2 1 3;\n             3 5 5 3],\n    :tup => [(1,) (2,) (3,) (4,);\n             (4,) (2,) (1,) (3,);\n             (3,) (5,) (5,) (3,)],\n    :car => CartesianIndex.(\n            [(1,) (2,) (3,) (4,);\n             (4,) (2,) (1,) (3,);\n             (3,) (5,) (5,) (3,)]),\n)\nres = Dict(\n    (+, 0, true) => [5, 6, 9, 8, 9],\n    (+, 1, true) => [5 5 8 6 7;\n                     7 7 10 8 9],\n    (+, 0, false) => [4, 4, 12, 5, 5],\n    (+, 1, false) => [4 4 12 5 5;\n                      8 8 24 10 10],\n    (-, 0, true) => [1, 2, 1, 4, 5],\n    (-, 1, true) => [1 1 0 2 3;\n                     3 3 2 4 5],\n    (-, 0, false) => [-4, -4, -12, -5, -5],\n    (-, 1, false) => [-4 -4 -12 -5 -5;\n                      -8 -8 -24 -10 -10],\n    (max, 0, true) => [3, 4, 5, 6, 7],\n    (max, 1, true) => [3 3 4 4 5;\n                       5 5 6 6 7],\n    (max, 0, false) => [3, 2, 4, 4, 3],\n    (max, 1, false) => [3 2 4 4 3;\n                        6 4 8 8 6],\n    (min, 0, true) => [1, 1, 1, 1, 1],\n    (min, 1, true) => [1 1 1 1 1;\n                       1 1 1 1 1],\n    (min, 0, false) => [1, 2, 1, 1, 2],\n    (min, 1, false) => [1 2 1 1 2;\n                        2 4 2 2 4],\n    (*, 0, true) => [3, 4, 5, 6, 7],\n    (*, 1, true) => [3 3 4 4 5;\n                     5 5 6 6 7],\n    (*, 0, false) => [3, 4, 48, 4, 6],\n    (*, 1, false) => [3 4 48 4 6;\n                      12 16 768 16 24],\n    (/, 0, true) => [0.75, 1., 0.3125, 1.5, 1.75],\n    (/, 1, true) => [0.75 0.75 0.25 1. 1.25;\n                     1.25 1.25 0.375 1.5 1.75],\n    (/, 0, false) => [1//3, 1//4, 1//48, 1//4, 1//6],\n    (/, 1, false) => [1//3 1//4 1//48 1//4 1//6;\n                      1//12 1//16 1//768 1//16 1//24],\n    (mean, 0, true) => [4., 5., 6., 7., 8.],\n    (mean, 1, true) => [4. 4. 5. 5. 6.;\n                        6. 6. 7. 7. 8.],\n    (mean, 0, false) => [2, 2, 3, 2.5, 2.5],\n    (mean, 1, false) => [2. 2. 3. 2.5 2.5;\n                         4. 4. 6. 5. 5.],\n)\n\nfunction test_scatter(device, types, ops; pt, ops_skip_types)\n    for T in types, IT in (Int8, Int64)\n        PT = promote_type(T, pt)\n        @testset \"eltype $T - idx eltype $IT - $op\" for op in ops\n            skip_types = get(ops_skip_types, op, [])\n            for idx = values(idxs), dims = [0, 1]\n                # Tests with indices of different types.\n                eltype(idx) == Int && (idx = IT.(idx);)\n\n                idx = device(idx)\n                dst = device(dsts[dims])\n\n                mutated = true\n                target_y = res[(op, dims, mutated)]\n                src = device(srcs[(dims, mutated)])\n                if op == /\n                    src = src .* T(2)\n                end\n\n                @test cpu(scatter!(op, T.(dst), T.(src), idx)) == T.(target_y)\n                @test cpu(scatter!(op, T.(dst), src, idx)) == PT.(target_y)\n                if op == /\n                    @test cpu(scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y)\n                else\n                    @test cpu(scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y)\n                end\n\n                if T ∉ skip_types\n                    mutated = false\n                    src = device(srcs[(dims, mutated)])\n                    @test cpu(scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)])\n                end\n            end\n        end\n    end\nend\n\nfunction scatter_testsuite(Backend)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gradtest : gputest\n\n    ops_skip_types = Dict(\n        (+) => [],\n        (-) => [UInt8, UInt16, UInt32, UInt64, UInt128],\n        (*) => [UInt8, Int8],\n        max => [BigInt],\n        min => [BigInt])\n    types = if Backend == CPU\n        [UInt8,  UInt32, UInt64, Int32, Int64, Float16, Float32, Float64, BigFloat, Rational]\n    elseif Symbol(Backend) == :CUDABackend\n        [Int32, Int64, Float32, Float64]\n    else\n        # Need LLVM 15+ for atomic fmin/fmax:\n        # https://reviews.llvm.org/D127041\n        # But fmin/fmax can be done by reinterpreting an array to `UInt`.\n        [Int32, Int64, UInt32, UInt64]\n    end\n    ops = Backend == CPU ?\n        (+, -, max, min, *) :\n        (+, -, max, min)\n    test_scatter(device, types, ops; pt=Int, ops_skip_types)\n\n    types = Backend == CPU ?\n        [Float16, Float32, BigFloat, Rational] :\n        [Float32, Float64]\n    ops = if Backend == CPU\n        (/, mean)\n    elseif Symbol(Backend) == :CUDABackend\n        (*, /, mean)\n    else\n        # LLVM does not support atomic fmul/fdiv:\n        # https://llvm.org/docs/LangRef.html#atomicrmw-instruction\n        (mean,)\n    end\n    test_scatter(device, types, ops; pt=Float64, ops_skip_types=Dict())\n\n    if Backend == CPU\n        @testset \"scatter exceptions\" begin\n            idx = [1 2 3 4; 4 2 1 3; 6 7 8 9]\n            @test_throws AssertionError scatter!(+, copy(dsts[0]), srcs[(1, true)], idxs[:int])\n            @test_throws BoundsError scatter!(+, copy(dsts[1]), srcs[(1, true)], idx)\n        end\n    end\n\n    @testset \"∇scatter\" begin\n        T = Float64\n        fdm(op) = op == min ? :backward : :forward\n\n        @testset \"dstsize\" begin\n            idx = device([2, 2, 3, 4, 4])\n            src = device(ones(T, 3, 5))\n            y = scatter(+, src, idx, dstsize = (3, 6))\n            @test eltype(y) == T\n            @test size(y) == (3, 6)\n            Backend == CPU ?\n                gradtest_fn(x -> scatter(+, x, idx; dstsize=(3, 6)), src) :\n                gradtest_fn((x, i) -> scatter(+, x, i; dstsize=(3, 6)), src, idx)\n        end\n\n        @testset \"∂dst\" begin\n            ops = if Backend == CPU || Symbol(Backend) == :CUDABackend\n                (+, -, *, /, mean, max, min)\n            else\n                (+, -, mean, max, min)\n            end\n            for op in ops, i in (0, 1), IT in (Int8, Int64)\n                PT = ( # If not CPU and CUDA -> use Int64 for min/max.\n                    Backend != CPU &&\n                    Symbol(Backend) != :CUDABackend &&\n                    (op == max || op == min)) ? Int64 : T\n\n                src = device(srcs[(i, true)])\n                idx = device(IT.(idxs[:int]))\n                dst = device(PT.(dsts[i]))\n                Backend == CPU ?\n                    gradtest_fn(x -> scatter!(op, copy(x), src, idx), dst; fdm=fdm(op)) :\n                    gradtest_fn((x, s, i) -> scatter!(op, x, s, i), dst, src, idx)\n            end\n        end\n\n        @testset \"∂src\" begin\n            ops = if Backend == CPU || Symbol(Backend) == :CUDABackend\n                (+, -, *, /, mean, max, min)\n            else\n                (+, -, mean, max, min)\n            end\n            for op in ops, i in (0, 1), IT in (Int8, Int64)\n                PT = ( # If not CPU and CUDA -> use Int64 for min/max.\n                    Backend != CPU &&\n                    Symbol(Backend) != :CUDABackend &&\n                    (op == max || op == min)) ? Int64 : T\n                src = PT.(device(srcs[(i, false)]))\n                idx = device(IT.(idxs[:int]))\n                Backend == CPU ?\n                    gradtest_fn(xs -> scatter(op, xs, idx), src; fdm=fdm(op)) :\n                    gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx)\n            end\n        end\n\n\n        @static if Test_Enzyme\n\n        @testset \"EnzymeRules\" begin\n            idx = device([2, 2, 3, 4, 4])\n            src = device(ones(T, 3, 5))\n\n            for op in (+, -)\n\n                dst = scatter(op, src, idx)\n\n                for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n                    Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),\n                    Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)\n\n                    Tret == EnzymeCore.Const && continue # ERROR                 \n                    EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue\n\n                    EnzymeTestUtils.test_reverse(scatter!, Tret, (op, EnzymeCore.Const), (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const))\n                end\n            end\n        end\n\n        end\n    end\nend\n"
  },
  {
    "path": "test/testsuite/spectral.jl",
    "content": "function spectral_testsuite(Backend)\n    cpu(x) = adapt(CPU(), x)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gradtest : gputest\n\n    @testset \"Window functions\" begin\n        for window_fn in (hann_window, hamming_window)\n            @inferred window_fn(10, Float32)\n            @inferred window_fn(10, Float64)\n\n            w = window_fn(10)\n            @test length(w) == 10\n            @test eltype(w) == Float32\n\n            wp = window_fn(10; periodic=false)\n            @test wp[1:5] ≈ reverse(wp[6:10])\n\n            @test window_fn(10; periodic=true) ≈ window_fn(10 + 1; periodic=false)[1:10]\n        end\n    end\n\n    @testset \"STFT\" for batch in ((), (3,))\n        @testset \"Grads\" begin\n            if Backend != CPU\n                x = rand(Float32, 16, batch...)\n                window = hann_window(16)\n\n                gradtest_fn(s -> abs.(stft(s; n_fft=16)), x)\n                gradtest_fn((s, w) -> abs.(stft(s; n_fft=16, window=w)), x, window)\n\n                x = rand(Float32, 2045, batch...)\n                n_fft = 256\n                window = hann_window(n_fft)\n                gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w)), x, window)\n                gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w, center=false)), x, window)\n                gradtest_fn((s, w) -> abs.(stft(s; n_fft, window=w, normalized=true)), x, window)\n            end\n        end\n\n        @testset \"Batch $batch\" begin\n            x = device(ones(Float32, 16, batch...))\n            # TODO fix type stability for pad_reflect\n            # @inferred stft(x; n_fft=16)\n\n            bd = ntuple(_ -> Colon(), length(batch))\n\n            y = stft(x; n_fft=16)\n            @test size(y) == (9, 5, batch...)\n            @test all(real(cpu(y))[1, :, bd...] .≈ 16)\n\n            xx = istft(y; n_fft=16)\n            @test size(xx) == (16, batch...)\n            @test cpu(x) ≈ cpu(xx)\n\n            # Test multiple hops.\n            x = device(rand(Float32, 2048, batch...))\n            y = stft(x; n_fft=1024)\n            xx = istft(y; n_fft=1024)\n            @test cpu(x) ≈ cpu(xx)\n\n            # Test odd sizes.\n            x = device(rand(Float32, 1111, batch...))\n            y = stft(x; n_fft=256)\n            xx = istft(y; n_fft=256, original_length=size(x, 1))\n            @test cpu(x) ≈ cpu(xx)\n\n            # Output from inverse is cropped on the right\n            # without knowing the original size.\n            xx = istft(y; n_fft=256)\n            @test length(xx) < length(x)\n            @test cpu(x)[[1:s for s in size(xx)]...] ≈ cpu(xx)\n\n            # Test different options.\n\n            # Normalized.\n            x = device(rand(Float32, 1234, batch...))\n            y = stft(x; n_fft=512, normalized=true)\n            xx = istft(y; n_fft=512, normalized=true, original_length=size(x, 1))\n            @test cpu(x) ≈ cpu(xx)\n\n            # With window.\n            window = device(hann_window(512))\n            y = stft(x; n_fft=512, window)\n            xx = istft(y; n_fft=512, window, original_length=size(x, 1))\n            @test cpu(x) ≈ cpu(xx)\n\n            # Hop.\n            for hop_length in (32, 33, 255, 256, 511, 512)\n                y = stft(x; n_fft=512, hop_length)\n                xx = istft(y; n_fft=512, hop_length, original_length=size(x, 1))\n                @test cpu(x) ≈ cpu(xx)\n            end\n\n            # N FFT.\n            for n_fft in (32, 33, 64, 65, 128, 129, 512)\n                y = stft(x; n_fft)\n                xx = istft(y; n_fft, original_length=size(x, 1))\n                @test cpu(x) ≈ cpu(xx)\n            end\n        end\n    end\n\n    @testset \"Spectrogram\" begin\n        x = device(rand(Float32, 1024))\n        window = device(hann_window(1024))\n\n        y = stft(x;\n            n_fft=1024, hop_length=128, window,\n            center=true, normalized=false)\n        spec = spectrogram(x;\n            n_fft=1024, hop_length=128, window,\n            center=true, normalized=false)\n        @test abs.(y).^2 ≈ spec\n\n        # Gradient with `0`s in spectrogram.\n        # We add small ϵ to spectrogram before computing power\n        # to prevent `NaN` in gradient due to `abs(0)`.\n        x = device(ones(Float32, 1024))\n        g = Zygote.gradient(x) do x\n            sum(spectrogram(x;\n                n_fft=1024, hop_length=128, window,\n                center=true, normalized=false))\n        end\n        @test !any(isnan.(g[1]))\n\n        # Batched.\n        x = device(rand(Float32, 1024, 3))\n        spec = spectrogram(x;\n            n_fft=1024, hop_length=128, window,\n            center=true, normalized=false)\n        for i in 1:3\n            y = stft(x[:, i];\n                n_fft=1024, hop_length=128, window,\n                center=true, normalized=false)\n            @test abs.(y).^2 ≈ spec[:, :, i]\n        end\n\n        if Backend != CPU\n            @testset \"Grads\" begin\n                for batch in ((), (3,))\n                    x = rand(Float32, 2045, batch...)\n                    n_fft = 256\n                    window = hann_window(n_fft)\n                    gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w), x, window)\n                    gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w, center=false), x, window)\n                    gradtest_fn((s, w) -> spectrogram(s; n_fft, hop_length=n_fft ÷ 4, window=w, normalized=true), x, window)\n                end\n            end\n        end\n    end\n\n    @testset \"Power to dB\" begin\n        x = device(rand(Float32, 1024))\n        window = device(hann_window(1024))\n        spec = spectrogram(x; pad=0, n_fft=1024, hop_length=128, window)\n\n        @test spec ≈ NNlib.db_to_power(NNlib.power_to_db(spec))\n        @inferred NNlib.power_to_db(spec)\n        @inferred NNlib.db_to_power(NNlib.power_to_db(spec))\n    end\nend\n"
  },
  {
    "path": "test/testsuite/upsample.jl",
    "content": "function upsample_testsuite(Backend)\n    device(x) = adapt(Backend(), x)\n    gradtest_fn = Backend == CPU ? gradtest : gputest\n    T = Float32 # TODO test against all supported eltypes for each backend.\n    atol = T == Float32 ? 1e-3 : 1e-6\n\n    @testset \"upsample_nearest, integer scale via reshape\" begin\n        x = device(reshape(T[1 2; 3 4], (2,2,1,1)))\n        @test cpu(upsample_nearest(x, (3,3)))[1,:] == [1,1,1, 2,2,2]\n\n        y = upsample_nearest(x, (2,3))\n        @test size(y) == (4,6,1,1)\n        y2 = upsample_nearest(x, size=(4,6))\n        @test cpu(y) ≈ cpu(y2)\n\n        @test cpu(∇upsample_nearest(y, (2,3)))[:, :, 1, 1] == [6 12; 18 24]\n        gradtest_fn(\n            x -> upsample_nearest(x, (2,3)),\n            device(rand(T, 2,2,1,1)); atol)\n        gradtest_fn(\n            x -> upsample_nearest(x, size=(4,6)),\n            device(rand(T, 2,2,1,1)); atol)\n\n        @test_throws ArgumentError ∇upsample_nearest(y, (2,4))\n        @test_throws ArgumentError upsample_nearest(x, (1,2,3,4,5))\n        @test_throws ArgumentError upsample_nearest(x, size=(3,4))\n    end\n\n    @testset \"Linear upsampling (1D)\" begin\n        x = T[1,2,3,4]\n        x = hcat(x,x,x)[:,:,:]\n\n        y = collect(1:1//3:4)\n        y = hcat(y,y,y)[:,:,:]\n\n        xd = device(x)\n        @test y ≈ cpu(upsample_linear(xd, 2.5))\n        @test y ≈ cpu(upsample_linear(xd; size=10))\n        gradtest_fn(x -> upsample_linear(x, 2.5), xd; atol)\n    end\n\n    @testset \"Bilinear upsampling (2D)\" begin\n        x = Float32[1 2; 3 4][:,:,:,:]\n        x = cat(x,x; dims=3)\n        x = cat(x,x; dims=4)\n\n        # this output matches the one of pytorch v1.5.0\n        # nn.UpsamplingBilinear2d(scale_factor=(3,2), align_corners=True)\n        # for above x\n        y_true = Float32[ 1//1  4//3   5//3   2//1;\n                          7//5 26//15 31//15 12//5;\n                          9//5 32//15 37//15 14//5;\n                         11//5 38//15 43//15 16//5;\n                         13//5 44//15 49//15 18//5;\n                          3//1 10//3  11//3   4//1][:,:,:,:]\n        y_true = cat(y_true, y_true; dims=3)\n        y_true = cat(y_true, y_true; dims=4)\n\n        xd = device(x)\n        y = upsample_bilinear(xd, (3, 2))\n        @test size(y) == size(y_true)\n        @test eltype(y) == Float32\n        @test cpu(y) ≈ y_true\n\n        gradtest_fn(x -> upsample_bilinear(x, (3, 2)), xd; atol)\n\n        # additional grad check, also compliant with pytorch\n        o = ones(Float32,6,4,2,1)\n        grad_true = 6*ones(Float32,2,2,2,1)\n        @test cpu(∇upsample_bilinear(device(o); size = (2,2))) ≈ grad_true\n\n        # CPU only tests.\n\n        y_true_2 = Rational{Int}[1//1  5//4  6//4  7//4 2//1;\n                                 3//2  7//4  8//4  9//4 5//2;\n                                 4//2  9//4 10//4 11//4 6//2;\n                                 5//2 11//4 12//4 13//4 7//2;\n                                 3//1 13//4 14//4 15//4 4//1][:,:,:,:]\n        y_true_2 = cat(y_true_2, y_true_2; dims=3)\n        y_true_2 = cat(y_true_2, y_true_2; dims=4)\n\n        # check for real-valued single-number argument and type stability for rationals\n        y_rational = upsample_bilinear(Rational{Int}.(x), 2.5)\n        @test eltype(y_rational) == Rational{Int}\n        @test y_rational == y_true_2\n\n        # check Integer support for forward pass\n        # grads are always assumed to be floats, so no extension there\n        x = UInt8[1 3; 3 5][:,:,:,:]\n        y_true_int = UInt8[1 2 3; 2 3 4; 3 4 5][:,:,:,:]\n        y = upsample_bilinear(x, 1.5)\n\n        @test eltype(y) == UInt8\n        @test y == y_true_int\n    end\n\n    @testset \"Trilinear upsampling (3D)\" begin\n        # Layout: WHDCN, where D is depth\n        # we generate data which is constant along W & H and differs in D\n        # then we upsample along all dimensions\n        x = ones(T, 3,3,3,1,1)\n        x[:,:,1,:,:] .= 1.\n        x[:,:,2,:,:] .= 2.\n        x[:,:,3,:,:] .= 3.\n\n        y_true = ones(T, 5,5,5,1,1)\n        y_true[:,:,1,:,:] .= 1.\n        y_true[:,:,2,:,:] .= 1.5\n        y_true[:,:,3,:,:] .= 2.\n        y_true[:,:,4,:,:] .= 2.5\n        y_true[:,:,5,:,:] .= 3.\n\n        xd = device(x)\n        y = upsample_trilinear(xd; size=(5,5,5))\n\n        @test size(y) == size(y_true)\n        @test eltype(y) == T\n        @test collect(y) ≈ collect(y_true)\n\n        gradtest_fn(\n            x -> upsample_trilinear(x, (2,2,2)), xd;\n            atol=(T == Float32) ? 1e-2 : 1e-5)\n\n        # This test only works when `align_corners=false`.\n        o = device(ones(Float32,8,8,8,1,1))\n        grad_true = 8 * ones(Float32,4,4,4,1,1)\n        @test cpu(∇upsample_trilinear(o; size=(4,4,4), align_corners=false)) ≈ grad_true\n    end\n\n    @testset \"pixel_shuffle\" begin\n        x = reshape(1:16, (2, 2, 4, 1))\n        # [:, :, 1, 1] =\n        #     1  3\n        #     2  4\n        # [:, :, 2, 1] =\n        #     5  7\n        #     6  8\n        # [:, :, 3, 1] =\n        #     9  11\n        #     10  12\n        # [:, :, 4, 1] =\n        #     13  15\n        #     14  16\n\n        y_true = [1  9 3 11\n                  5 13 7 15\n                  2 10 4 12\n                  6 14 8 16][:,:,:,:]\n\n        y = pixel_shuffle(device(x), 2)\n        @test size(y) == size(y_true)\n        @test y_true == cpu(y)\n\n        x = reshape(1:32, (2, 2, 8, 1))\n        y_true = zeros(Int, 4, 4, 2, 1)\n        y_true[:,:,1,1] .= [ 1   9  3  11\n                             5  13  7  15\n                             2  10  4  12\n                             6  14  8  16 ]\n\n        y_true[:,:,2,1] .= [ 17  25  19  27\n                             21  29  23  31\n                             18  26  20  28\n                             22  30  24  32]\n\n        y = pixel_shuffle(device(x), 2)\n        @test size(y) == size(y_true)\n        @test y_true == cpu(y)\n\n        x = reshape(1:4*3*27*2, (4,3,27,2))\n        y = pixel_shuffle(device(x), 3)\n        @test size(y) == (12, 9, 3, 2)\n\n        # batch dimension is preserved\n        x1 = x[:,:,:,[1]]\n        x2 = x[:,:,:,[2]]\n        y1 = pixel_shuffle(device(x1), 3)\n        y2 = pixel_shuffle(device(x2), 3)\n        @test cpu(cat(y1, y2, dims=4)) == cpu(y)\n\n        for d in [1, 2, 3]\n            r = rand(1:5)\n            n = rand(1:5)\n            c = rand(1:5)\n            insize = rand(1:5, d)\n            x = rand(insize..., r^d*c, n)\n            xd = device(x)\n\n            y = pixel_shuffle(xd, r)\n            @test size(y) == ((r .* insize)..., c, n)\n            gradtest_fn(x -> pixel_shuffle(x, r), xd)\n        end\n    end\n\n    @testset \"Complex-valued upsample\" begin\n        for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear])\n            for (k, interp) in zip((2, ntuple(_ -> 2,  d)), [method, upsample_nearest])\n                x = device(randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1))\n\n                upsize = (8, 16, 24)[1:d]\n                xup = interp(x, k)\n                @test size(xup)[1:d] == upsize\n                @test cpu(real(xup)) == cpu(interp(real(x), k))\n                @test cpu(imag(xup)) == cpu(interp(imag(x), k))\n\n                upsize = (8,24,48)[1:d]\n                xup = interp(x; size=upsize)\n                @test size(xup)[1:d] == upsize\n                @test cpu(real(xup)) == cpu(interp(real(x), size=upsize))\n                @test cpu(imag(xup)) == cpu(interp(imag(x), size=upsize))\n            end\n        end\n    end\nend\n"
  },
  {
    "path": "test/utils.jl",
    "content": "@testset \"within_gradient\" begin\n    @test NNlib.within_gradient([1.0]) === false\n    @test gradient(x -> NNlib.within_gradient(x) * x, 2.0) == (1.0,)\n    @test NNlib.within_gradient([ForwardDiff.Dual(1.0, 2)]) === true\nend\n\n@testset \"maximum_dims\" begin\n    ind1 = [1,2,3,4,5,6]\n    @test NNlib.maximum_dims(ind1) == (6,)\n    ind2 = [(3,4,5), (1,2,3), (2,3,9)]\n    @test NNlib.maximum_dims(ind2) == (3,4,9)\n    ind3 = [(3,4,5) (1,2,3) (2,3,9);\n            (4,6,2) (5,3,2) (4,4,4)]\n    @test NNlib.maximum_dims(ind3) == (5,6,9)\n    ind4 = CartesianIndex.(\n           [(3,4,5) (1,2,3) (2,3,9);\n            (4,6,2) (5,3,2) (4,4,4)])\n    @test NNlib.maximum_dims(ind4) == (5,6,9)\nend\n\n@testset \"reverse_indices\" begin\n    res = [\n        CartesianIndex.([(1,1), (2,3)]),\n        CartesianIndex.([(1,2), (2,2)]),\n        CartesianIndex.([(3,1), (1,3), (2,4), (3,4)]),\n        CartesianIndex.([(2,1), (1,4)]),\n        CartesianIndex.([(3,2), (3,3)])\n    ]\n    idx = [1 2 3 4;\n           4 2 1 3;\n           3 5 5 3]\n    @test NNlib.reverse_indices(idx) == res\n    @test NNlib.reverse_indices(idx) isa typeof(res)\n    idx = [(1,) (2,) (3,) (4,);\n           (4,) (2,) (1,) (3,);\n           (3,) (5,) (5,) (3,)]\n    @test NNlib.reverse_indices(idx) == res\n    @test NNlib.reverse_indices(idx) isa typeof(res)\n    idx = CartesianIndex.(\n        [(1,) (2,) (3,) (4,);\n        (4,) (2,) (1,) (3,);\n        (3,) (5,) (5,) (3,)])\n    @test NNlib.reverse_indices(idx) == res\n    @test NNlib.reverse_indices(idx) isa typeof(res)\nend\n"
  }
]