Full Code of HIPS/autograd for AI

master 994362fdbcc8 cached
120 files
426.4 KB
134.9k tokens
1282 symbols
1 requests
Download .txt
Showing preview only (456K chars total). Download the full file or copy to clipboard to get everything.
Repository: HIPS/autograd
Branch: master
Commit: 994362fdbcc8
Files: 120
Total size: 426.4 KB

Directory structure:
gitextract_4gygwh8h/

├── .github/
│   └── workflows/
│       ├── check.yml
│       ├── publish.yml
│       └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── README.md
├── autograd/
│   ├── __init__.py
│   ├── builtins.py
│   ├── core.py
│   ├── differential_operators.py
│   ├── extend.py
│   ├── misc/
│   │   ├── __init__.py
│   │   ├── fixed_points.py
│   │   ├── flatten.py
│   │   ├── optimizers.py
│   │   └── tracers.py
│   ├── numpy/
│   │   ├── __init__.py
│   │   ├── fft.py
│   │   ├── linalg.py
│   │   ├── numpy_boxes.py
│   │   ├── numpy_jvps.py
│   │   ├── numpy_vjps.py
│   │   ├── numpy_vspaces.py
│   │   ├── numpy_wrapper.py
│   │   └── random.py
│   ├── scipy/
│   │   ├── __init__.py
│   │   ├── integrate.py
│   │   ├── linalg.py
│   │   ├── signal.py
│   │   ├── special.py
│   │   └── stats/
│   │       ├── __init__.py
│   │       ├── beta.py
│   │       ├── chi2.py
│   │       ├── dirichlet.py
│   │       ├── gamma.py
│   │       ├── multivariate_normal.py
│   │       ├── norm.py
│   │       ├── poisson.py
│   │       └── t.py
│   ├── test_util.py
│   ├── tracer.py
│   ├── util.py
│   └── wrap_util.py
├── benchmarks/
│   ├── __init__.py
│   ├── asv.conf.json.sample
│   ├── bench_core.py
│   ├── bench_mem.py
│   ├── bench_numpy_vjps.py
│   ├── bench_rnn.py
│   └── bench_util.py
├── conda_recipe/
│   └── conda.yaml
├── docs/
│   ├── tutorial.md
│   └── updateguide.md
├── examples/
│   ├── README.md
│   ├── __init__.py
│   ├── bayesian_neural_net.py
│   ├── bayesian_optimization.py
│   ├── black_box_svi.py
│   ├── convnet.py
│   ├── data.py
│   ├── data_mnist.py
│   ├── deep_gaussian_process.py
│   ├── define_gradient.py
│   ├── dot_graph.py
│   ├── fixed_points.py
│   ├── fluidsim/
│   │   ├── fluidsim.py
│   │   └── wing.py
│   ├── gaussian_process.py
│   ├── generative_adversarial_net.py
│   ├── gmm.py
│   ├── gplvm.py
│   ├── hmm_em.py
│   ├── ica.py
│   ├── logistic_regression.py
│   ├── lstm.py
│   ├── mixture_variational_inference.py
│   ├── natural_gradient_black_box_svi.py
│   ├── negative_binomial_maxlike.py
│   ├── neural_net.py
│   ├── neural_net_regression.py
│   ├── ode_net.py
│   ├── print_trace.py
│   ├── rkhs.py
│   ├── rnn.py
│   ├── rosenbrock.py
│   ├── sinusoid.py
│   ├── tanh.py
│   └── variational_autoencoder.py
├── license.txt
├── noxfile.py
├── pyproject.toml
└── tests/
    ├── _test_complexity.py
    ├── check_examples_run.sh
    ├── conftest.py
    ├── numpy_utils.py
    ├── profiling.py
    ├── test_binary_ops.py
    ├── test_builtins.py
    ├── test_complex.py
    ├── test_core.py
    ├── test_dict.py
    ├── test_direct.py
    ├── test_fft.py
    ├── test_graphs.py
    ├── test_jacobian.py
    ├── test_linalg.py
    ├── test_list.py
    ├── test_logic.py
    ├── test_misc.py
    ├── test_numpy.py
    ├── test_performance.py
    ├── test_scalar_ops.py
    ├── test_scipy.py
    ├── test_systematic.py
    ├── test_tests.py
    ├── test_truediv.py
    ├── test_tuple.py
    ├── test_vspaces.py
    └── test_wrappers.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/check.yml
================================================
name: Style and package checks

on:
  pull_request:
    branches:
    - master
  push:
    branches:
    - master
  workflow_dispatch:

env:
  PIP_DISABLE_PIP_VERSION_CHECK: "1"
  FORCE_COLOR: "3"

concurrency:
  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
  cancel-in-progress: true

jobs:
  check:
    name: ${{ matrix.env }}
    runs-on: ubuntu-latest
    strategy:
      fail-fast: false
      matrix:
        session:
      # - lint
        - validate-package
    steps:
    - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
    - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0

    - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1

    - name: Run ${{ matrix.env }}
      run: uvx nox -s ${{ matrix.env }}


================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish

on:
  workflow_dispatch:
  release:
    types: [published]

env:
  PIP_DISABLE_PIP_VERSION_CHECK: '1'
  FORCE_COLOR: '3'

jobs:
  build:
    name: Build sdist and wheel
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
      name: Checkout repository

    - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
      with:
        python-version: "3.12"

    - name: Install build tools
      run: |
        pipx run build --outdir dist

    - name: Upload wheel and sdist artifacts
      uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
      with:
        name: artifacts
        path: ./dist/*
        if-no-files-found: error

  publish:
    needs: [build]
    name: Upload to PyPI
    runs-on: ubuntu-latest
    environment:
      name: release
      url: https://pypi.org/p/autograd
    permissions:
      id-token: write # mandatory for trusted publishing

    steps:
      - name: Download artifacts
        uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
        with:
          path: dist
          merge-multiple: true

      - name: Sanity check artifacts
        run: ls -la dist/

      - name: Publish sdist and wheel to PyPI
        uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
        with:
          packages-dir: dist/


================================================
FILE: .github/workflows/test.yml
================================================
name: CI

on:
  pull_request:
    branches:
      - master
  push:
    branches:
      - master
  workflow_dispatch:
  schedule:
    - cron: "0 4 * * *"

env:
  PIP_DISABLE_PIP_VERSION_CHECK: "1"
  FORCE_COLOR: "3"

concurrency:
  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
  cancel-in-progress: true

jobs:
  test:
    name: Regular tests / ${{ matrix.platform }} / Python ${{ matrix.python-version }}
    runs-on: ${{ matrix.platform }}
    strategy:
      fail-fast: false
      matrix:
        platform: [ubuntu-latest, ubuntu-22.04-arm, macos-15-intel, macos-latest, windows-latest]
        python-version:
          ["3.10", "3.11", "3.12", "3.13", "3.14", "pypy-3.10"]
    steps:
      - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
      - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
        with:
          python-version: ${{ matrix.python-version }}
          allow-prereleases: true
      - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1

      # On PyPy, we skip SciPy because we don't have wheels
      # available, see noxfile.py for more details.
      - name: Run tests
        run: uvx nox -s tests

  # In this job, we test against the NumPy nightly wheels hosted on
  # https://anaconda.org/scientific-python-nightly-wheels/numpy
  # on the latest Python version available across platforms, instead of
  # testing all Python versions and implementations on all platforms.
  # We do not test on PyPy.
  #
  # However, "nox -s nightly-tests" can be used locally anywhere, on
  # any Python version and implementation on any platform and we leave
  # it to the user to decide what Python version to test against, which
  # might or might not have a corresponding NumPy nightly wheel present.
  nightlies:
    name: Nightly tests / ${{ matrix.platform }} / Python ${{ matrix.python-version }}
    runs-on: ${{ matrix.platform }}
    strategy:
      fail-fast: false
      matrix:
        platform: [ubuntu-latest, ubuntu-22.04-arm, macos-15-intel, macos-latest, windows-latest]
        python-version: ["3.x"]

    steps:
      - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
      - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
        with:
          python-version: ${{ matrix.python-version }}
          allow-prereleases: true
      - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1
      - name: Run tests against nightly wheels for NumPy and SciPy
        run: uvx nox -s nightly-tests


================================================
FILE: .gitignore
================================================
__pycache__/
*.py[cod]
*$py.class

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
coverage.*
*.cover
.hypothesis/
nosetests.xml
.pytest_cache/
junit-report.xml

# pyenv
.python-version

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# mypy
.mypy_cache/

# OS and IDE config files
.DS_Store
.idea/

# project-specific
data/
*.so
*.c
scratch/
examples/data

.asv/
asv.conf.json
benchmarks/asv.conf.js


================================================
FILE: .pre-commit-config.yaml
================================================
ci:
  autoupdate_commit_msg: "chore: update pre-commit hooks"
  autofix_commit_msg: "style: pre-commit fixes"

repos:
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v6.0.0
    hooks:
      - id: check-added-large-files
      - id: check-case-conflict
      - id: check-merge-conflict
      - id: check-yaml
        exclude: conda_recipe/conda.yaml
      - id: debug-statements
      - id: end-of-file-fixer
      - id: mixed-line-ending
      - id: trailing-whitespace

  - repo: https://github.com/asottile/pyupgrade
    rev: v3.21.2
    hooks:
      - id: pyupgrade
        args: [--py310-plus]

  - repo: https://github.com/astral-sh/ruff-pre-commit
    rev: "v0.15.6"
    hooks:
      - id: ruff
        args: ["--fix", "--show-fixes"]
      - id: ruff-format

  - repo: https://github.com/pre-commit/pygrep-hooks
    rev: v1.10.0
    hooks:
      - id: python-check-blanket-type-ignore
        exclude: ^src/vector/backends/_numba_object.py$
      - id: rst-backticks
      - id: rst-directive-colons
      - id: rst-inline-touching-normal


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing

Use [Nox](https://nox.thea.codes/en/stable/) to run tests and linting, e.g.,

```shell
pip install nox
```

`nox` will run all checks in an isolated virtual environment with Autograd and its dependencies, including its optional dependencies, installed.

## Run tests, linting, packaging checks

| Command                   | Description                                                                                                                                                     |
| ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `nox --list`              | Lists all available Nox sessions, including selected ones                                                                                                       |
| `nox -s lint`             | Runs code style checks with pre-commit and pre-commit hooks as listed in `.pre-commit-config.yaml`. Accepts posargs to pass additional arguments to the linter. |
| `nox -s tests`            | Runs tests with your default Python interpreter. Accepts posargs to pass additional arguments and configuration to `pytest`.                                    |
| `nox -s nightly-tests`    | Similar to `nox -s tests`, except that it runs tests with nightly versions of dependencies (NumPy, SciPy, etc.).                                                |
| `nox -s validate-package` | Builds a source distribution and a wheel using `pypa/build` and checks the package with `twine` in strict mode.                                                 |
| `nox`                     | Runs all selected sessions, as listed in `nox.options.sessions` in `noxfile.py`.                                                                                |

Additionally, `nox` supports tags to run specific sessions, e.g., `nox --tags tests` runs all sessions tagged with `tests`.

Make sure all tests pass before you push your changes to GitHub.
GH Actions will run the tests across all supported Python versions.

## Using positional arguments (reformat, upload package, help)

You can use additional arguments for the tools (`pytest`, `pre-commit`, etc.) called by Nox by
separating them from the Nox arguments by a double-hyphen `--`, e.g.,

- `nox -s tests -- tests/test_tuple.py` runs just the tests listed `tests/test_tuple.py`.
- `nox -s lint -- --fix` runs the linter with the `--fix` flag.
- and so on.


================================================
FILE: README.md
================================================
# Autograd  [![Checks status][checks-badge]][checks-url] [![Tests status][tests-badge]][tests-url] [![Publish status][publish-badge]][publish-url] [![asv][asv-badge]](#)

[publish-badge]: https://github.com/HIPS/autograd/actions/workflows/publish.yml/badge.svg
[checks-badge]: https://github.com/HIPS/autograd/actions/workflows/check.yml/badge.svg
[tests-badge]: https://github.com/HIPS/autograd/actions/workflows/test.yml/badge.svg
[asv-badge]: http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat
[publish-url]: https://github.com/HIPS/autograd/actions/workflows/publish.yml
[checks-url]: https://github.com/HIPS/autograd/actions/workflows/check.yml
[tests-url]: https://github.com/HIPS/autograd/actions/workflows/test.yml

Autograd can automatically differentiate native Python and Numpy code. It can
handle a large subset of Python's features, including loops, ifs, recursion and
closures, and it can even take derivatives of derivatives of derivatives. It
supports reverse-mode differentiation (a.k.a. backpropagation), which means it
can efficiently take gradients of scalar-valued functions with respect to
array-valued arguments, as well as forward-mode differentiation, and the two can
be composed arbitrarily. The main intended application of Autograd is
gradient-based optimization. For more information, check out the
[tutorial](docs/tutorial.md) and the [examples directory](examples/).

Example use:

```python
>>> import autograd.numpy as np  # Thinly-wrapped numpy
>>> from autograd import grad    # The only autograd function you may ever need
>>>
>>> def tanh(x):                 # Define a function
...     return (1.0 - np.exp((-2 * x))) / (1.0 + np.exp(-(2 * x)))
...
>>> grad_tanh = grad(tanh)       # Obtain its gradient function
>>> grad_tanh(1.0)               # Evaluate the gradient at x = 1.0
np.float64(0.419974341614026)
>>> (tanh(1.0001) - tanh(0.9999)) / 0.0002  # Compare to finite differences
np.float64(0.41997434264973155)
```

We can continue to differentiate as many times as we like, and use numpy's
vectorization of scalar-valued functions across many different input values:

```python
>>> from autograd import elementwise_grad as egrad  # for functions that vectorize over inputs
>>> import matplotlib.pyplot as plt
>>> x = np.linspace(-7, 7, 700)
>>> plt.plot(x, tanh(x),
...          x, egrad(tanh)(x),                                     # first  derivative
...          x, egrad(egrad(tanh))(x),                              # second derivative
...          x, egrad(egrad(egrad(tanh)))(x),                       # third  derivative
...          x, egrad(egrad(egrad(egrad(tanh))))(x),)               # fourth derivative
>>> plt.show()
```

<img src="examples/tanh.png" width="600">

See the [tanh example file](examples/tanh.py) for the code.

## Documentation

You can find a tutorial [here.](docs/tutorial.md)

## End-to-end examples

* [Simple neural net](examples/neural_net.py)
* [Convolutional neural net](examples/convnet.py)
* [Recurrent neural net](examples/rnn.py)
* [LSTM](examples/lstm.py)
* [Neural Turing Machine](https://github.com/DoctorTeeth/diffmem/blob/512aadeefd6dbafc1bdd253a64b6be192a435dc3/ntm/ntm.py)
* [Backpropagating through a fluid simulation](examples/fluidsim/fluidsim.py)

<img src="examples/fluidsim/animated.gif" width="400">

* [Variational inference in Bayesian neural network](examples/bayesian_neural_net.py)
* [Gaussian process regression](examples/gaussian_process.py)
* [Sampyl, a pure Python MCMC package with HMC and NUTS](https://github.com/mcleonard/sampyl)

## How to install

Install Autograd using Pip:

```shell
pip install autograd
```

Some features require SciPy, which you can install separately or as an
optional dependency along with Autograd:

```shell
pip install "autograd[scipy]"
```

## Authors and maintainers

Autograd was written by [Dougal Maclaurin](https://dougalmaclaurin.com),
[David Duvenaud](https://www.cs.toronto.edu/~duvenaud/),
[Matt Johnson](http://people.csail.mit.edu/mattjj/),
[Jamie Townsend](https://github.com/j-towns)
and many other contributors. The package is currently being maintained by
[Agriya Khetarpal](https://github.com/agriyakhetarpal),
[Fabian Joswig](https://github.com/fjosw) and
[Jamie Townsend](https://github.com/j-towns).
Please feel free to submit any bugs or
feature requests. We'd also love to hear about your experiences with Autograd
in general. Drop us an email!

We want to thank Jasper Snoek and the rest of the HIPS group (led by Prof. Ryan
P. Adams) for helpful contributions and advice; Barak Pearlmutter for
foundational work on automatic differentiation and for guidance on our
implementation; and Analog Devices Inc. (Lyric Labs) and Samsung Advanced Institute
of Technology for their generous support.


================================================
FILE: autograd/__init__.py
================================================
from autograd.core import primitive_with_deprecation_warnings as primitive

from .builtins import dict, isinstance, list, tuple, type
from .differential_operators import (
    checkpoint,
    deriv,
    elementwise_grad,
    grad,
    grad_and_aux,
    grad_named,
    hessian,
    hessian_tensor_product,
    hessian_vector_product,
    holomorphic_grad,
    jacobian,
    make_ggnvp,
    make_hvp,
    make_jvp,
    make_vjp,
    multigrad_dict,
    tensor_jacobian_product,
    value_and_grad,
    vector_jacobian_product,
)


================================================
FILE: autograd/builtins.py
================================================
from .extend import (
    Box,
    SparseObject,
    VSpace,
    defjvp,
    defjvp_argnum,
    defvjp,
    defvjp_argnum,
    notrace_primitive,
    primitive,
    vspace,
)
from .util import subvals

isinstance_ = isinstance
isinstance = notrace_primitive(isinstance)

type_ = type
type = notrace_primitive(type)

tuple_, list_, dict_ = tuple, list, dict


@primitive
def container_take(A, idx):
    return A[idx]


def grad_container_take(ans, A, idx):
    return lambda g: container_untake(g, idx, vspace(A))


defvjp(container_take, grad_container_take)
defjvp(container_take, "same")


class SequenceBox(Box):
    __slots__ = []
    __getitem__ = container_take

    def __len__(self):
        return len(self._value)

    def __add__(self, other):
        return sequence_extend_right(self, *other)

    def __radd__(self, other):
        return sequence_extend_left(self, *other)

    def __contains__(self, elt):
        return elt in self._value

    def index(self, elt):
        return self._value.index(elt)


SequenceBox.register(tuple_)
SequenceBox.register(list_)


class DictBox(Box):
    __slots__ = []
    __getitem__ = container_take

    def __len__(self):
        return len(self._value)

    def __iter__(self):
        return self._value.__iter__()

    def __contains__(self, elt):
        return elt in self._value

    def items(self):
        return list(self.iteritems())

    def keys(self):
        return list(self.iterkeys())

    def values(self):
        return list(self.itervalues())

    def iteritems(self):
        return ((k, self[k]) for k in self)

    def iterkeys(self):
        return iter(self)

    def itervalues(self):
        return (self[k] for k in self)

    def get(self, k, d=None):
        return self[k] if k in self else d


DictBox.register(dict_)


@primitive
def container_untake(x, idx, vs):
    if isinstance(idx, slice):
        accum = lambda result: [elt_vs._mut_add(a, b) for elt_vs, a, b in zip(vs.shape[idx], result, x)]
    else:
        accum = lambda result: vs.shape[idx]._mut_add(result, x)

    def mut_add(A):
        return vs._subval(A, idx, accum(A[idx]))

    return SparseObject(vs, mut_add)


defvjp(container_untake, lambda ans, x, idx, _: lambda g: container_take(g, idx))
defjvp(container_untake, "same")


@primitive
def sequence_extend_right(seq, *elts):
    return seq + type(seq)(elts)


def grad_sequence_extend_right(argnum, ans, args, kwargs):
    seq, elts = args[0], args[1:]
    return lambda g: g[: len(seq)] if argnum == 0 else g[len(seq) + argnum - 1]


defvjp_argnum(sequence_extend_right, grad_sequence_extend_right)


@primitive
def sequence_extend_left(seq, *elts):
    return type(seq)(elts) + seq


def grad_sequence_extend_left(argnum, ans, args, kwargs):
    seq, elts = args[0], args[1:]
    return lambda g: g[len(elts) :] if argnum == 0 else g[argnum - 1]


defvjp_argnum(sequence_extend_left, grad_sequence_extend_left)


@primitive
def make_sequence(seq_type, *args):
    return seq_type(args)


defvjp_argnum(make_sequence, lambda argnum, *args: lambda g: g[argnum - 1])


def fwd_grad_make_sequence(argnum, g, ans, seq_type, *args, **kwargs):
    return container_untake(g, argnum - 1, vspace(ans))


defjvp_argnum(make_sequence, fwd_grad_make_sequence)


class TupleMeta(type(tuple_)):
    def __instancecheck__(self, instance):
        return isinstance(instance, tuple_)


class tuple(tuple_, metaclass=TupleMeta):
    def __new__(cls, xs):
        return make_sequence(tuple_, *xs)


class ListMeta(type_):
    def __instancecheck__(self, instance):
        return isinstance(instance, list_)


class list(list_, metaclass=ListMeta):
    def __new__(cls, xs):
        return make_sequence(list_, *xs)


class DictMeta(type_):
    def __instancecheck__(self, instance):
        return isinstance(instance, dict_)


class dict(dict_, metaclass=DictMeta):
    def __new__(cls, *args, **kwargs):
        result = dict_(*args, **kwargs)
        if result:
            return _make_dict(result.keys(), list(result.values()))
        return result


@primitive
def _make_dict(keys, vals):
    return dict_(zip(keys, vals))


defvjp(_make_dict, lambda ans, keys, vals: lambda g: list(g[key] for key in keys), argnums=(1,))


class ContainerVSpace(VSpace):
    def __init__(self, value):
        self.shape = value
        self.shape = self._map(vspace)

    @property
    def size(self):
        return sum(self._values(self._map(lambda vs: vs.size)))

    def zeros(self):
        return self._map(lambda vs: vs.zeros())

    def ones(self):
        return self._map(lambda vs: vs.ones())

    def randn(self):
        return self._map(lambda vs: vs.randn())

    def standard_basis(self):
        zero = self.zeros()
        for i, vs in self._kv_pairs(self.shape):
            for x in vs.standard_basis():
                yield self._subval(zero, i, x)

    def _add(self, xs, ys):
        return self._map(lambda vs, x, y: vs._add(x, y), xs, ys)

    def _mut_add(self, xs, ys):
        return self._map(lambda vs, x, y: vs._mut_add(x, y), xs, ys)

    def _scalar_mul(self, xs, a):
        return self._map(lambda vs, x: vs._scalar_mul(x, a), xs)

    def _inner_prod(self, xs, ys):
        return sum(self._values(self._map(lambda vs, x, y: vs._inner_prod(x, y), xs, ys)))

    def _covector(self, xs):
        return self._map(lambda vs, x: vs._covector(x), xs)


class SequenceVSpace(ContainerVSpace):
    def _values(self, x):
        return x

    def _kv_pairs(self, x):
        return enumerate(x)

    def _map(self, f, *args):
        return self.seq_type(map(f, self.shape, *args))

    def _subval(self, xs, idx, x):
        return self.seq_type(subvals(xs, [(idx, x)]))


class ListVSpace(SequenceVSpace):
    seq_type = list_


class TupleVSpace(SequenceVSpace):
    seq_type = tuple_


class DictVSpace(ContainerVSpace):
    def _values(self, x):
        return x.values()

    def _kv_pairs(self, x):
        return x.items()

    def _map(self, f, *args):
        return {k: f(vs, *[x[k] for x in args]) for k, vs in self.shape.items()}

    def _subval(self, xs, idx, x):
        d = dict(xs.items())
        d[idx] = x
        return d


ListVSpace.register(list_)
TupleVSpace.register(tuple_)
DictVSpace.register(dict_)


class NamedTupleVSpace(SequenceVSpace):
    def _map(self, f, *args):
        return self.seq_type(*map(f, self.shape, *args))

    def _subval(self, xs, idx, x):
        return self.seq_type(*subvals(xs, [(idx, x)]))


================================================
FILE: autograd/core.py
================================================
from functools import reduce
from itertools import count

from .tracer import Box, Node, getval, isbox, primitive, toposort, trace
from .util import func, subval

# -------------------- reverse mode --------------------


def make_vjp(fun, x):
    start_node = VJPNode.new_root()
    end_value, end_node = trace(start_node, fun, x)
    if end_node is None:

        def vjp(g):
            return vspace(x).zeros()
    else:

        def vjp(g):
            return backward_pass(g, end_node)

    return vjp, end_value


def backward_pass(g, end_node):
    outgrads = {end_node: (g, False)}
    for node in toposort(end_node):
        outgrad = outgrads.pop(node)
        ingrads = node.vjp(outgrad[0])
        for parent, ingrad in zip(node.parents, ingrads):
            outgrads[parent] = add_outgrads(outgrads.get(parent), ingrad)
    return outgrad[0]


class VJPNode(Node):
    __slots__ = ["parents", "vjp"]

    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        self.parents = parents
        try:
            vjpmaker = primitive_vjps[fun]
        except KeyError:
            fun_name = getattr(fun, "__name__", fun)
            raise NotImplementedError(f"VJP of {fun_name} wrt argnums {parent_argnums} not defined")
        self.vjp = vjpmaker(parent_argnums, value, args, kwargs)

    def initialize_root(self):
        self.parents = []
        self.vjp = lambda g: ()


primitive_vjps = {}


def defvjp_argnums(fun, vjpmaker):
    primitive_vjps[fun] = vjpmaker


def defvjp_argnum(fun, vjpmaker):
    def vjp_argnums(argnums, *args):
        vjps = [vjpmaker(argnum, *args) for argnum in argnums]
        return lambda g: (vjp(g) for vjp in vjps)

    defvjp_argnums(fun, vjp_argnums)


def defvjp(fun, *vjpmakers, **kwargs):
    argnums = kwargs.get("argnums", count())
    vjps_dict = {
        argnum: translate_vjp(vjpmaker, fun, argnum) for argnum, vjpmaker in zip(argnums, vjpmakers)
    }

    def vjp_argnums(argnums, ans, args, kwargs):
        L = len(argnums)
        # These first two cases are just optimizations
        if L == 1:
            argnum = argnums[0]
            try:
                vjpfun = vjps_dict[argnum]
            except KeyError:
                raise NotImplementedError(f"VJP of {fun.__name__} wrt argnum 0 not defined")
            vjp = vjpfun(ans, *args, **kwargs)
            return lambda g: (vjp(g),)
        elif L == 2:
            argnum_0, argnum_1 = argnums
            try:
                vjp_0_fun = vjps_dict[argnum_0]
                vjp_1_fun = vjps_dict[argnum_1]
            except KeyError:
                raise NotImplementedError(f"VJP of {fun.__name__} wrt argnums 0, 1 not defined")
            vjp_0 = vjp_0_fun(ans, *args, **kwargs)
            vjp_1 = vjp_1_fun(ans, *args, **kwargs)
            return lambda g: (vjp_0(g), vjp_1(g))
        else:
            vjps = [vjps_dict[argnum](ans, *args, **kwargs) for argnum in argnums]
            return lambda g: (vjp(g) for vjp in vjps)

    defvjp_argnums(fun, vjp_argnums)


def translate_vjp(vjpfun, fun, argnum):
    if vjpfun is None:
        return lambda ans, *args, **kwargs: lambda g: vspace(args[argnum]).zeros()
    elif callable(vjpfun):
        return vjpfun
    else:
        raise Exception(f"Bad VJP '{vjpfun}' for '{fun.__name__}'")


# -------------------- forward mode --------------------


def make_jvp(fun, x):
    def jvp(g):
        start_node = JVPNode.new_root(g)
        end_value, end_node = trace(start_node, fun, x)
        if end_node is None:
            return end_value, vspace(end_value).zeros()
        else:
            return end_value, end_node.g

    return jvp


class JVPNode(Node):
    __slots__ = ["g"]

    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        parent_gs = [parent.g for parent in parents]
        try:
            jvpmaker = primitive_jvps[fun]
        except KeyError:
            name = getattr(fun, "__name__", fun)
            raise NotImplementedError(f"JVP of {name} wrt argnums {parent_argnums} not defined")
        self.g = jvpmaker(parent_argnums, parent_gs, value, args, kwargs)

    def initialize_root(self, g):
        self.g = g


primitive_jvps = {}


def defjvp_argnums(fun, jvpmaker):
    primitive_jvps[fun] = jvpmaker


def defjvp_argnum(fun, jvpmaker):
    def jvp_argnums(argnums, gs, ans, args, kwargs):
        return sum_outgrads(jvpmaker(argnum, g, ans, args, kwargs) for argnum, g in zip(argnums, gs))

    defjvp_argnums(fun, jvp_argnums)


def defjvp(fun, *jvpfuns, **kwargs):
    argnums = kwargs.get("argnums", count())
    jvps_dict = {argnum: translate_jvp(jvpfun, fun, argnum) for argnum, jvpfun in zip(argnums, jvpfuns)}

    def jvp_argnums(argnums, gs, ans, args, kwargs):
        return sum_outgrads(jvps_dict[argnum](g, ans, *args, **kwargs) for argnum, g in zip(argnums, gs))

    defjvp_argnums(fun, jvp_argnums)


def translate_jvp(jvpfun, fun, argnum):
    if jvpfun is None:
        return lambda g, ans, *a, **k: vspace(ans).zeros()
    elif jvpfun == "same":
        return lambda g, ans, *args, **kwargs: fun(*subval(args, argnum, g), **kwargs)
    elif callable(jvpfun):
        return jvpfun
    else:
        raise Exception(f"Bad JVP '{jvpfun}' for '{fun.__name__}'")


def def_linear(fun):
    """Flags that a function is linear wrt all args"""
    defjvp_argnum(fun, lambda argnum, g, ans, args, kwargs: fun(*subval(args, argnum, g), **kwargs))


# -------------------- vector behavior --------------------


def add_outgrads(prev_g_flagged, g):
    sparse = type(g) in sparse_object_types
    if prev_g_flagged:
        vs = vspace(g)
        prev_g, mutable = prev_g_flagged
        if mutable:
            if sparse:
                return sparse_add(vs, prev_g, g), True
            else:
                return vs.mut_add(prev_g, g), True
        else:
            if sparse:
                prev_g_mutable = vs.mut_add(None, prev_g)
                return sparse_add(vs, prev_g_mutable, g), True
            else:
                return vs.add(prev_g, g), True
    else:
        if sparse:
            return sparse_add(vspace(g), None, g), True
        else:
            return g, False


def sum_outgrads(gs):
    return reduce(add_outgrads, gs, None)[0]


@primitive
def sparse_add(vs, x_prev, x_new):
    x_prev = x_prev if x_prev is not None else vs.zeros()
    return x_new.mut_add(x_prev)


class VSpace:
    __slots__ = []
    mappings = {}
    iscomplex = False

    def __init__(self, value):
        pass

    def zeros(self):
        assert False, repr(self)

    def ones(self):
        assert False, repr(self)

    def standard_basis(self):
        assert False, repr(self)

    def randn(self):
        assert False, repr(self)

    @primitive
    def mut_add(self, x_prev, x_new):
        x_prev = x_prev if x_prev is not None else self.zeros()
        return self._mut_add(x_prev, x_new)

    @primitive
    def add(self, x_prev, x_new):
        return self._add(x_prev, x_new)

    @primitive
    def scalar_mul(self, x, a):
        return self._scalar_mul(x, a)

    @primitive
    def inner_prod(self, x, y):
        return self._inner_prod(x, y)

    @primitive
    def covector(self, x):
        return self._covector(x)

    def _add(self, x, y):
        return x + y

    def _mut_add(self, x, y):
        x += y
        return x

    def _scalar_mul(self, x, a):
        return x * a

    def _inner_prod(self, x, y):
        assert False

    def _covector(self, x):
        return x

    def __eq__(self, other):
        return type(self) == type(other) and self.__dict__ == other.__dict__

    def __repr__(self):
        return f"{type(self).__name__}_{self.__dict__}"

    @classmethod
    def register(cls, value_type, vspace_maker=None):
        if vspace_maker:
            VSpace.mappings[value_type] = vspace_maker
        else:
            VSpace.mappings[value_type] = cls


def vspace(value):
    try:
        return VSpace.mappings[type(value)](value)
    except KeyError:
        if isbox(value):
            return vspace(getval(value))
        else:
            raise TypeError(
                "Can't find vector space for value {} of type {}. Valid types are {}".format(
                    value, type(value), VSpace.mappings.keys()
                )
            )


class SparseBox(Box):
    __slots__ = []


class SparseObject:
    __slots__ = ["vs", "mut_add"]

    def __init__(self, vs, mut_add):
        self.vs = vs
        self.mut_add = mut_add


VSpace.register(SparseObject, lambda x: x.vs)
SparseBox.register(SparseObject)
sparse_object_types = {SparseObject, SparseBox}

# -------------------- core reverse mode grads --------------------

identity_vjp = lambda argnums, *args: lambda g: g
defvjp(sparse_add, None, identity_vjp, identity_vjp)
defvjp(func(VSpace.add), None, identity_vjp, identity_vjp)
defvjp(func(VSpace.mut_add), None, identity_vjp, identity_vjp)
defvjp(
    func(VSpace.inner_prod),
    None,
    lambda ans, vs, x, y: lambda g: vs.covector(vs.scalar_mul(y, g)),
    lambda ans, vs, x, y: lambda g: vs.covector(vs.scalar_mul(x, g)),
)
defvjp(func(VSpace.covector), None, lambda ans, vs, x: lambda g: vs.covector(g))
defvjp(
    func(VSpace.scalar_mul),
    None,
    lambda ans, vs, x, a: lambda g: vs.covector(vs.scalar_mul(vs.covector(g), a)),
    lambda ans, vs, x, a: lambda g: vs.inner_prod(g, vs.covector(x)),
)

# -------------------- core forward mode grads --------------------

identity_jvp = lambda g, *args, **kwargs: g
defjvp(sparse_add, None, identity_jvp, identity_jvp)
defjvp(func(VSpace.mut_add), None, identity_jvp, identity_jvp)
defjvp(func(VSpace.add), None, identity_jvp, identity_jvp)
defjvp(func(VSpace.scalar_mul), None, "same", "same")
defjvp(func(VSpace.inner_prod), None, "same", "same")
defjvp(func(VSpace.covector), None, "same")

# -------------------- deprecation warnings -----------------------

import warnings

deprecated_defvjp_message = """
The {} method is deprecated. See the update guide and tutorial:
https://github.com/HIPS/autograd/blob/master/docs/updateguide.md
https://github.com/HIPS/autograd/blob/master/docs/tutorial.md"""


def deprecated_defvjp(primitive_fun):
    deprecation_msg = deprecated_defvjp_message.format("defvjp")
    vjpfuns = {}

    def defvjp_unstaged(vjpmaker, argnum=0):
        warnings.warn(deprecation_msg)

        def staged_vjpmaker(ans, *args, **kwargs):
            def vjp(g):
                vs, gvs = vspace(args[argnum]), vspace(g)
                return vjpmaker(g, ans, vs, gvs, *args, **kwargs)

            return vjp

        vjpfuns[argnum] = staged_vjpmaker
        argnums, vjpmakers = zip(*[(argnum, vjpfuns[argnum]) for argnum in sorted(vjpfuns.keys())])
        defvjp(primitive_fun, *vjpmakers, argnums=argnums)

    return defvjp_unstaged


def deprecated_defvjp_is_zero(primitive_fun):
    deprecation_msg = deprecated_defvjp_message.format("defvjp_is_zero")
    zero_vjps = [set()]

    def defvjp_is_zero(argnums=(0,)):
        warnings.warn(deprecation_msg)
        zero_vjps[0] |= set(argnums)
        nones = [None] * len(zero_vjps[0])
        defvjp(primitive_fun, *nones, argnums=sorted(zero_vjps[0]))

    return defvjp_is_zero


def deprecated_defgrad(primitive_fun):
    deprecation_msg = deprecated_defvjp_message.format("defgrad")
    gradfuns = {}

    def defgrad(gradfun, argnum=0):
        warnings.warn(deprecation_msg)
        gradfuns[argnum] = gradfun
        argnums, vjpmakers = zip(*[(argnum, gradfuns[argnum]) for argnum in sorted(gradfuns.keys())])
        defvjp(primitive_fun, *vjpmakers, argnums=argnums)

    return defgrad


primitive_ = primitive


def primitive_with_deprecation_warnings(f_raw):
    f_wrapped = primitive_(f_raw)
    f_wrapped.defvjp = deprecated_defvjp(f_wrapped)
    f_wrapped.defvjp_is_zero = deprecated_defvjp_is_zero(f_wrapped)
    f_wrapped.defgrad = deprecated_defgrad(f_wrapped)
    return f_wrapped


primitive = primitive_with_deprecation_warnings


================================================
FILE: autograd/differential_operators.py
================================================
"""Convenience functions built on top of `make_vjp`."""

from collections import OrderedDict

try:
    from inspect import getfullargspec as _getargspec  # Python 3
except ImportError:
    from inspect import getargspec as _getargspec  # Python 2
import warnings

import autograd.numpy as np

from .builtins import tuple as atuple
from .core import make_jvp as _make_jvp
from .core import make_vjp as _make_vjp
from .extend import defvjp_argnum, primitive, vspace
from .wrap_util import unary_to_nary

make_vjp = unary_to_nary(_make_vjp)
make_jvp = unary_to_nary(_make_jvp)


@unary_to_nary
def grad(fun, x):
    """
    Returns a function which computes the gradient of `fun` with respect to
    positional argument number `argnum`. The returned function takes the same
    arguments as `fun`, but returns the gradient instead. The function `fun`
    should be scalar-valued. The gradient has the same type as the argument."""
    vjp, ans = _make_vjp(fun, x)
    if not vspace(ans).size == 1:
        raise TypeError(
            "Grad only applies to real scalar-output functions. "
            "Try jacobian, elementwise_grad or holomorphic_grad."
        )
    return vjp(vspace(ans).ones())


@unary_to_nary
def elementwise_grad(fun, x):
    """
    Returns a function that computes the sum of each column of the Jacobian of
    `fun`, in one pass. If the Jacobian is diagonal, then this is the diagonal
    of the Jacobian.
    """
    vjp, ans = _make_vjp(fun, x)
    if vspace(ans).iscomplex:
        raise TypeError("Elementwise_grad only applies to real-output functions.")
    return vjp(vspace(ans).ones())


@unary_to_nary
def deriv(fun, x):
    return _make_jvp(fun, x)(vspace(x).ones())[1]


@unary_to_nary
def jacobian(fun, x):
    """
    Returns a function which computes the Jacobian of `fun` with respect to
    positional argument number `argnum`, which must be a scalar or array. Unlike
    `grad` it is not restricted to scalar-output functions, but also it cannot
    take derivatives with respect to some argument types (like lists or dicts).
    If the input to `fun` has shape (in1, in2, ...) and the output has shape
    (out1, out2, ...) then the Jacobian has shape (out1, out2, ..., in1, in2, ...).
    """
    vjp, ans = _make_vjp(fun, x)
    ans_vspace = vspace(ans)
    jacobian_shape = ans_vspace.shape + vspace(x).shape
    grads = map(vjp, ans_vspace.standard_basis())
    return np.reshape(np.stack(grads), jacobian_shape)


@unary_to_nary
def holomorphic_grad(fun, x):
    if not vspace(x).iscomplex:
        warnings.warn("Input to holomorphic_grad is not complex")
    return grad(lambda x: np.real(fun(x)))(x)


def grad_named(fun, argname):
    """Takes gradients with respect to a named argument.
    Doesn't work on *args or **kwargs."""
    arg_index = _getargspec(fun).args.index(argname)
    return grad(fun, arg_index)


@unary_to_nary
def hessian(fun, x):
    "Returns a function that computes the exact Hessian."
    return jacobian(jacobian(fun))(x)


@unary_to_nary
def make_hvp(fun, x):
    """Builds a function for evaluating the Hessian-vector product at a point,
    which may be useful when evaluating many Hessian-vector products at the same
    point while caching the results of the forward pass."""
    return _make_vjp(grad(fun), x)


def hessian_tensor_product(fun, argnum=0):
    """Builds a function that returns the exact Hessian-tensor product.
    The returned function has arguments (*args, tensor, **kwargs), and for
    vectors takes roughly 4x as long to evaluate as the original function."""
    fun_grad = grad(fun, argnum)

    def vector_dot_grad(*args, **kwargs):
        args, vector = args[:-1], args[-1]
        return np.tensordot(fun_grad(*args, **kwargs), vector, np.ndim(vector))

    return grad(vector_dot_grad, argnum)


hessian_vector_product = hessian_tensor_product


def tensor_jacobian_product(fun, argnum=0):
    """Builds a function that returns the exact tensor-Jacobian product, that
    is the Jacobian matrix left-multiplied by tensor. The returned function
    has arguments (*args, tensor, **kwargs)."""

    def vector_dot_fun(*args, **kwargs):
        args, vector = args[:-1], args[-1]
        return np.tensordot(vector, fun(*args, **kwargs), axes=np.ndim(vector))

    return jacobian(vector_dot_fun, argnum)


vector_jacobian_product = tensor_jacobian_product


@unary_to_nary
def make_jvp_reversemode(fun, x):
    """Builds a function for evaluating the Jacobian-vector product at a
    point. Roughly 1.5x more FLOPs than forward-mode, plus memory requirements
    that scale with the number of primitives applied in the evaluation of f, as
    well as other overheads. See j-towns.github.io/2017/06/12/A-new-trick.html."""
    vjp, y = _make_vjp(fun, x)
    vjp_vjp, _ = _make_vjp(vjp, vspace(y).zeros())
    return vjp_vjp  # vjp_vjp is just jvp by linearity


# TODO(mattjj): update this function using make_jvp and const_graph
def make_ggnvp(f, g=lambda x: 1.0 / 2 * np.sum(x**2, axis=-1), f_argnum=0):
    """Builds a function for evaluating generalized-Gauss-Newton-vector products
    at a point. Slightly more expensive than mixed-mode."""

    @unary_to_nary
    def _make_ggnvp(f, x):
        f_vjp, f_x = _make_vjp(f, x)
        g_hvp, grad_g_x = _make_vjp(grad(g), f_x)
        f_jvp, _ = _make_vjp(f_vjp, vspace(grad_g_x).zeros())

        def ggnvp(v):
            return f_vjp(g_hvp(f_jvp(v)))

        return ggnvp

    return _make_ggnvp(f, f_argnum)


@unary_to_nary
def value_and_grad(fun, x):
    """Returns a function that returns both value and gradient. Suitable for use
    in scipy.optimize"""
    vjp, ans = _make_vjp(fun, x)
    if not vspace(ans).size == 1:
        raise TypeError(
            "value_and_grad only applies to real scalar-output "
            "functions. Try jacobian, elementwise_grad or "
            "holomorphic_grad."
        )
    return ans, vjp(vspace(ans).ones())


@unary_to_nary
def grad_and_aux(fun, x):
    """Builds a function that returns the gradient of the first output and the
    (unmodified) second output of a function that returns two outputs."""
    vjp, (ans, aux) = _make_vjp(lambda x: atuple(fun(x)), x)
    return vjp((vspace(ans).ones(), vspace(aux).zeros())), aux


def multigrad_dict(fun):
    "Takes gradients wrt all arguments simultaneously,"
    "returns a dict mapping 'argname' to 'gradval'"

    import funcsigs

    sig = funcsigs.signature(fun)

    def select(preds, lst):
        idx = lambda item: next((i for i, pred in enumerate(preds) if pred(item)), len(preds))
        results = [[] for _ in preds] + [[]]
        for item in lst:
            results[idx(item)].append(item)
        return results

    is_var_pos = lambda name: sig.parameters[name].kind == sig.parameters[name].VAR_POSITIONAL
    is_var_kwd = lambda name: sig.parameters[name].kind == sig.parameters[name].VAR_KEYWORD
    var_pos, var_kwd, argnames = select([is_var_pos, is_var_kwd], sig.parameters)

    todict = lambda dct: {key: dct[key] for key in dct}

    def apply_defaults(arguments):
        defaults = {
            name: param.default for name, param in sig.parameters.items() if param.default is not param.empty
        }
        return OrderedDict(
            (name, arguments[name] if name in arguments else defaults[name]) for name in sig.parameters
        )

    def gradfun(*args, **kwargs):
        bindings = sig.bind(*args, **kwargs)

        args = lambda dct: tuple(dct[var_pos[0]]) if var_pos else ()
        kwargs = lambda dct: todict(dct[var_kwd[0]]) if var_kwd else {}
        others = lambda dct: tuple(dct[argname] for argname in argnames if argname not in var_kwd + var_pos)

        newfun = lambda dct: fun(*(others(dct) + args(dct)), **kwargs(dct))

        argdict = apply_defaults(bindings.arguments)
        grad_dict = grad(newfun)(dict(argdict))
        return OrderedDict((argname, grad_dict[argname]) for argname in argdict)

    return gradfun


def checkpoint(fun):
    """Returns a checkpointed version of `fun`, where intermediate values
    computed during the forward pass of `fun` are discarded and then recomputed
    for the backward pass. Useful to save memory, effectively trading off time
    and memory. See e.g. arxiv.org/abs/1604.06174.
    """

    def wrapped_grad(argnum, ans, args, kwargs):
        return make_vjp(fun, argnum)(*args, **kwargs)[0]

    wrapped = primitive(fun)
    defvjp_argnum(wrapped, wrapped_grad)
    return wrapped


================================================
FILE: autograd/extend.py
================================================
# Exposes API for extending autograd
from .core import (
    JVPNode,
    SparseObject,
    VJPNode,
    VSpace,
    def_linear,
    defjvp,
    defjvp_argnum,
    defjvp_argnums,
    defvjp,
    defvjp_argnum,
    defvjp_argnums,
    vspace,
)
from .tracer import Box, notrace_primitive, primitive, register_notrace


================================================
FILE: autograd/misc/__init__.py
================================================
from .flatten import flatten
from .tracers import const_graph


================================================
FILE: autograd/misc/fixed_points.py
================================================
from autograd import make_vjp
from autograd.builtins import tuple
from autograd.extend import defvjp, primitive, vspace


@primitive
def fixed_point(f, a, x0, distance, tol):
    _f = f(a)
    x, x_prev = _f(x0), x0
    while distance(x, x_prev) > tol:
        x, x_prev = _f(x), x
    return x


def fixed_point_vjp(ans, f, a, x0, distance, tol):
    def rev_iter(params):
        a, x_star, x_star_bar = params
        vjp_x, _ = make_vjp(f(a))(x_star)
        vs = vspace(x_star)
        return lambda g: vs.add(vjp_x(g), x_star_bar)

    vjp_a, _ = make_vjp(lambda x, y: f(x)(y))(a, ans)
    return lambda g: vjp_a(fixed_point(rev_iter, tuple((a, ans, g)), vspace(x0).zeros(), distance, tol))


defvjp(fixed_point, None, fixed_point_vjp, None)


================================================
FILE: autograd/misc/flatten.py
================================================
"""
Handy functions for flattening nested containers containing numpy
arrays. The main purpose is to make examples and optimizers simpler.
"""

import autograd.numpy as np
from autograd import make_vjp
from autograd.builtins import type


def flatten(value):
    """Flattens any nesting of tuples, lists, or dicts, with numpy arrays or
    scalars inside. Returns 1D numpy array and an unflatten function.
    Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict
    keys are sortable."""
    unflatten, flat_value = make_vjp(_flatten)(value)
    return flat_value, unflatten


def _flatten(value):
    t = type(value)
    if t in (list, tuple):
        return _concatenate(map(_flatten, value))
    elif t is dict:
        return _concatenate(_flatten(value[k]) for k in sorted(value))
    else:
        return np.ravel(value)


def _concatenate(lst):
    lst = list(lst)
    return np.concatenate(lst) if lst else np.array([])


def flatten_func(func, example):
    _ex, unflatten = flatten(example)
    _func = lambda _x, *args: flatten(func(unflatten(_x), *args))[0]
    return _func, unflatten, _ex


================================================
FILE: autograd/misc/optimizers.py
================================================
"""Some standard gradient-based stochastic optimizers.

These are just standard routines that don't make any use of autograd,
though you could take gradients of these functions too if you want
to do meta-optimization.

These routines can optimize functions whose inputs are structured
objects, such as dicts of numpy arrays."""

import autograd.numpy as np
from autograd.misc import flatten
from autograd.wrap_util import wraps


def unflatten_optimizer(optimize):
    """Takes an optimizer that operates on flat 1D numpy arrays and returns a
    wrapped version that handles trees of nested containers (lists/tuples/dicts)
    with arrays/scalars at the leaves."""

    @wraps(optimize)
    def _optimize(grad, x0, callback=None, *args, **kwargs):
        _x0, unflatten = flatten(x0)
        _grad = lambda x, i: flatten(grad(unflatten(x), i))[0]
        if callback:
            _callback = lambda x, i, g: callback(unflatten(x), i, unflatten(g))
        else:
            _callback = None
        return unflatten(optimize(_grad, _x0, _callback, *args, **kwargs))

    return _optimize


@unflatten_optimizer
def sgd(grad, x, callback=None, num_iters=200, step_size=0.1, mass=0.9):
    """Stochastic gradient descent with momentum.
    grad() must have signature grad(x, i), where i is the iteration number."""
    velocity = np.zeros(len(x))
    for i in range(num_iters):
        g = grad(x, i)
        if callback:
            callback(x, i, g)
        velocity = mass * velocity - (1.0 - mass) * g
        x = x + step_size * velocity
    return x


@unflatten_optimizer
def rmsprop(grad, x, callback=None, num_iters=100, step_size=0.1, gamma=0.9, eps=10**-8):
    """Root mean squared prop: See Adagrad paper for details."""
    avg_sq_grad = np.ones(len(x))
    for i in range(num_iters):
        g = grad(x, i)
        if callback:
            callback(x, i, g)
        avg_sq_grad = avg_sq_grad * gamma + g**2 * (1 - gamma)
        x = x - step_size * g / (np.sqrt(avg_sq_grad) + eps)
    return x


@unflatten_optimizer
def adam(grad, x, callback=None, num_iters=100, step_size=0.001, b1=0.9, b2=0.999, eps=10**-8):
    """Adam as described in http://arxiv.org/pdf/1412.6980.pdf.
    It's basically RMSprop with momentum and some correction terms."""
    m = np.zeros(len(x))
    v = np.zeros(len(x))
    for i in range(num_iters):
        g = grad(x, i)
        if callback:
            callback(x, i, g)
        m = (1 - b1) * g + b1 * m  # First  moment estimate.
        v = (1 - b2) * (g**2) + b2 * v  # Second moment estimate.
        mhat = m / (1 - b1 ** (i + 1))  # Bias correction.
        vhat = v / (1 - b2 ** (i + 1))
        x = x - step_size * mhat / (np.sqrt(vhat) + eps)
    return x


================================================
FILE: autograd/misc/tracers.py
================================================
from functools import partial
from itertools import repeat

from autograd.tracer import Node, trace
from autograd.util import subvals, toposort
from autograd.wrap_util import wraps


class ConstGraphNode(Node):
    __slots__ = ["parents", "partial_fun"]

    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        args = subvals(args, zip(parent_argnums, repeat(None)))

        def partial_fun(partial_args):
            return fun(*subvals(args, zip(parent_argnums, partial_args)), **kwargs)

        self.parents = parents
        self.partial_fun = partial_fun

    def initialize_root(self):
        self.parents = []


def const_graph_unary(fun):
    graph = []
    _fun = [fun]  # Allow fun to be freed, since it may have bound args

    def maybe_cached_fun(x):
        if graph:
            _graph = graph[0]
            vals = {_graph[0]: x}
            for node in _graph[1:]:
                vals[node] = node.partial_fun([vals[p] for p in node.parents])
            return vals[node]
        else:
            start_node = ConstGraphNode.new_root()
            end_value, end_node = trace(start_node, _fun.pop(), x)
            if end_node is None:
                raise Exception("Output is independent of input")
            graph.append(list(toposort(end_node))[::-1])
            return end_value

    return maybe_cached_fun


def const_graph(fun, *args, **kwargs):
    partial_fun = partial(fun, *args, **kwargs)
    unary_fun = lambda args: partial_fun(*args)
    maybe_cached_unary_fun = const_graph_unary(unary_fun)

    @wraps(fun)
    def _fun(*args):
        return maybe_cached_unary_fun(args)

    return _fun


class FullGraphNode(Node):
    __slots__ = ["value", "recipe"]

    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        self.value = value
        self.recipe = (fun, args, kwargs, zip(parent_argnums, parents))

    def initialize_root(self):
        self.value = None
        self.recipe = (lambda x: x, (), {}, [])


def full_graph(fun, *args, **kwargs):
    unary_fun = lambda args: fun(*args, **kwargs)
    start_node = FullGraphNode.new_root()
    end_value, end_node = trace(start_node, unary_fun, args)
    return end_node


================================================
FILE: autograd/numpy/__init__.py
================================================
from . import fft, linalg, numpy_boxes, numpy_jvps, numpy_vjps, numpy_vspaces, random
from .numpy_wrapper import *
from .numpy_wrapper import numpy_version as __version__


================================================
FILE: autograd/numpy/fft.py
================================================
import numpy.fft as ffto

from autograd.extend import defvjp, primitive, vspace

from . import numpy_wrapper as anp
from .numpy_vjps import match_complex
from .numpy_wrapper import wrap_namespace

wrap_namespace(ffto.__dict__, globals())


# TODO: make fft gradient work for a repeated axis,
# e.g. by replacing fftn with repeated calls to 1d fft along each axis
def fft_grad(get_args, fft_fun, ans, x, *args, **kwargs):
    axes, s, norm = get_args(x, *args, **kwargs)
    check_no_repeated_axes(axes)
    vs = vspace(x)
    return lambda g: match_complex(x, truncate_pad(fft_fun(g, *args, **kwargs), vs.shape))


defvjp(fft, lambda *args, **kwargs: fft_grad(get_fft_args, fft, *args, **kwargs))
defvjp(ifft, lambda *args, **kwargs: fft_grad(get_fft_args, ifft, *args, **kwargs))

defvjp(fft2, lambda *args, **kwargs: fft_grad(get_fft_args, fft2, *args, **kwargs))
defvjp(ifft2, lambda *args, **kwargs: fft_grad(get_fft_args, ifft2, *args, **kwargs))

defvjp(fftn, lambda *args, **kwargs: fft_grad(get_fft_args, fftn, *args, **kwargs))
defvjp(ifftn, lambda *args, **kwargs: fft_grad(get_fft_args, ifftn, *args, **kwargs))


def rfft_grad(get_args, irfft_fun, ans, x, *args, **kwargs):
    axes, s, norm = get_args(x, *args, **kwargs)
    vs = vspace(x)
    gvs = vspace(ans)
    check_no_repeated_axes(axes)
    if s is None:
        s = [vs.shape[i] for i in axes]
    check_even_shape(s)

    # s is the full fft shape
    # gs is the compressed shape
    gs = list(s)
    gs[-1] = gs[-1] // 2 + 1
    fac = make_rfft_factors(axes, gvs.shape, gs, s, norm)

    def vjp(g):
        g = anp.conj(g / fac)
        r = match_complex(x, truncate_pad((irfft_fun(g, *args, **kwargs)), vs.shape))
        return r

    return vjp


def irfft_grad(get_args, rfft_fun, ans, x, *args, **kwargs):
    axes, gs, norm = get_args(x, *args, **kwargs)
    vs = vspace(x)
    gvs = vspace(ans)
    check_no_repeated_axes(axes)
    if gs is None:
        gs = [gvs.shape[i] for i in axes]
    check_even_shape(gs)

    # gs is the full fft shape
    # s is the compressed shape
    s = list(gs)
    s[-1] = s[-1] // 2 + 1

    def vjp(g):
        r = match_complex(x, truncate_pad((rfft_fun(g, *args, **kwargs)), vs.shape))
        fac = make_rfft_factors(axes, vs.shape, s, gs, norm)
        r = anp.conj(r) * fac
        return r

    return vjp


defvjp(rfft, lambda *args, **kwargs: rfft_grad(get_fft_args, irfft, *args, **kwargs))

defvjp(irfft, lambda *args, **kwargs: irfft_grad(get_fft_args, rfft, *args, **kwargs))

defvjp(rfft2, lambda *args, **kwargs: rfft_grad(get_fft2_args, irfft2, *args, **kwargs))

defvjp(irfft2, lambda *args, **kwargs: irfft_grad(get_fft2_args, rfft2, *args, **kwargs))

defvjp(rfftn, lambda *args, **kwargs: rfft_grad(get_fftn_args, irfftn, *args, **kwargs))

defvjp(irfftn, lambda *args, **kwargs: irfft_grad(get_fftn_args, rfftn, *args, **kwargs))

defvjp(
    fftshift, lambda ans, x, axes=None: lambda g: match_complex(x, anp.conj(ifftshift(anp.conj(g), axes)))
)
defvjp(
    ifftshift, lambda ans, x, axes=None: lambda g: match_complex(x, anp.conj(fftshift(anp.conj(g), axes)))
)


@primitive
def truncate_pad(x, shape):
    # truncate/pad x to have the appropriate shape
    slices = [slice(n) for n in shape]
    pads = tuple(
        zip(anp.zeros(len(shape), dtype=int), anp.maximum(0, anp.array(shape) - anp.array(x.shape)))
    )
    return anp.pad(x, pads, "constant")[tuple(slices)]


defvjp(truncate_pad, lambda ans, x, shape: lambda g: match_complex(x, truncate_pad(g, vspace(x).shape)))


## TODO: could be made less stringent, to fail only when repeated axis has different values of s
def check_no_repeated_axes(axes):
    axes_set = set(axes)
    if len(axes) != len(axes_set):
        raise NotImplementedError("FFT gradient for repeated axes not implemented.")


def check_even_shape(shape):
    if shape[-1] % 2 != 0:
        raise NotImplementedError("Real FFT gradient for odd lengthed last axes is not implemented.")


def get_fft_args(a, d=None, axis=-1, norm=None, *args, **kwargs):
    axes = [axis]
    if d is not None:
        d = [d]
    return axes, d, norm


def get_fft2_args(a, s=None, axes=(-2, -1), norm=None, *args, **kwargs):
    return axes, s, norm


def get_fftn_args(a, s=None, axes=None, norm=None, *args, **kwargs):
    if axes is None:
        axes = list(range(a.ndim))
    return axes, s, norm


def make_rfft_factors(axes, resshape, facshape, normshape, norm):
    """make the compression factors and compute the normalization
    for irfft and rfft.
    """
    N = 1.0
    for n in normshape:
        N = N * n

    # inplace modification is fine because we produce a constant
    # which doesn't go into autograd.
    # For same reason could have used numpy rather than anp.
    # but we already imported anp, so use it instead.
    fac = anp.zeros(resshape)
    fac[...] = 2
    index = [slice(None)] * len(resshape)
    if facshape[-1] <= resshape[axes[-1]]:
        index[axes[-1]] = (0, facshape[-1] - 1)
    else:
        index[axes[-1]] = (0,)
    fac[tuple(index)] = 1
    if norm is None:
        fac /= N
    return fac


================================================
FILE: autograd/numpy/linalg.py
================================================
from functools import partial

import numpy.linalg as npla

from autograd.extend import defjvp, defvjp

from . import numpy_wrapper as anp
from .numpy_wrapper import wrap_namespace

wrap_namespace(npla.__dict__, globals())

# Some formulas are from
# "An extended collection of matrix derivative results
#  for forward and reverse mode algorithmic differentiation"
# by Mike Giles
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf


# transpose by swapping last two dimensions
def T(x):
    return anp.swapaxes(x, -1, -2)


_dot = partial(anp.einsum, "...ij,...jk->...ik")

# batched diag
_diag = lambda a: anp.eye(a.shape[-1]) * a


# batched diagonal, similar to matrix_diag in tensorflow
def _matrix_diag(a):
    reps = anp.array(a.shape)
    reps[:-1] = 1
    reps[-1] = a.shape[-1]
    newshape = list(a.shape) + [a.shape[-1]]
    return _diag(anp.tile(a, reps).reshape(newshape))


# add two dimensions to the end of x
def add2d(x):
    return anp.reshape(x, anp.shape(x) + (1, 1))


defvjp(det, lambda ans, x: lambda g: add2d(g) * add2d(ans) * T(inv(x)))
defvjp(slogdet, lambda ans, x: lambda g: add2d(g[1]) * T(inv(x)))


def grad_inv(ans, x):
    return lambda g: -_dot(_dot(T(ans), g), T(ans))


defvjp(inv, grad_inv)


def grad_pinv(ans, x):
    # https://mathoverflow.net/questions/25778/analytical-formula-for-numerical-derivative-of-the-matrix-pseudo-inverse
    return lambda g: T(
        -_dot(_dot(ans, T(g)), ans)
        + _dot(_dot(_dot(ans, T(ans)), g), anp.eye(x.shape[-2]) - _dot(x, ans))
        + _dot(_dot(_dot(anp.eye(ans.shape[-2]) - _dot(ans, x), g), T(ans)), ans)
    )


defvjp(pinv, grad_pinv)


def grad_solve(argnum, ans, a, b):
    updim = lambda x: x if x.ndim == a.ndim else x[..., None]
    if argnum == 0:
        return lambda g: -_dot(updim(solve(T(a), g)), T(updim(ans)))
    else:
        return lambda g: solve(T(a), g)


defvjp(solve, partial(grad_solve, 0), partial(grad_solve, 1))


def norm_vjp(ans, x, ord=None, axis=None):
    def check_implemented():
        matrix_norm = (x.ndim == 2 and axis is None) or isinstance(axis, tuple)

        if matrix_norm:
            if not (ord is None or ord == "fro" or ord == "nuc"):
                raise NotImplementedError(f"Gradient of matrix norm not implemented for ord={ord}")
        elif not (ord is None or ord > 1):
            raise NotImplementedError(f"Gradient of norm not implemented for ord={ord}")

    if axis is None:
        expand = lambda a: a
    elif isinstance(axis, tuple):
        row_axis, col_axis = axis
        if row_axis > col_axis:
            row_axis = row_axis - 1
        expand = lambda a: anp.expand_dims(anp.expand_dims(a, row_axis), col_axis)
    else:
        expand = lambda a: anp.expand_dims(a, axis=axis)

    if ord == "nuc":
        if axis is None:
            roll = lambda a: a
            unroll = lambda a: a
        else:
            row_axis, col_axis = axis
            if row_axis > col_axis:
                row_axis = row_axis - 1
            # Roll matrix axes to the back
            roll = lambda a: anp.rollaxis(anp.rollaxis(a, col_axis, a.ndim), row_axis, a.ndim - 1)
            # Roll matrix axes to their original position
            unroll = lambda a: anp.rollaxis(anp.rollaxis(a, a.ndim - 2, row_axis), a.ndim - 1, col_axis)

    check_implemented()

    def vjp(g):
        if ord in (None, 2, "fro"):
            return expand(g / ans) * anp.conj(x)
        elif ord == "nuc":
            x_rolled = roll(x)
            u, s, vt = svd(x_rolled, full_matrices=False)
            uvt_rolled = _dot(u, vt)
            # Roll the matrix axes back to their correct positions
            uvt = unroll(uvt_rolled)
            g = expand(g)
            return g * anp.conj(uvt)
        else:
            # see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
            return expand(g / ans ** (ord - 1)) * anp.conj(x) * anp.abs(x) ** (ord - 2)

    return vjp


defvjp(norm, norm_vjp)


def norm_jvp(g, ans, x, ord=None, axis=None):
    def check_implemented():
        matrix_norm = (x.ndim == 2 and axis is None) or isinstance(axis, tuple)

        if matrix_norm:
            if not (ord is None or ord == "fro" or ord == "nuc"):
                raise NotImplementedError(f"Gradient of matrix norm not implemented for ord={ord}")
        elif not (ord is None or ord > 1):
            raise NotImplementedError(f"Gradient of norm not implemented for ord={ord}")

    if axis is None:
        contract = lambda a: anp.sum(a)
    else:
        contract = partial(anp.sum, axis=axis)

    if ord == "nuc":
        if axis is None:
            roll = lambda a: a
            unroll = lambda a: a
        else:
            row_axis, col_axis = axis
            if row_axis > col_axis:
                row_axis = row_axis - 1
            # Roll matrix axes to the back
            roll = lambda a: anp.rollaxis(anp.rollaxis(a, col_axis, a.ndim), row_axis, a.ndim - 1)
            # Roll matrix axes to their original position
            unroll = lambda a: anp.rollaxis(anp.rollaxis(a, a.ndim - 2, row_axis), a.ndim - 1, col_axis)

    check_implemented()
    if ord in (None, 2, "fro"):
        return contract(g * anp.conj(x)) / ans
    elif ord == "nuc":
        x_rolled = roll(x)
        u, s, vt = svd(x_rolled, full_matrices=False)
        uvt_rolled = _dot(u, vt)
        # Roll the matrix axes back to their correct positions
        uvt = unroll(uvt_rolled)
        return contract(g * anp.conj(uvt))
    else:
        # see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
        return contract(g * anp.conj(x) * anp.abs(x) ** (ord - 2)) / ans ** (ord - 1)


defjvp(norm, norm_jvp)


def grad_eigh(ans, x, UPLO="L"):
    """Gradient for eigenvalues and vectors of a symmetric matrix."""
    N = x.shape[-1]
    w, v = ans  # Eigenvalues, eigenvectors.
    vc = anp.conj(v)

    def vjp(g):
        wg, vg = g  # Gradient w.r.t. eigenvalues, eigenvectors.
        w_repeated = anp.repeat(w[..., anp.newaxis], N, axis=-1)

        # Eigenvalue part
        vjp_temp = _dot(vc * wg[..., anp.newaxis, :], T(v))

        # Add eigenvector part only if non-zero backward signal is present.
        # This can avoid NaN results for degenerate cases if the function depends
        # on the eigenvalues only.
        if anp.any(vg):
            off_diag = anp.ones((N, N)) - anp.eye(N)
            F = off_diag / (T(w_repeated) - w_repeated + anp.eye(N))
            vjp_temp += _dot(_dot(vc, F * _dot(T(v), vg)), T(v))

        # eigh always uses only the lower or the upper part of the matrix
        # we also have to make sure broadcasting works
        reps = anp.array(x.shape)
        reps[-2:] = 1

        if UPLO == "L":
            tri = anp.tile(anp.tril(anp.ones(N), -1), reps)
        elif UPLO == "U":
            tri = anp.tile(anp.triu(anp.ones(N), 1), reps)

        return anp.real(vjp_temp) * anp.eye(vjp_temp.shape[-1]) + (vjp_temp + anp.conj(T(vjp_temp))) * tri

    return vjp


defvjp(eigh, grad_eigh)


# https://arxiv.org/pdf/1701.00392.pdf Eq(4.77)
# Note the formula from Sec3.1 in https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf is incomplete
def grad_eig(ans, x):
    """Gradient of a general square (complex valued) matrix"""
    e, u = ans  # eigenvalues as 1d array, eigenvectors in columns
    n = e.shape[-1]

    def vjp(g):
        ge, gu = g
        ge = _matrix_diag(ge)
        f = 1 / (e[..., anp.newaxis, :] - e[..., :, anp.newaxis] + 1.0e-20)
        f -= _diag(f)
        ut = anp.swapaxes(u, -1, -2)
        r1 = f * _dot(ut, gu)
        r2 = -f * (_dot(_dot(ut, anp.conj(u)), anp.real(_dot(ut, gu)) * anp.eye(n)))
        r = _dot(_dot(inv(ut), ge + r1 + r2), ut)
        if not anp.iscomplexobj(x):
            r = anp.real(r)
            # the derivative is still complex for real input (imaginary delta is allowed), real output
            # but the derivative should be real in real input case when imaginary delta is forbidden
        return r

    return vjp


defvjp(eig, grad_eig)


def grad_cholesky(L, A):
    # Based on Iain Murray's note http://arxiv.org/abs/1602.07527
    # scipy's dtrtrs wrapper, solve_triangular, doesn't broadcast along leading
    # dimensions, so we just call a generic LU solve instead of directly using
    # backsubstitution (also, we factor twice...)
    solve_trans = lambda a, b: solve(T(a), b)
    phi = lambda X: anp.tril(X) / (1.0 + anp.eye(X.shape[-1]))

    def conjugate_solve(L, X):
        # X -> L^{-T} X L^{-1}
        return solve_trans(L, T(solve_trans(L, T(X))))

    def vjp(g):
        S = conjugate_solve(L, phi(anp.einsum("...ki,...kj->...ij", L, g)))
        return (S + T(S)) / 2.0

    return vjp


defvjp(cholesky, grad_cholesky)


# https://j-towns.github.io/papers/svd-derivative.pdf
# https://arxiv.org/abs/1909.02659
def grad_svd(usv_, a, full_matrices=True, compute_uv=True):
    def vjp(g):
        usv = usv_

        if not compute_uv:
            s = usv

            # Need U and V so do the whole svd anyway...
            usv = svd(a, full_matrices=False)
            u = usv[0]
            v = anp.conj(T(usv[2]))

            return _dot(anp.conj(u) * g[..., anp.newaxis, :], T(v))

        elif full_matrices:
            raise NotImplementedError("Gradient of svd not implemented for full_matrices=True")

        else:
            u = usv[0]
            s = usv[1]
            v = anp.conj(T(usv[2]))

            m, n = a.shape[-2:]

            k = anp.min((m, n))
            # broadcastable identity array with shape (1, 1, ..., 1, k, k)
            i = anp.reshape(anp.eye(k), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (k, k))))

            f = 1 / (s[..., anp.newaxis, :] ** 2 - s[..., :, anp.newaxis] ** 2 + i)

            gu = g[0]
            gs = g[1]
            gv = anp.conj(T(g[2]))

            utgu = _dot(T(u), gu)
            vtgv = _dot(T(v), gv)
            t1 = (f * (utgu - anp.conj(T(utgu)))) * s[..., anp.newaxis, :]
            t1 = t1 + i * gs[..., :, anp.newaxis]
            t1 = t1 + s[..., :, anp.newaxis] * (f * (vtgv - anp.conj(T(vtgv))))

            if anp.iscomplexobj(u):
                t1 = t1 + 1j * anp.imag(_diag(utgu)) / s[..., anp.newaxis, :]

            t1 = _dot(_dot(anp.conj(u), t1), T(v))

            if m < n:
                i_minus_vvt = anp.reshape(
                    anp.eye(n), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (n, n)))
                ) - _dot(v, anp.conj(T(v)))
                t1 = t1 + anp.conj(_dot(_dot(u / s[..., anp.newaxis, :], T(gv)), i_minus_vvt))

                return t1

            elif m == n:
                return t1

            elif m > n:
                i_minus_uut = anp.reshape(
                    anp.eye(m), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (m, m)))
                ) - _dot(u, anp.conj(T(u)))
                t1 = t1 + T(_dot(_dot(v / s[..., anp.newaxis, :], T(gu)), i_minus_uut))

                return t1

    return vjp


defvjp(svd, grad_svd)


================================================
FILE: autograd/numpy/numpy_boxes.py
================================================
import numpy as np

from autograd.builtins import SequenceBox
from autograd.extend import Box, primitive

from . import numpy_wrapper as anp

Box.__array_priority__ = 90.0


class ArrayBox(Box):
    __slots__ = []
    __array_priority__ = 100.0

    @primitive
    def __getitem__(A, idx):
        return A[idx]

    # Constants w.r.t float data just pass though
    shape = property(lambda self: self._value.shape)
    ndim = property(lambda self: self._value.ndim)
    size = property(lambda self: self._value.size)
    dtype = property(lambda self: self._value.dtype)
    T = property(lambda self: anp.transpose(self))

    def __array_namespace__(self, *, api_version: str | None = None):
        return anp

    def __len__(self):
        return len(self._value)

    def astype(self, *args, **kwargs):
        return anp._astype(self, *args, **kwargs)

    def __neg__(self):
        return anp.negative(self)

    def __add__(self, other):
        return anp.add(self, other)

    def __sub__(self, other):
        return anp.subtract(self, other)

    def __mul__(self, other):
        return anp.multiply(self, other)

    def __pow__(self, other):
        return anp.power(self, other)

    def __div__(self, other):
        return anp.divide(self, other)

    def __mod__(self, other):
        return anp.mod(self, other)

    def __truediv__(self, other):
        return anp.true_divide(self, other)

    def __matmul__(self, other):
        return anp.matmul(self, other)

    def __radd__(self, other):
        return anp.add(other, self)

    def __rsub__(self, other):
        return anp.subtract(other, self)

    def __rmul__(self, other):
        return anp.multiply(other, self)

    def __rpow__(self, other):
        return anp.power(other, self)

    def __rdiv__(self, other):
        return anp.divide(other, self)

    def __rmod__(self, other):
        return anp.mod(other, self)

    def __rtruediv__(self, other):
        return anp.true_divide(other, self)

    def __rmatmul__(self, other):
        return anp.matmul(other, self)

    def __eq__(self, other):
        return anp.equal(self, other)

    def __ne__(self, other):
        return anp.not_equal(self, other)

    def __gt__(self, other):
        return anp.greater(self, other)

    def __ge__(self, other):
        return anp.greater_equal(self, other)

    def __lt__(self, other):
        return anp.less(self, other)

    def __le__(self, other):
        return anp.less_equal(self, other)

    def __abs__(self):
        return anp.abs(self)

    def __hash__(self):
        return id(self)


ArrayBox.register(np.ndarray)
for type_ in [
    float,
    np.longdouble,
    np.float64,
    np.float32,
    np.float16,
    complex,
    np.clongdouble,
    np.complex64,
    np.complex128,
]:
    ArrayBox.register(type_)

# These numpy.ndarray methods are just refs to an equivalent numpy function
nondiff_methods = [
    "all",
    "any",
    "argmax",
    "argmin",
    "argpartition",
    "argsort",
    "nonzero",
    "searchsorted",
    "round",
]
diff_methods = [
    "clip",
    "compress",
    "cumprod",
    "cumsum",
    "diagonal",
    "max",
    "mean",
    "min",
    "prod",
    "ptp",
    "ravel",
    "repeat",
    "reshape",
    "squeeze",
    "std",
    "sum",
    "swapaxes",
    "take",
    "trace",
    "transpose",
    "var",
]
for method_name in nondiff_methods + diff_methods:
    setattr(ArrayBox, method_name, anp.__dict__[method_name])

# Flatten has no function, only a method.
setattr(ArrayBox, "flatten", anp.__dict__["ravel"])

if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
    SequenceBox.register(np.linalg._linalg.EigResult)
    SequenceBox.register(np.linalg._linalg.EighResult)
    SequenceBox.register(np.linalg._linalg.QRResult)
    SequenceBox.register(np.linalg._linalg.SlogdetResult)
    SequenceBox.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= "1.25":
    SequenceBox.register(np.linalg.linalg.EigResult)
    SequenceBox.register(np.linalg.linalg.EighResult)
    SequenceBox.register(np.linalg.linalg.QRResult)
    SequenceBox.register(np.linalg.linalg.SlogdetResult)
    SequenceBox.register(np.linalg.linalg.SVDResult)


================================================
FILE: autograd/numpy/numpy_jvps.py
================================================
import numpy as onp

from autograd.extend import JVPNode, def_linear, defjvp, defjvp_argnum, register_notrace, vspace

from ..util import func
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox
from .numpy_vjps import (
    balanced_eq,
    dot_adjoint_0,
    dot_adjoint_1,
    match_complex,
    nograd_functions,
    replace_zero,
    tensordot_adjoint_0,
    tensordot_adjoint_1,
    untake,
)

for fun in nograd_functions:
    register_notrace(JVPNode, fun)

defjvp(func(ArrayBox.__getitem__), "same")
defjvp(untake, "same")

defjvp_argnum(anp.array_from_args, lambda argnum, g, ans, args, kwargs: untake(g, argnum - 2, vspace(ans)))
defjvp(
    anp._array_from_scalar_or_array,
    None,
    None,
    lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g),
)

# ----- Functions that are constant w.r.t. continuous inputs -----
defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.0))

# ----- Binary ufuncs (linear) -----
def_linear(anp.multiply)

# ----- Binary ufuncs -----
defjvp(anp.add, lambda g, ans, x, y: broadcast(g, ans), lambda g, ans, x, y: broadcast(g, ans))
defjvp(anp.subtract, lambda g, ans, x, y: broadcast(g, ans), lambda g, ans, x, y: broadcast(-g, ans))
defjvp(anp.divide, "same", lambda g, ans, x, y: -g * x / y**2)
defjvp(
    anp.maximum,
    lambda g, ans, x, y: g * balanced_eq(x, ans, y),
    lambda g, ans, x, y: g * balanced_eq(y, ans, x),
)
defjvp(
    anp.minimum,
    lambda g, ans, x, y: g * balanced_eq(x, ans, y),
    lambda g, ans, x, y: g * balanced_eq(y, ans, x),
)
defjvp(
    anp.fmax,
    lambda g, ans, x, y: g * balanced_eq(x, ans, y),
    lambda g, ans, x, y: g * balanced_eq(y, ans, x),
)
defjvp(
    anp.fmin,
    lambda g, ans, x, y: g * balanced_eq(x, ans, y),
    lambda g, ans, x, y: g * balanced_eq(y, ans, x),
)
defjvp(anp.logaddexp, lambda g, ans, x, y: g * anp.exp(x - ans), lambda g, ans, x, y: g * anp.exp(y - ans))
defjvp(anp.logaddexp2, lambda g, ans, x, y: g * 2 ** (x - ans), lambda g, ans, x, y: g * 2 ** (y - ans))
defjvp(anp.true_divide, "same", lambda g, ans, x, y: -g * x / y**2)
defjvp(anp.mod, lambda g, ans, x, y: broadcast(g, ans), lambda g, ans, x, y: -g * anp.floor(x / y))
defjvp(anp.remainder, lambda g, ans, x, y: broadcast(g, ans), lambda g, ans, x, y: -g * anp.floor(x / y))
defjvp(
    anp.power,
    lambda g, ans, x, y: g * y * x ** anp.where(y, y - 1, 1.0),
    lambda g, ans, x, y: g * anp.log(replace_zero(x, 1.0)) * ans,
)
defjvp(anp.arctan2, lambda g, ans, x, y: g * y / (x**2 + y**2), lambda g, ans, x, y: g * -x / (x**2 + y**2))

# ----- Simple grads (linear) -----
defjvp(anp.negative, "same")
defjvp(anp.rad2deg, "same")
defjvp(anp.degrees, "same")
defjvp(anp.deg2rad, "same")
defjvp(anp.radians, "same")
defjvp(anp.reshape, "same")
defjvp(anp.roll, "same")
defjvp(anp.array_split, "same")
defjvp(anp.split, "same")
defjvp(anp.vsplit, "same")
defjvp(anp.hsplit, "same")
defjvp(anp.dsplit, "same")
defjvp(anp.ravel, "same")
defjvp(anp.expand_dims, "same")
defjvp(anp.squeeze, "same")
defjvp(anp.diag, "same")
defjvp(anp.diagonal, "same")
defjvp(anp.make_diagonal, "same")
defjvp(anp.flipud, "same")
defjvp(anp.fliplr, "same")
defjvp(anp.rot90, "same")
defjvp(anp.trace, "same")
defjvp(anp.full, "same", argnums=(1,))
defjvp(anp.triu, "same")
defjvp(anp.tril, "same")
defjvp(anp.swapaxes, "same")
defjvp(anp.rollaxis, "same")
defjvp(anp.moveaxis, "same")
defjvp(anp.broadcast_to, "same")
def_linear(anp.cross)

# ----- Simple grads -----
np_abs_jvp = lambda g, ans, x: anp.real(g * replace_zero(anp.conj(x), 0.0)) / replace_zero(ans, 1.0)
defjvp(anp.abs, np_abs_jvp)
defjvp(anp.absolute, np_abs_jvp)
defjvp(anp.fabs, lambda g, ans, x: anp.sign(x) * g)  # fabs doesn't take complex numbers.
defjvp(anp.reciprocal, lambda g, ans, x: -g / x**2)
defjvp(anp.exp, lambda g, ans, x: ans * g)
defjvp(anp.exp2, lambda g, ans, x: ans * anp.log(2) * g)
defjvp(anp.expm1, lambda g, ans, x: (ans + 1) * g)
defjvp(anp.log, lambda g, ans, x: g / x)
defjvp(anp.log2, lambda g, ans, x: g / x / anp.log(2))
defjvp(anp.log10, lambda g, ans, x: g / x / anp.log(10))
defjvp(anp.log1p, lambda g, ans, x: g / (x + 1))
defjvp(anp.sin, lambda g, ans, x: g * anp.cos(x))
defjvp(anp.cos, lambda g, ans, x: -g * anp.sin(x))
defjvp(anp.tan, lambda g, ans, x: g / anp.cos(x) ** 2)
defjvp(anp.arcsin, lambda g, ans, x: g / anp.sqrt(1 - x**2))
defjvp(anp.arccos, lambda g, ans, x: -g / anp.sqrt(1 - x**2))
defjvp(anp.arctan, lambda g, ans, x: g / (1 + x**2))
defjvp(anp.sinh, lambda g, ans, x: g * anp.cosh(x))
defjvp(anp.cosh, lambda g, ans, x: g * anp.sinh(x))
defjvp(anp.tanh, lambda g, ans, x: g / anp.cosh(x) ** 2)
defjvp(anp.arcsinh, lambda g, ans, x: g / anp.sqrt(x**2 + 1))
defjvp(anp.arccosh, lambda g, ans, x: g / anp.sqrt(x**2 - 1))
defjvp(anp.arctanh, lambda g, ans, x: g / (1 - x**2))
defjvp(anp.square, lambda g, ans, x: g * 2 * x)
defjvp(anp.sqrt, lambda g, ans, x: g * 0.5 * x**-0.5)
defjvp(
    anp.sinc,
    lambda g, ans, x: g * (anp.cos(anp.pi * x) * anp.pi * x - anp.sin(anp.pi * x)) / (anp.pi * x**2),
)
defjvp(anp.clip, lambda g, ans, x, a_min, a_max: g * anp.logical_and(ans != a_min, ans != a_max))
defjvp(anp.real_if_close, lambda g, ans, x: match_complex(ans, g))
defjvp(anp.real, lambda g, ans, x: anp.real(g))
defjvp(anp.imag, lambda g, ans, x: match_complex(ans, -1j * g))
np_conj_jvp = lambda g, ans, x: anp.conj(g)
defjvp(anp.conj, np_conj_jvp)
defjvp(anp.conjugate, np_conj_jvp)
defjvp(anp.angle, lambda g, ans, x: match_complex(ans, g * anp.conj(x * 1j) / anp.abs(x) ** 2))
defjvp(
    anp.where,
    None,
    lambda g, ans, c, x=None, y=None: anp.where(c, g, anp.zeros(anp.shape(g))),
    lambda g, ans, c, x=None, y=None: anp.where(c, anp.zeros(g.shape), g),
)

# ----- Trickier grads -----
defjvp(anp.kron, "same", "same")
defjvp(anp.diff, "same")
defjvp(anp.gradient, "same")
defjvp(anp.repeat, "same")
defjvp(anp.tile, "same")
defjvp(anp.transpose, "same")
defjvp(anp.sum, "same")
defjvp(anp.mean, "same")
defjvp(
    anp.prod, lambda g, ans, x, axis=None, keepdims=False: ans * anp.sum(g / x, axis=axis, keepdims=keepdims)
)
defjvp(
    anp.linspace,
    lambda g, ans, start, stop, *args, **kwargs: anp.linspace(g, 0, *args, **kwargs),
    lambda g, ans, start, stop, *args, **kwargs: anp.linspace(0, g, *args, **kwargs),
)


def forward_grad_np_var(g, ans, x, axis=None, ddof=0, keepdims=False):
    if axis is None:
        num_reps = anp.size(g)
    elif isinstance(axis, int):
        num_reps = anp.shape(g)[axis]
    elif isinstance(axis, tuple):
        num_reps = anp.prod(anp.array(np.shape(g))[list(axis)])

    x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True))
    return 2.0 * anp.sum(anp.real(g * x_minus_mean), axis=axis, keepdims=keepdims) / (num_reps - ddof)


defjvp(anp.var, forward_grad_np_var)


def forward_grad_np_std(g, ans, x, axis=None, ddof=0, keepdims=False):
    if axis is None:
        num_reps = anp.size(g)
    elif isinstance(axis, int):
        num_reps = anp.shape(g)[axis]
    elif isinstance(axis, tuple):
        num_reps = anp.prod(anp.array(anp.shape(g))[list(axis)])

    if num_reps <= 1:
        return anp.zeros_like(ans)
    x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True))
    return anp.sum(anp.real(g * x_minus_mean), axis=axis, keepdims=keepdims) / ((num_reps - ddof) * ans)


defjvp(anp.std, forward_grad_np_std)


def fwd_grad_chooser(g, ans, x, axis=None, keepdims=False):
    if anp.isscalar(x):
        return g
    if not keepdims:
        if isinstance(axis, int):
            ans = anp.expand_dims(ans, axis)
        elif isinstance(axis, tuple):
            for ax in sorted(axis):
                ans = anp.expand_dims(ans, ax)
    chosen_locations = x == ans
    return anp.sum((g * chosen_locations), axis=axis, keepdims=keepdims) / anp.sum(
        chosen_locations, axis=axis, keepdims=keepdims
    )


defjvp(anp.max, fwd_grad_chooser)
defjvp(anp.min, fwd_grad_chooser)
defjvp(anp.amax, fwd_grad_chooser)
defjvp(anp.amin, fwd_grad_chooser)

defjvp(anp.cumsum, "same")

def_linear(anp.inner)
def_linear(anp.matmul)
def_linear(anp.dot)
def_linear(anp.tensordot)
def_linear(anp.outer)

def_linear(dot_adjoint_0)
def_linear(dot_adjoint_1)

def_linear(tensordot_adjoint_0)
def_linear(tensordot_adjoint_1)


def fwd_grad_concatenate_args(argnum, g, ans, axis_args, kwargs):
    result = []
    for i in range(1, len(axis_args)):
        if i == argnum:
            result.append(g)
        else:
            result.append(anp.zeros_like(axis_args[i]))
    return anp.concatenate_args(axis_args[0], *result)


defjvp_argnum(anp.concatenate_args, fwd_grad_concatenate_args)


def fwd_grad_sort(g, ans, x, axis=-1, kind="quicksort", order=None):
    sort_perm = anp.argsort(x, axis, kind, order)
    return g[sort_perm]


defjvp(anp.sort, fwd_grad_sort)
if onp.lib.NumpyVersion(onp.__version__) < "2.0.0":
    defjvp(anp.msort, lambda g, ans, x: fwd_grad_sort(g, ans, x, axis=0))


def fwd_grad_partition(g, ans, x, kth, axis=-1, kind="introselect", order=None):
    partition_perm = anp.argpartition(x, kth, axis, kind, order)
    return g[partition_perm]


defjvp(anp.partition, fwd_grad_partition)


def atleast_jvpmaker(fun):
    def jvp(g, ans, *arys):
        if len(arys) > 1:
            raise NotImplementedError("Can't handle multiple arguments yet.")
        return fun(g)

    return jvp


defjvp(anp.atleast_1d, atleast_jvpmaker(anp.atleast_1d))
defjvp(anp.atleast_2d, atleast_jvpmaker(anp.atleast_2d))
defjvp(anp.atleast_3d, atleast_jvpmaker(anp.atleast_3d))

def_linear(anp.einsum)


# TODO(mattjj): can we call np.broadcast_to or a related function instead?
def broadcast(x, target):
    target_shape, target_ndim, target_dtype, target_iscomplex = anp.metadata(target)
    while anp.ndim(x) < target_ndim:
        x = anp.expand_dims(x, 0)
    for axis, size in enumerate(anp.shape(x)):
        if size == 1:
            x = anp.repeat(x, target_shape[axis], axis=axis)
    if target_iscomplex and not anp.iscomplexobj(x):
        x = x + 0j  # TODO(mattjj): this might promote the dtype
    return x


defjvp(anp.pad, lambda g, ans, array, width, mode, **kwargs: anp.pad(g, width, mode))


================================================
FILE: autograd/numpy/numpy_vjps.py
================================================
from functools import partial

import numpy as onp

from autograd.extend import SparseObject, VJPNode, defvjp, defvjp_argnum, primitive, register_notrace, vspace

from ..util import func
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox

# ----- Non-differentiable functions -----

nograd_functions = [
    anp.floor,
    anp.ceil,
    anp.round,
    anp.rint,
    anp.around,
    anp.fix,
    anp.trunc,
    anp.all,
    anp.any,
    anp.argmax,
    anp.argmin,
    anp.argpartition,
    anp.argsort,
    anp.argwhere,
    anp.nonzero,
    anp.flatnonzero,
    anp.count_nonzero,
    anp.searchsorted,
    anp.sign,
    anp.ndim,
    anp.shape,
    anp.floor_divide,
    anp.logical_and,
    anp.logical_or,
    anp.logical_not,
    anp.logical_xor,
    anp.isfinite,
    anp.isinf,
    anp.isnan,
    anp.isneginf,
    anp.isposinf,
    anp.allclose,
    anp.isclose,
    anp.array_equal,
    anp.array_equiv,
    anp.greater,
    anp.greater_equal,
    anp.less,
    anp.less_equal,
    anp.equal,
    anp.not_equal,
    anp.iscomplexobj,
    anp.iscomplex,
    anp.size,
    anp.isscalar,
    anp.isreal,
    anp.zeros_like,
    anp.ones_like,
    anp.empty_like,
    anp.full_like,
    anp.result_type,
]

for fun in nograd_functions:
    register_notrace(VJPNode, fun)

# ----- Functions that are constant w.r.t. continuous inputs -----

defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.0))

# ----- Binary ufuncs -----

defvjp(
    anp.add, lambda ans, x, y: unbroadcast_f(x, lambda g: g), lambda ans, x, y: unbroadcast_f(y, lambda g: g)
)
defvjp(
    anp.multiply,
    lambda ans, x, y: unbroadcast_f(x, lambda g: y * g),
    lambda ans, x, y: unbroadcast_f(y, lambda g: x * g),
)
defvjp(
    anp.subtract,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g),
)
defvjp(
    anp.divide,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g / y),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g * x / y**2),
)
defvjp(
    anp.maximum,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)),
)
defvjp(
    anp.minimum,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)),
)
defvjp(
    anp.fmax,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)),
)
defvjp(
    anp.fmin,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)),
)
defvjp(
    anp.logaddexp,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * anp.exp(x - ans)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * anp.exp(y - ans)),
)
defvjp(
    anp.logaddexp2,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * 2 ** (x - ans)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * 2 ** (y - ans)),
)
defvjp(
    anp.true_divide,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g / y),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g * x / y**2),
)
defvjp(
    anp.mod,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g * anp.floor(x / y)),
)
defvjp(
    anp.remainder,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g * anp.floor(x / y)),
)
defvjp(
    anp.power,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * y * x ** anp.where(y, y - 1, 1.0)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * anp.log(replace_zero(x, 1.0)) * ans),
)
defvjp(
    anp.arctan2,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * y / (x**2 + y**2)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * -x / (x**2 + y**2)),
)
defvjp(
    anp.hypot,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * x / ans),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * y / ans),
)

# ----- Simple grads -----

defvjp(anp.negative, lambda ans, x: lambda g: -g)
np_abs_vjp = lambda ans, x: lambda g: g * replace_zero(anp.conj(x), 0.0) / replace_zero(ans, 1.0)
defvjp(anp.abs, np_abs_vjp)
defvjp(anp.absolute, np_abs_vjp)
defvjp(anp.fabs, lambda ans, x: lambda g: anp.sign(x) * g)  # fabs doesn't take complex numbers.
defvjp(anp.reciprocal, lambda ans, x: lambda g: -g / x**2)
defvjp(anp.exp, lambda ans, x: lambda g: ans * g)
defvjp(anp.exp2, lambda ans, x: lambda g: ans * anp.log(2) * g)
defvjp(anp.expm1, lambda ans, x: lambda g: (ans + 1) * g)
defvjp(anp.log, lambda ans, x: lambda g: g / x)
defvjp(anp.log2, lambda ans, x: lambda g: g / x / anp.log(2))
defvjp(anp.log10, lambda ans, x: lambda g: g / x / anp.log(10))
defvjp(anp.log1p, lambda ans, x: lambda g: g / (x + 1))
defvjp(anp.sin, lambda ans, x: lambda g: g * anp.cos(x))
defvjp(anp.cos, lambda ans, x: lambda g: -g * anp.sin(x))
defvjp(anp.tan, lambda ans, x: lambda g: g / anp.cos(x) ** 2)
defvjp(anp.arcsin, lambda ans, x: lambda g: g / anp.sqrt(1 - x**2))
defvjp(anp.arccos, lambda ans, x: lambda g: -g / anp.sqrt(1 - x**2))
defvjp(anp.arctan, lambda ans, x: lambda g: g / (1 + x**2))
defvjp(anp.sinh, lambda ans, x: lambda g: g * anp.cosh(x))
defvjp(anp.cosh, lambda ans, x: lambda g: g * anp.sinh(x))
defvjp(anp.tanh, lambda ans, x: lambda g: g / anp.cosh(x) ** 2)
defvjp(anp.arcsinh, lambda ans, x: lambda g: g / anp.sqrt(x**2 + 1))
defvjp(anp.arccosh, lambda ans, x: lambda g: g / anp.sqrt(x**2 - 1))
defvjp(anp.arctanh, lambda ans, x: lambda g: g / (1 - x**2))
defvjp(anp.rad2deg, lambda ans, x: lambda g: g / anp.pi * 180.0)
defvjp(anp.degrees, lambda ans, x: lambda g: g / anp.pi * 180.0)
defvjp(anp.deg2rad, lambda ans, x: lambda g: g * anp.pi / 180.0)
defvjp(anp.radians, lambda ans, x: lambda g: g * anp.pi / 180.0)
defvjp(anp.square, lambda ans, x: lambda g: g * 2 * x)
defvjp(anp.sqrt, lambda ans, x: lambda g: g * 0.5 * x**-0.5)
defvjp(
    anp.sinc,
    lambda ans, x: lambda g: g * (anp.cos(anp.pi * x) * anp.pi * x - anp.sin(anp.pi * x)) / (anp.pi * x**2),
)
defvjp(anp.reshape, lambda ans, x, shape, order=None: lambda g: anp.reshape(g, anp.shape(x), order=order))
defvjp(anp.roll, lambda ans, x, shift, axis=None: lambda g: anp.roll(g, -shift, axis=axis))
defvjp(anp.array_split, lambda ans, ary, idxs, axis=0: lambda g: anp.concatenate(g, axis=axis))
defvjp(anp.split, lambda ans, ary, idxs, axis=0: lambda g: anp.concatenate(g, axis=axis))
defvjp(anp.vsplit, lambda ans, ary, idxs: lambda g: anp.concatenate(g, axis=0))
defvjp(anp.hsplit, lambda ans, ary, idxs: lambda g: anp.concatenate(g, axis=1))
defvjp(anp.dsplit, lambda ans, ary, idxs: lambda g: anp.concatenate(g, axis=2))
defvjp(anp.ravel, lambda ans, x, order=None: lambda g: anp.reshape(g, anp.shape(x), order=order))
defvjp(anp.expand_dims, lambda ans, x, axis: lambda g: anp.reshape(g, anp.shape(x)))
defvjp(anp.squeeze, lambda ans, x, axis=None: lambda g: anp.reshape(g, anp.shape(x)))
defvjp(anp.diag, lambda ans, x, k=0: lambda g: anp.diag(g, k))
defvjp(anp.flipud, lambda ans, x,: lambda g: anp.flipud(g))
defvjp(anp.fliplr, lambda ans, x,: lambda g: anp.fliplr(g))
defvjp(anp.rot90, lambda ans, x, k=1: lambda g: anp.rot90(g, -k))
defvjp(
    anp.trace,
    lambda ans, x, offset=0: (
        lambda g: anp.einsum("ij,...->ij...", anp.eye(x.shape[0], x.shape[1], k=offset), g)
    ),
)
defvjp(anp.full, lambda ans, shape, fill_value, dtype=None: lambda g: anp.sum(g), argnums=(1,))
defvjp(anp.triu, lambda ans, x, k=0: lambda g: anp.triu(g, k=k))
defvjp(anp.tril, lambda ans, x, k=0: lambda g: anp.tril(g, k=k))
defvjp(anp.clip, lambda ans, x, a_min, a_max: lambda g: g * anp.logical_and(ans != a_min, ans != a_max))
defvjp(anp.swapaxes, lambda ans, x, axis1, axis2: lambda g: anp.swapaxes(g, axis2, axis1))
defvjp(anp.moveaxis, lambda ans, a, source, destination: lambda g: anp.moveaxis(g, destination, source))
defvjp(anp.real_if_close, lambda ans, x: lambda g: match_complex(x, g))
defvjp(anp.real, lambda ans, x: lambda g: match_complex(x, g))
defvjp(anp.imag, lambda ans, x: lambda g: match_complex(x, -1j * g))
np_conj_vjp = lambda ans, x: lambda g: anp.conj(g)
defvjp(anp.conj, np_conj_vjp)
defvjp(anp.conjugate, np_conj_vjp)
defvjp(anp.angle, lambda ans, x: lambda g: match_complex(x, g * anp.conj(x * 1j) / anp.abs(x) ** 2))
defvjp(
    anp.where,
    None,
    lambda ans, c, x=None, y=None: lambda g: anp.where(c, g, anp.zeros(g.shape)),
    lambda ans, c, x=None, y=None: lambda g: anp.where(c, anp.zeros(g.shape), g),
)
defvjp(
    anp.cross,
    lambda ans, a, b, axisa=-1, axisb=-1, axisc=-1, axis=None: (
        lambda g: anp.cross(b, g, axisb, axisc, axisa, axis)
    ),
    lambda ans, a, b, axisa=-1, axisb=-1, axisc=-1, axis=None: (
        lambda g: anp.cross(g, a, axisc, axisa, axisb, axis)
    ),
)
defvjp(
    anp.linspace,
    lambda ans, start, stop, num: lambda g: anp.dot(anp.linspace(1.0, 0.0, num), g),
    lambda ans, start, stop, num: lambda g: anp.dot(anp.linspace(0.0, 1.0, num), g),
)

defvjp(
    anp._astype,
    lambda ans, A, dtype, order="K", casting="unsafe", subok=True, copy=True: (
        lambda g: anp._astype(g, A.dtype)
    ),
)


# ----- Trickier grads -----
def grad_rollaxis(ans, a, axis, start=0):
    if axis < 0:
        raise NotImplementedError(
            "Gradient of rollaxis not implemented for axis < 0. Please use moveaxis instead."
        )
    elif start < 0:
        raise NotImplementedError(
            "Gradient of rollaxis not implemented for start < 0. Please use moveaxis instead."
        )
    return lambda g: anp.rollaxis(g, start - 1, axis) if start > axis else anp.rollaxis(g, start, axis + 1)


defvjp(anp.rollaxis, grad_rollaxis)


def grad_diff(ans, a, n=1, axis=-1):
    nd = anp.ndim(a)
    ans_shape = anp.shape(ans)
    sl1 = [slice(None)] * nd
    sl1[axis] = slice(None, 1)

    sl2 = [slice(None)] * nd
    sl2[axis] = slice(-1, None)

    def undiff(g):
        if g.shape[axis] > 0:
            return anp.concatenate((-g[tuple(sl1)], -anp.diff(g, axis=axis), g[tuple(sl2)]), axis=axis)
        shape = list(ans_shape)
        shape[axis] = 1
        return anp.zeros(shape)

    def helper(g, n):
        if n == 0:
            return g
        return helper(undiff(g), n - 1)

    return lambda g: helper(g, n)


defvjp(anp.diff, grad_diff)


def grad_gradient(ans, x, *vargs, **kwargs):
    axis = kwargs.pop("axis", None)
    if vargs or kwargs:
        raise NotImplementedError("The only optional argument currently supported for np.gradient is axis.")
    if axis is None:
        axis = range(x.ndim)
    elif type(axis) is int:
        axis = [axis]
    else:
        axis = list(axis)

    x_dtype = x.dtype
    x_shape = x.shape
    nd = x.ndim

    def vjp(g):
        if anp.ndim(g) == nd:
            # add axis if gradient was along one axis only
            g = g[anp.newaxis]

        # accumulate gradient
        out = anp.zeros(x_shape, dtype=x_dtype)

        for i, a in enumerate(axis):
            # swap gradient axis to the front
            g_swap = anp.swapaxes(g[i], 0, a)[:, anp.newaxis]

            out_axis = anp.concatenate(
                (
                    -g_swap[0] - 0.5 * g_swap[1],
                    g_swap[0] - 0.5 * g_swap[2],
                    (-1.0) * anp.gradient(g_swap, axis=0)[2:-2, 0],
                    0.5 * g_swap[-3] - g_swap[-1],
                    0.5 * g_swap[-2] + g_swap[-1],
                ),
                axis=0,
            )

            out = out + anp.swapaxes(out_axis, 0, a)

        return out

    return vjp


defvjp(anp.gradient, grad_gradient)


def grad_repeat(ans, x, repeats, axis=None):
    shape = anp.shape(x)

    def vjp(g):
        if axis is None:  # If axis is none, np.repeat() repeats the flattened array.
            expanded = anp.reshape(g, (anp.prod(shape),) + (repeats,))
            return anp.reshape(anp.sum(expanded, axis=1, keepdims=False), shape)
        else:
            if shape[axis] == 1:  # For this common case, the logic is simple.
                return anp.sum(g, axis=axis, keepdims=True)
            else:
                expanded = anp.reshape(g, shape[0 : axis + 1] + (repeats,) + shape[axis + 1 :])
                return anp.sum(expanded, axis=axis + 1, keepdims=False)

    return vjp


defvjp(anp.repeat, grad_repeat)


def grad_tile(ans, x, reps):
    reps = [reps] if anp.isscalar(reps) else reps
    x_shape = anp.shape(x)

    def vjp(g):
        for axis, rep in enumerate(reps):
            g = sum(anp.split(g, rep, axis))
        return anp.reshape(g, x_shape)

    return vjp


defvjp(anp.tile, grad_tile)


def grad_kron(argnum, ans, orig_A, orig_B):
    # kron has different promotion rules than dot. the reshapes are necessary if
    # and only if (1) orig_B is 1D or (2) orig_A and/or orig_B are 0D
    orig_A_shape = anp.shape(orig_A)
    orig_B_shape = anp.shape(orig_B)

    def vjp(G):
        A, B = anp.atleast_2d(orig_A), anp.atleast_2d(orig_B)
        shape = list(A.shape + B.shape)
        n = anp.ndim(A)
        shape[n - 1], shape[n] = shape[n], shape[n - 1]
        reshaped_G = anp.swapaxes(anp.reshape(G, shape), n - 1, n)
        if argnum == 0:
            return match_complex(
                orig_A, anp.reshape(anp.tensordot(reshaped_G, B, axes=anp.ndim(B)), orig_A_shape)
            )
        else:
            return match_complex(
                orig_B, anp.reshape(anp.tensordot(A, reshaped_G, axes=anp.ndim(A)), orig_B_shape)
            )

    return vjp


defvjp(anp.kron, partial(grad_kron, 0), partial(grad_kron, 1))


def grad_transpose(ans, x, axes=None):
    if axes is not None:
        axes = anp.argsort(axes)
    return lambda g: anp.transpose(g, axes)


defvjp(anp.transpose, grad_transpose)


def repeat_to_match_shape(g, shape, dtype, axis, keepdims):
    """Returns the array g repeated along axis to fit vector space vs.
    Also returns the number of repetitions of the array."""
    if shape == ():
        return g, 1
    axis = list(axis) if isinstance(axis, tuple) else axis
    new_shape = onp.array(shape)
    new_shape[axis] = 1
    num_reps = onp.prod(onp.array(shape)[axis])
    # Can't use broadcast_to because of numpy bug: https://github.com/numpy/numpy/issues/9165
    # return anp.broadcast_to(anp.reshape(g, new_shape), shape), num_reps
    return anp.reshape(g, new_shape) + onp.zeros(shape, dtype=dtype), num_reps


def grad_broadcast_to(ans, x, new_shape):
    old_shape = anp.shape(x)
    assert anp.shape(ans) == new_shape
    assert len(old_shape) == len(new_shape), "Can't handle extra leading dims"
    broadcast_axes = tuple(
        onp.where(onp.logical_and(onp.array(old_shape) == 1, onp.array(new_shape) > 1))[0]
    )
    return lambda g: anp.sum(g, axis=broadcast_axes, keepdims=True)


defvjp(anp.broadcast_to, grad_broadcast_to)


def grad_np_sum(ans, x, axis=None, keepdims=False, dtype=None):
    shape, dtype = anp.shape(x), anp.result_type(x)
    return lambda g: repeat_to_match_shape(g, shape, dtype, axis, keepdims)[0]


defvjp(anp.sum, grad_np_sum)


def grad_np_mean(ans, x, axis=None, keepdims=False):
    shape, dtype = anp.shape(x), anp.result_type(x)

    def vjp(g):
        g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
        return g_repeated / num_reps

    return vjp


defvjp(anp.mean, grad_np_mean)


def grad_np_prod(ans, x, axis=None, keepdims=False):  # TODO: Support tuples of axes.
    shape, dtype = anp.shape(x), anp.result_type(x)

    def vjp(g):
        g_repeated, _ = repeat_to_match_shape(g * ans, shape, dtype, axis, keepdims)
        return g_repeated / x

    return vjp


defvjp(anp.prod, grad_np_prod)


def grad_np_var(ans, x, axis=None, ddof=0, keepdims=False):
    shape, _, dtype, iscomplex = anp.metadata(x)

    def vjp(g):
        if iscomplex:
            g = g + 0j
        g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
        x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True))
        return 2.0 * g_repeated * x_minus_mean / (num_reps - ddof)

    return vjp


defvjp(anp.var, grad_np_var)


def grad_np_std(ans, x, axis=None, ddof=0, keepdims=False):
    shape, _, dtype, iscomplex = anp.metadata(x)

    def vjp(g):
        if iscomplex:
            g = g + 0j
        g_repeated, num_reps = repeat_to_match_shape(
            g, shape, dtype, axis, keepdims
        )  # Avoid division by zero.
        if num_reps <= 1:
            return g_repeated * 0.0
        else:
            g_repeated, num_reps = repeat_to_match_shape(g / ans, shape, dtype, axis, keepdims)
            x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True))
            return g_repeated * x_minus_mean / (num_reps - ddof)

    return vjp


defvjp(anp.std, grad_np_std)


def grad_chooser(ans, x, axis=None, keepdims=None):
    shape, dtype = anp.shape(x), anp.result_type(x)

    def vjp(g):
        """Builds gradient of functions that choose a single item, such as min or max."""
        g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
        argmax_locations = x == repeat_to_match_shape(ans, shape, dtype, axis, keepdims)[0]
        return g_repeated * argmax_locations / onp.sum(argmax_locations, axis=axis, keepdims=True)

    return vjp


defvjp(anp.max, grad_chooser)
defvjp(anp.min, grad_chooser)
defvjp(anp.amax, grad_chooser)
defvjp(anp.amin, grad_chooser)


def reverse_axis(x, axis):
    x = x.swapaxes(axis, 0)
    x = x[::-1, ...]
    return x.swapaxes(0, axis)


def grad_np_cumsum(ans, x, axis=None):
    def vjp(g):
        if axis:
            return reverse_axis(anp.cumsum(reverse_axis(g, axis), axis), axis)
        else:
            return anp.reshape(anp.cumsum(g[::-1], axis)[::-1], x.shape)

    return vjp


defvjp(anp.cumsum, grad_np_cumsum)


def grad_inner(argnum, ans, A, B):
    A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
    if A_ndim == 0 or B_ndim == 0:
        axes = ([], [])
    else:
        axes = ([A_ndim - 1], [B_ndim - 1])
    if argnum == 0:
        return lambda G: tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim)
    elif argnum == 1:
        return lambda G: tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim)


defvjp(anp.inner, partial(grad_inner, 0), partial(grad_inner, 1))


def matmul_adjoint_0(B, G, A_meta, B_ndim):
    if anp.ndim(G) == 0:  # A_ndim == B_ndim == 1
        return unbroadcast(G * B, A_meta)
    _, A_ndim, _, _ = A_meta
    if A_ndim == 1:
        G = anp.expand_dims(G, anp.ndim(G) - 1)
    if B_ndim == 1:  # The result we need is an outer product
        B = anp.expand_dims(B, 0)
        G = anp.expand_dims(G, anp.ndim(G))
    else:  # We need to swap the last two axes of B
        B = anp.swapaxes(B, B_ndim - 2, B_ndim - 1)
    result = anp.matmul(G, B)
    return unbroadcast(result, A_meta)


def matmul_adjoint_1(A, G, A_ndim, B_meta):
    if anp.ndim(G) == 0:  # A_ndim == B_ndim == 1
        return unbroadcast(G * A, B_meta)
    _, B_ndim, _, _ = B_meta
    B_is_vec = B_ndim == 1
    if B_is_vec:
        G = anp.expand_dims(G, anp.ndim(G))
    if A_ndim == 1:  # The result we need is an outer product
        A = anp.expand_dims(A, 1)
        G = anp.expand_dims(G, anp.ndim(G) - 1)
    else:  # We need to swap the last two axes of A
        A = anp.swapaxes(A, A_ndim - 2, A_ndim - 1)
    result = anp.matmul(A, G)
    if B_is_vec:
        result = anp.squeeze(result, anp.ndim(G) - 1)
    return unbroadcast(result, B_meta)


def matmul_vjp_0(ans, A, B):
    A_meta = anp.metadata(A)
    B_ndim = anp.ndim(B)
    return lambda g: matmul_adjoint_0(B, g, A_meta, B_ndim)


def matmul_vjp_1(ans, A, B):
    A_ndim = anp.ndim(A)
    B_meta = anp.metadata(B)
    return lambda g: matmul_adjoint_1(A, g, A_ndim, B_meta)


defvjp(anp.matmul, matmul_vjp_0, matmul_vjp_1)


@primitive
def dot_adjoint_0(B, G, A_meta, B_meta):
    _, A_ndim, A_dtype, _ = A_meta
    _, B_ndim, _, _ = B_meta
    if B_ndim == 0 or B_ndim == 1 or A_ndim == 0:
        contract_num = max(0, B_ndim - (A_ndim != 0))
        out = onp.tensordot(G, B, contract_num)
    else:
        out = onp.tensordot(G, onp.swapaxes(B, -1, -2), B_ndim - 1)
    return onp.asarray(out, dtype=A_dtype)


@primitive
def dot_adjoint_1(A, G, A_meta, B_meta):
    _, A_ndim, _, _ = A_meta
    _, B_ndim, B_dtype, _ = B_meta
    needs_transpose = B_ndim > 1 and A_ndim != 0
    swap = (lambda x: onp.swapaxes(x, -1, -2)) if needs_transpose else (lambda x: x)
    if A_ndim == 0 or A_ndim == 1 or B_ndim == 0:
        contract_num = max(0, A_ndim - (B_ndim != 0))
        out = swap(onp.tensordot(G, A, contract_num))
    else:
        out = swap(onp.tensordot(G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)]))
    return onp.asarray(out, dtype=B_dtype)


def dot_vjp_0(ans, A, B):
    A_meta, B_meta = anp.metadata(A), anp.metadata(B)
    return lambda g: match_complex(A, dot_adjoint_0(B, g, A_meta, B_meta))


def dot_vjp_1(ans, A, B):
    A_meta, B_meta = anp.metadata(A), anp.metadata(B)
    return lambda g: match_complex(B, dot_adjoint_1(A, g, A_meta, B_meta))


defvjp(anp.dot, dot_vjp_0, dot_vjp_1)

defvjp(
    dot_adjoint_0,
    lambda ans, B, g, An, Bn: lambda A: match_complex(B, dot_adjoint_1(A, g, An, Bn)),
    lambda ans, B, g, An, Bn: lambda A: match_complex(g, anp.dot(A, B)),
)

defvjp(
    dot_adjoint_1,
    lambda ans, A, g, An, Bn: lambda B: match_complex(A, dot_adjoint_0(B, g, An, Bn)),
    lambda ans, A, g, An, Bn: lambda B: match_complex(g, anp.dot(A, B)),
)


@primitive
def tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim):
    # The adjoint of the operator
    # A |--> np.tensordot(A, B, axes)
    if B_ndim == 0:
        return G * B

    G_axes = onp.arange(onp.ndim(G))
    if type(axes) is int:
        axes = max(axes, 0)
        B_axes = onp.arange(B_ndim)
        return onp.tensordot(G, B, [G_axes[A_ndim - axes :], B_axes[axes:]])
    else:
        axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
        axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
        axes = [axes0, axes1]
        A_axes = onp.arange(A_ndim)
        B_axes = onp.arange(B_ndim)
        summed_axes = [
            onp.asarray(axes[0], dtype="int64") % A_ndim,
            onp.asarray(axes[1], dtype="int64") % B_ndim,
        ]
        other_axes = [onp.delete(A_axes, summed_axes[0]), onp.delete(B_axes, summed_axes[1])]
        out = onp.tensordot(G, B, [G_axes[len(other_axes[0]) :], other_axes[1]])
        perm = onp.argsort(onp.concatenate((other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
        return onp.transpose(out, perm)


@primitive
def tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim):
    # The adjoint of the operator
    # B |--> np.tensordot(A, B, axes)
    if A_ndim == 0:
        return G * A

    G_axes = onp.arange(onp.ndim(G))
    if type(axes) is int:
        axes = max(axes, 0)
        A_axes = onp.arange(A_ndim)
        return onp.tensordot(A, G, [A_axes[: A_ndim - axes], G_axes[: A_ndim - axes]])
    else:
        axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
        axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
        axes = [axes0, axes1]
        A_axes = onp.arange(A_ndim)
        B_axes = onp.arange(B_ndim)
        summed_axes = [
            onp.asarray(axes[0], dtype="int64") % A_ndim,
            onp.asarray(axes[1], dtype="int64") % B_ndim,
        ]
        other_axes = [onp.delete(A_axes, summed_axes[0]), onp.delete(B_axes, summed_axes[1])]
        out = onp.tensordot(A, G, [other_axes[0], G_axes[: len(other_axes[0])]])
        perm = onp.argsort(onp.concatenate((summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
        return onp.transpose(out, perm)


def tensordot_vjp_0(ans, A, B, axes=2):
    A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
    return lambda G: match_complex(A, tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim))


def tensordot_vjp_1(ans, A, B, axes=2):
    A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
    return lambda G: match_complex(B, tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim))


defvjp(anp.tensordot, tensordot_vjp_0, tensordot_vjp_1)
defvjp(
    tensordot_adjoint_0,
    lambda ans, B, G, axes, An, Bn: lambda A: match_complex(B, tensordot_adjoint_1(A, G, axes, An, Bn)),
    lambda ans, B, G, axes, An, Bn: lambda A: match_complex(G, anp.tensordot(A, B, axes)),
)
defvjp(
    tensordot_adjoint_1,
    lambda ans, A, G, axes, An, Bn: lambda B: match_complex(A, tensordot_adjoint_0(B, G, axes, An, Bn)),
    lambda ans, A, G, axes, An, Bn: lambda B: match_complex(G, anp.tensordot(A, B, axes)),
)
defvjp(
    anp.outer,
    lambda ans, a, b: lambda g: match_complex(a, anp.dot(g, b.T)),
    lambda ans, a, b: lambda g: match_complex(b, anp.dot(a.T, g)),
)


def grad_concatenate_args(argnum, ans, axis_args, kwargs):
    axis, args = axis_args[0], axis_args[1:]
    sizes = [anp.shape(a)[axis] for a in args[:argnum]]
    start = sum(sizes[:-1])
    idxs = [slice(None)] * ans.ndim
    idxs[axis] = slice(start, start + sizes[-1])
    return lambda g: g[tuple(idxs)]


defvjp_argnum(anp.concatenate_args, grad_concatenate_args)


def wrapped_reshape(x, *args, **kwargs):
    # The reshape method can be called like A.reshape((5,4)) or A.reshape(5,4).
    # The reshape function doesn't support both ways, so we have to wrap it.
    if isinstance(args[0], int):
        return anp.reshape(x, args, **kwargs)
    else:
        return anp.reshape(x, *args, **kwargs)


setattr(ArrayBox, "reshape", wrapped_reshape)


def grad_sort(ans, x, axis=-1, kind="quicksort", order=None):
    # TODO: Cast input with np.asanyarray()
    if len(x.shape) > 1:
        raise NotImplementedError("Gradient of sort not implemented for multi-dimensional arrays.")
    sort_perm = anp.argsort(x, axis, kind, order)
    return lambda g: unpermuter(g, sort_perm)


defvjp(anp.sort, grad_sort)
if onp.lib.NumpyVersion(onp.__version__) < "2.0.0":
    defvjp(anp.msort, grad_sort)  # Until multi-D is allowed, these are the same.


def grad_partition(ans, x, kth, axis=-1, kind="introselect", order=None):
    # TODO: Cast input with np.asanyarray()
    if len(x.shape) > 1:
        raise NotImplementedError("Gradient of partition not implemented for multi-dimensional arrays.")
    partition_perm = anp.argpartition(x, kth, axis, kind, order)
    return lambda g: unpermuter(g, partition_perm)


defvjp(anp.partition, grad_partition)


def unpermuter(g, permutation):
    unsort = anp.zeros(len(permutation), dtype=int)
    unsort[permutation] = list(range(len(permutation)))
    return g[unsort]


def grad_reshape_list(ans, *arys):
    if len(arys) > 1:
        raise NotImplementedError("Can't handle multiple arguments yet.")
    return lambda g: anp.reshape(g, anp.shape(arys[0]))


defvjp(anp.atleast_1d, grad_reshape_list)
defvjp(anp.atleast_2d, grad_reshape_list)
defvjp(anp.atleast_3d, grad_reshape_list)


def grad_einsum(argnum, ans, operands_, kwargs):
    result_meta = anp.metadata(operands_[argnum])

    def vjp(g):
        operands = operands_
        if isinstance(operands[0], str):  # using "ijk" convention.
            in_subs, out_subs, _ = anp.parse_einsum_input(*operands)
            string, operands = operands[0], operands[1:]

            in_subs_list = in_subs.split(",")
            op_num = argnum - 1
            subs_wrt = in_subs_list[op_num]
            rest_of_ops = operands[:op_num] + operands[op_num + 1 :]
            rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num + 1 :]

            # subscripts that only appear in subs_wrt (and not in other subscript lists
            # or in the output) are implicitly being summed out, as if contracted
            # against a tensor of ones. we make that tensor of ones explicit to handle
            # the necessary vjp broadcasting inside einsum.
            other_named_subs = set("".join([out_subs] + rest_of_subs))
            naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt) if sub not in other_named_subs]
            if naked_summed:
                naked_summed_dims, ones_subs = zip(*naked_summed)
                ones_subs = "".join(ones_subs)
                ones = onp.ones(onp.array(operands[op_num].shape)[list(naked_summed_dims)])
                new_input_subs = ",".join([out_subs, ones_subs] + rest_of_subs)
                new_operands = (g, ones) + rest_of_ops
            else:
                new_input_subs = ",".join([out_subs] + rest_of_subs)
                new_operands = (g,) + rest_of_ops

            new_subscripts = new_input_subs + "->" + subs_wrt
            return unbroadcast(anp.einsum(new_subscripts, *new_operands), result_meta)
        else:  # using (op0, sublist0, op1, sublist1, ..., sublistout) convention
            if len(operands) % 2 == 0:
                raise NotImplementedError("Need sublistout argument")
            operands = list(operands)
            rest_of_ops = (
                [operands[-1]] + operands[:argnum] + operands[(argnum + 2) : -1] + [operands[argnum + 1]]
            )
            return unbroadcast_einsum(anp.einsum(g, *rest_of_ops), result_meta, operands[argnum + 1])

    return vjp


defvjp_argnum(anp.einsum, grad_einsum)

defvjp(
    anp.diagonal,
    lambda ans, A, offset=0, axis1=0, axis2=1: lambda g: anp.make_diagonal(g, offset, axis1, axis2),
)
defvjp(
    anp.make_diagonal,
    lambda ans, D, offset=0, axis1=0, axis2=1: lambda g: anp.diagonal(g, offset, axis1, axis2),
)


def match_complex(target, x):
    target_iscomplex = anp.iscomplexobj(target)
    x_iscomplex = anp.iscomplexobj(x)
    if x_iscomplex and not target_iscomplex:
        return anp.real(x)
    elif not x_iscomplex and target_iscomplex:
        return x + 0j
    else:
        return x


def unbroadcast(x, target_meta, broadcast_idx=0):
    target_shape, target_ndim, dtype, target_iscomplex = target_meta
    while anp.ndim(x) > target_ndim:
        x = anp.sum(x, axis=broadcast_idx)
    for axis, size in enumerate(target_shape):
        if size == 1:
            x = anp.sum(x, axis=axis, keepdims=True)
    if anp.iscomplexobj(x) and not target_iscomplex:
        x = anp.real(x)
    return x


def unbroadcast_f(target, f):
    target_meta = anp.metadata(target)
    return lambda g: unbroadcast(f(g), target_meta)


def unbroadcast_einsum(x, target_meta, subscript):
    if Ellipsis not in subscript:
        return x
    elif subscript[0] == Ellipsis:
        return unbroadcast(x, target_meta, 0)
    elif subscript[-1] == Ellipsis:
        return unbroadcast(x, target_meta, -1)
    else:
        return unbroadcast(x, target_meta, subscript.index(Ellipsis))


def balanced_eq(x, z, y):
    return (x == z) / (1.0 + (x == y))


def replace_zero(x, val):
    return anp.where(x, x, val)


# ----- extra functions used internally  -----


def array_from_args_gradmaker(argnum, ans, args, kwargs):
    return lambda g: g[argnum - 2]


defvjp_argnum(anp.array_from_args, array_from_args_gradmaker)


def array_from_scalar_or_array_gradmaker(ans, array_args, array_kwargs, scarray):
    ndmin = array_kwargs.get("ndmin", 0)
    scarray_ndim = anp.ndim(scarray)
    if ndmin > scarray_ndim:
        return lambda g: anp.squeeze(g, axis=tuple(range(ndmin - scarray_ndim)))
    else:
        return lambda g: g


defvjp(anp._array_from_scalar_or_array, array_from_scalar_or_array_gradmaker, argnums=(2, 3))


@primitive
def untake(x, idx, vs):
    if isinstance(idx, list) and (len(idx) == 0 or not isinstance(idx[0], slice)):
        idx = onp.array(idx, dtype="int64")

    def mut_add(A):
        onp.add.at(A, idx, x)
        return A

    return SparseObject(vs, mut_add)


defvjp(func(ArrayBox.__getitem__), lambda ans, A, idx: lambda g: untake(g, idx, vspace(A)))
defvjp(untake, lambda ans, x, idx, _: lambda g: g[idx])


def _unpad(array, width):
    if anp.isscalar(width):
        width = [[width, width]]
    elif anp.shape(width) == (1,):
        width = [anp.concatenate((width, width))]
    elif anp.shape(width) == (2,):
        width = [width]
    if anp.shape(width)[0] == 1:
        width = anp.repeat(width, anp.ndim(array), 0)
    idxs = tuple(slice(l, -u or None) for l, u in width)
    return array[idxs]


def pad_vjp(ans, array, pad_width, mode, **kwargs):
    assert mode == "constant", "Only constant mode padding is supported."
    return lambda g: _unpad(g, pad_width)


defvjp(anp.pad, pad_vjp)


================================================
FILE: autograd/numpy/numpy_vspaces.py
================================================
import numpy as np

from autograd.builtins import NamedTupleVSpace
from autograd.extend import VSpace


class ArrayVSpace(VSpace):
    def __init__(self, value):
        value = np.asarray(value)
        self.shape = value.shape
        self.dtype = value.dtype

    @property
    def size(self):
        return np.prod(self.shape)

    @property
    def ndim(self):
        return len(self.shape)

    def zeros(self):
        return np.zeros(self.shape, dtype=self.dtype)

    def ones(self):
        return np.ones(self.shape, dtype=self.dtype)

    def standard_basis(self):
        for idxs in np.ndindex(*self.shape):
            vect = np.zeros(self.shape, dtype=self.dtype)
            vect[idxs] = 1
            yield vect

    def randn(self):
        return np.array(np.random.randn(*self.shape)).astype(self.dtype)

    def _inner_prod(self, x, y):
        return np.dot(np.ravel(x), np.ravel(y))


class ComplexArrayVSpace(ArrayVSpace):
    iscomplex = True

    @property
    def size(self):
        return np.prod(self.shape) * 2

    def ones(self):
        return np.ones(self.shape, dtype=self.dtype) + 1.0j * np.ones(self.shape, dtype=self.dtype)

    def standard_basis(self):
        for idxs in np.ndindex(*self.shape):
            for v in [1.0, 1.0j]:
                vect = np.zeros(self.shape, dtype=self.dtype)
                vect[idxs] = v
                yield vect

    def randn(self):
        return np.array(np.random.randn(*self.shape)).astype(self.dtype) + 1.0j * np.array(
            np.random.randn(*self.shape)
        ).astype(self.dtype)

    def _inner_prod(self, x, y):
        return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y)))

    def _covector(self, x):
        return np.conj(x)


VSpace.register(np.ndarray, lambda x: ComplexArrayVSpace(x) if np.iscomplexobj(x) else ArrayVSpace(x))

for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]:
    ArrayVSpace.register(type_)

for type_ in [complex, np.clongdouble, np.complex64, np.complex128]:
    ComplexArrayVSpace.register(type_)


if np.lib.NumpyVersion(np.__version__) >= "2.0.0":

    class EigResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.EigResult

    class EighResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.EighResult

    class QRResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.QRResult

    class SlogdetResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.SlogdetResult

    class SVDResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.SVDResult

    EigResultVSpace.register(np.linalg._linalg.EigResult)
    EighResultVSpace.register(np.linalg._linalg.EighResult)
    QRResultVSpace.register(np.linalg._linalg.QRResult)
    SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult)
    SVDResultVSpace.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= "1.25":

    class EigResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.EigResult

    class EighResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.EighResult

    class QRResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.QRResult

    class SlogdetResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.SlogdetResult

    class SVDResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.SVDResult

    EigResultVSpace.register(np.linalg.linalg.EigResult)
    EighResultVSpace.register(np.linalg.linalg.EighResult)
    QRResultVSpace.register(np.linalg.linalg.QRResult)
    SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
    SVDResultVSpace.register(np.linalg.linalg.SVDResult)


================================================
FILE: autograd/numpy/numpy_wrapper.py
================================================
import warnings

import numpy as _np

import autograd.builtins as builtins
from autograd.extend import notrace_primitive, primitive

if _np.lib.NumpyVersion(_np.__version__) >= "2.0.0":
    from numpy._core.einsumfunc import _parse_einsum_input
else:
    from numpy.core.einsumfunc import _parse_einsum_input

numpy_version = _np.__version__

notrace_functions = [_np.ndim, _np.shape, _np.iscomplexobj, _np.result_type]


def wrap_intdtype(cls):
    class IntdtypeSubclass(cls):
        __new__ = notrace_primitive(cls.__new__)

    return IntdtypeSubclass


def wrap_namespace(old, new):
    unchanged_types = {float, int, type(None), type}
    int_types = {_np.int8, _np.int16, _np.int32, _np.int64, _np.integer}
    for name, obj in old.items():
        if obj in notrace_functions:
            new[name] = notrace_primitive(obj)
        elif callable(obj) and type(obj) is not type:
            new[name] = primitive(obj)
        elif type(obj) is type and obj in int_types:
            new[name] = wrap_intdtype(obj)
        elif type(obj) in unchanged_types:
            new[name] = obj


wrap_namespace(_np.__dict__, globals())

# ----- Special treatment of list-input functions -----


@primitive
def concatenate_args(axis, *args):
    return _np.concatenate(args, axis).view(ndarray)


concatenate = lambda arr_list, axis=0: concatenate_args(axis, *arr_list)
vstack = row_stack = lambda tup: concatenate([atleast_2d(_m) for _m in tup], axis=0)


def hstack(tup):
    arrs = [atleast_1d(_m) for _m in tup]
    if arrs[0].ndim == 1:
        return concatenate(arrs, 0)
    return concatenate(arrs, 1)


def column_stack(tup):
    arrays = []
    for v in tup:
        arr = array(v)
        if arr.ndim < 2:
            arr = array(arr, ndmin=2).T
        arrays.append(arr)
    return concatenate(arrays, 1)


def array(A, *args, **kwargs):
    t = builtins.type(A)
    if t in (list, tuple):
        return array_from_args(args, kwargs, *map(array, A))
    else:
        return _array_from_scalar_or_array(args, kwargs, A)


def wrap_if_boxes_inside(raw_array, slow_op_name=None):
    if raw_array.dtype is _np.dtype("O"):
        if slow_op_name:
            warnings.warn(f"{slow_op_name} is slow for array inputs. np.concatenate() is faster.")
        return array_from_args((), {}, *raw_array.ravel()).reshape(raw_array.shape)
    else:
        return raw_array


@primitive
def _array_from_scalar_or_array(array_args, array_kwargs, scalar):
    return _np.array(scalar, *array_args, **array_kwargs)


@primitive
def array_from_args(array_args, array_kwargs, *args):
    return _np.array(args, *array_args, **array_kwargs)


def select(condlist, choicelist, default=0):
    raw_array = _np.select(list(condlist), list(choicelist), default=default)
    return array(list(raw_array.ravel())).reshape(raw_array.shape)


def stack(arrays, axis=0):
    # this code is basically copied from numpy/core/shape_base.py's stack
    # we need it here because we want to re-implement stack in terms of the
    # primitives defined in this file

    arrays = [array(arr) for arr in arrays]
    if not arrays:
        raise ValueError("need at least one array to stack")

    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError("all input arrays must have the same shape")

    result_ndim = arrays[0].ndim + 1
    if not -result_ndim <= axis < result_ndim:
        raise IndexError("axis {0} out of bounds [-{1}, {1})".format(axis, result_ndim))
    if axis < 0:
        axis += result_ndim

    sl = (slice(None),) * axis + (None,)
    return concatenate([arr[sl] for arr in arrays], axis=axis)


def append(arr, values, axis=None):
    # this code is basically copied from numpy/lib/function_base.py's append
    arr = array(arr)
    if axis is None:
        if ndim(arr) != 1:
            arr = ravel(arr)
        values = ravel(array(values))
        axis = ndim(arr) - 1
    return concatenate((arr, values), axis=axis)


# ----- Enable functions called using [] ----


class r_class:
    def __getitem__(self, args):
        raw_array = _np.r_[args]
        return wrap_if_boxes_inside(raw_array, slow_op_name="r_")


r_ = r_class()


class c_class:
    def __getitem__(self, args):
        raw_array = _np.c_[args]
        return wrap_if_boxes_inside(raw_array, slow_op_name="c_")


c_ = c_class()


# ----- misc -----
@primitive
def make_diagonal(D, offset=0, axis1=0, axis2=1):
    # Numpy doesn't offer a complement to np.diagonal: a function to create new
    # diagonal arrays with extra dimensions. We need such a function for the
    # gradient of np.diagonal and it's also quite handy to have. So here it is.
    if not (offset == 0 and axis1 == -1 and axis2 == -2):
        raise NotImplementedError("Currently make_diagonal only supports offset=0, axis1=-1, axis2=-2")

    # We use a trick: calling np.diagonal returns a view on the original array,
    # so we can modify it in-place. (only valid for numpy version >= 1.10.)
    new_array = _np.zeros(D.shape + (D.shape[-1],))
    new_array_diag = _np.diagonal(new_array, offset=0, axis1=-1, axis2=-2)
    new_array_diag.flags.writeable = True
    new_array_diag[:] = D
    return new_array


@notrace_primitive
def metadata(A):
    return _np.shape(A), _np.ndim(A), _np.result_type(A), _np.iscomplexobj(A)


@notrace_primitive
def parse_einsum_input(*args):
    return _parse_einsum_input(args)


if _np.lib.NumpyVersion(_np.__version__) >= "2.0.0":
    # Wrapped above
    _astype = astype
else:

    @primitive
    def _astype(A, dtype, order="K", casting="unsafe", subok=True, copy=True):
        return A.astype(dtype, order, casting, subok, copy)


================================================
FILE: autograd/numpy/random.py
================================================
import numpy.random as npr

from .numpy_wrapper import wrap_namespace

wrap_namespace(npr.__dict__, globals())


================================================
FILE: autograd/scipy/__init__.py
================================================
from . import integrate, signal, special, stats


================================================
FILE: autograd/scipy/integrate.py
================================================
import scipy.integrate

import autograd.numpy as np
from autograd import make_vjp
from autograd.builtins import tuple
from autograd.extend import defvjp_argnums, primitive
from autograd.misc import flatten

odeint = primitive(scipy.integrate.odeint)


def grad_odeint(yt, func, y0, t, func_args, **kwargs):
    # Extended from "Scalable Inference of Ordinary Differential
    # Equation Models of Biochemical Processes", Sec. 2.4.2
    # Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
    # https://arxiv.org/abs/1711.08079

    T, D = np.shape(yt)
    flat_args, unflatten = flatten(func_args)

    def flat_func(y, t, flat_args):
        return func(y, t, *unflatten(flat_args))

    def unpack(x):
        #      y,      vjp_y,      vjp_t,    vjp_args
        return x[0:D], x[D : 2 * D], x[2 * D], x[2 * D + 1 :]

    def augmented_dynamics(augmented_state, t, flat_args):
        # Orginal system augmented with vjp_y, vjp_t and vjp_args.
        y, vjp_y, _, _ = unpack(augmented_state)
        vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
        vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
        return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))

    def vjp_all(g):
        vjp_y = g[-1, :]
        vjp_t0 = 0
        time_vjp_list = []
        vjp_args = np.zeros(np.size(flat_args))

        for i in range(T - 1, 0, -1):
            # Compute effect of moving measurement time.
            vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
            time_vjp_list.append(vjp_cur_t)
            vjp_t0 = vjp_t0 - vjp_cur_t

            # Run augmented system backwards to the previous observation.
            aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
            aug_ans = odeint(
                augmented_dynamics, aug_y0, np.array([t[i], t[i - 1]]), tuple((flat_args,)), **kwargs
            )
            _, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])

            # Add gradient from current output.
            vjp_y = vjp_y + g[i - 1, :]

        time_vjp_list.append(vjp_t0)
        vjp_times = np.hstack(time_vjp_list)[::-1]

        return None, vjp_y, vjp_times, unflatten(vjp_args)

    return vjp_all


def argnums_unpack(all_vjp_builder):
    # A generic autograd helper function.  Takes a function that
    # builds vjps for all arguments, and wraps it to return only required vjps.
    def build_selected_vjps(argnums, ans, combined_args, kwargs):
        vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)

        def chosen_vjps(g):  # Returns whichever vjps were asked for.
            all_vjps = vjp_func(g)
            return [all_vjps[argnum] for argnum in argnums]

        return chosen_vjps

    return build_selected_vjps


defvjp_argnums(odeint, argnums_unpack(grad_odeint))


================================================
FILE: autograd/scipy/linalg.py
================================================
from functools import partial

import scipy.linalg

import autograd.numpy as anp
from autograd.extend import defjvp, defjvp_argnums, defvjp, defvjp_argnums
from autograd.numpy.numpy_wrapper import wrap_namespace

wrap_namespace(scipy.linalg.__dict__, globals())  # populates module namespace


def _vjp_sqrtm(ans, A, disp=True, blocksize=64):
    assert disp, "sqrtm vjp not implemented for disp=False"
    ans_transp = anp.transpose(ans)

    def vjp(g):
        return anp.real(solve_sylvester(ans_transp, ans_transp, g))

    return vjp


defvjp(sqrtm, _vjp_sqrtm)


def _flip(a, trans):
    if anp.iscomplexobj(a):
        return "H" if trans in ("N", 0) else "N"
    else:
        return "T" if trans in ("N", 0) else "N"


def grad_solve_triangular(ans, a, b, trans=0, lower=False, **kwargs):
    tri = anp.tril if (lower ^ (_flip(a, trans) == "N")) else anp.triu
    transpose = lambda x: x if _flip(a, trans) != "N" else x.T
    al2d = lambda x: x if x.ndim > 1 else x[..., None]

    def vjp(g):
        v = al2d(solve_triangular(a, g, trans=_flip(a, trans), lower=lower))
        return -transpose(tri(anp.dot(v, al2d(ans).T)))

    return vjp


defvjp(
    solve_triangular,
    grad_solve_triangular,
    lambda ans, a, b, trans=0, lower=False, **kwargs: (
        lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower)
    ),
)


def grad_solve_banded(argnum, ans, l_and_u, a, b):
    updim = lambda x: x if x.ndim == a.ndim else x[..., None]

    def transpose_banded(l_and_u, a):
        # Compute the transpose of a banded matrix.
        # The transpose is itself a banded matrix.

        num_rows = a.shape[0]

        shifts = anp.arange(-l_and_u[1], l_and_u[0] + 1)

        T_a = anp.roll(a[:1, :], shifts[0])
        for rr in range(1, num_rows):
            T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr : rr + 1, :], shifts[rr]))])
        T_a = anp.flipud(T_a)

        T_l_and_u = anp.flip(l_and_u)

        return T_l_and_u, T_a

    def banded_dot(l_and_u, uu, vv):
        # Compute tensor product of vectors uu and vv.
        # Tensor product elements are resticted to the bands specified by l_and_u.

        # TODO: replace the brute-force ravel() by smarter dimension handeling of uu and vv

        # main diagonal
        banded_uv = anp.ravel(uu) * anp.ravel(vv)

        # stack below the sub-diagonals
        for rr in range(1, l_and_u[0] + 1):
            banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:] * anp.ravel(vv)[:-rr], anp.zeros(rr)])
            banded_uv = anp.vstack([banded_uv, banded_uv_rr])

        # stack above the sup-diagonals
        for rr in range(1, l_and_u[1] + 1):
            banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr] * anp.ravel(vv)[rr:]])
            banded_uv = anp.vstack([banded_uv_rr, banded_uv])

        return banded_uv

    T_l_and_u, T_a = transpose_banded(l_and_u, a)

    if argnum == 1:
        return lambda g: (
            -banded_dot(l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans)))
        )
    elif argnum == 2:
        return lambda g: solve_banded(T_l_and_u, T_a, g)


defvjp(solve_banded, partial(grad_solve_banded, 1), partial(grad_solve_banded, 2), argnums=[1, 2])


def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
    assert disp, "sqrtm jvp not implemented for disp=False"
    return solve_sylvester(ans, ans, dA)


defjvp(sqrtm, _jvp_sqrtm)


def _jvp_sylvester(argnums, dms, ans, args, _):
    a, b, q = args
    if 0 in argnums:
        da = dms[0]
        db = dms[1] if 1 in argnums else 0
    else:
        da = 0
        db = dms[0] if 1 in argnums else 0
    dq = dms[-1] if 2 in argnums else 0
    rhs = dq - anp.dot(da, ans) - anp.dot(ans, db)
    return solve_sylvester(a, b, rhs)


defjvp_argnums(solve_sylvester, _jvp_sylvester)


def _vjp_sylvester(argnums, ans, args, _):
    a, b, q = args

    def vjp(g):
        vjps = []
        q_vjp = solve_sylvester(anp.transpose(a), anp.transpose(b), g)
        if 0 in argnums:
            vjps.append(-anp.dot(q_vjp, anp.transpose(ans)))
        if 1 in argnums:
            vjps.append(-anp.dot(anp.transpose(ans), q_vjp))
        if 2 in argnums:
            vjps.append(q_vjp)
        return tuple(vjps)

    return vjp


defvjp_argnums(solve_sylvester, _vjp_sylvester)


================================================
FILE: autograd/scipy/signal.py
================================================
from functools import partial

import numpy as npo  # original numpy
from numpy.lib.stride_tricks import as_strided

import autograd.numpy as np
from autograd.extend import defvjp, primitive


@primitive
def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
    assert mode in ["valid", "full"], f"Mode {mode} not yet implemented"
    if axes is None:
        axes = [list(range(A.ndim)), list(range(A.ndim))]
    wrong_order = any([B.shape[ax_B] < A.shape[ax_A] for ax_A, ax_B in zip(*axes)])
    if wrong_order:
        if mode == "valid" and not all([B.shape[ax_B] <= A.shape[ax_A] for ax_A, ax_B in zip(*axes)]):
            raise Exception("One array must be larger than the other along all convolved dimensions")
        elif mode != "full" or B.size <= A.size:  # Tie breaker
            i1 = B.ndim - len(dot_axes[1]) - len(axes[1])  # B ignore
            i2 = i1 + A.ndim - len(dot_axes[0]) - len(axes[0])  # A ignore
            i3 = i2 + len(axes[0])
            ignore_B = list(range(i1))
            ignore_A = list(range(i1, i2))
            conv = list(range(i2, i3))
            return convolve(B, A, axes=axes[::-1], dot_axes=dot_axes[::-1], mode=mode).transpose(
                ignore_A + ignore_B + conv
            )

    if mode == "full":
        B = pad_to_full(B, A, axes[::-1])
    B_view_shape = list(B.shape)
    B_view_strides = list(B.strides)
    flipped_idxs = [slice(None)] * A.ndim
    for ax_A, ax_B in zip(*axes):
        B_view_shape.append(abs(B.shape[ax_B] - A.shape[ax_A]) + 1)
        B_view_strides.append(B.strides[ax_B])
        B_view_shape[ax_B] = A.shape[ax_A]
        flipped_idxs[ax_A] = slice(None, None, -1)
    B_view = as_strided(B, B_view_shape, B_view_strides)
    A_view = A[tuple(flipped_idxs)]
    all_axes = [list(axes[i]) + list(dot_axes[i]) for i in [0, 1]]
    return einsum_tensordot(A_view, B_view, all_axes)


def einsum_tensordot(A, B, axes, reverse=False):
    # Does tensor dot product using einsum, which shouldn't require a copy.
    A_axnums = list(range(A.ndim))
    B_axnums = list(range(A.ndim, A.ndim + B.ndim))
    sum_axnum = A.ndim + B.ndim
    for i_sum, (i_A, i_B) in enumerate(zip(*axes)):
        A_axnums[i_A] = sum_axnum + i_sum
        B_axnums[i_B] = sum_axnum + i_sum
    return npo.einsum(A, A_axnums, B, B_axnums)


def pad_to_full(A, B, axes):
    A_pad = [(0, 0)] * A.ndim
    for ax_A, ax_B in zip(*axes):
        A_pad[ax_A] = (B.shape[ax_B] - 1,) * 2
    return npo.pad(A, A_pad, mode="constant")


def parse_axes(A_shape, B_shape, conv_axes, dot_axes, mode):
    A_ndim, B_ndim = len(A_shape), len(B_shape)
    if conv_axes is None:
        conv_axes = (
            tuple(range(A_ndim)),
            tuple(range(A_ndim)),
        )
    axes = {
        "A": {
            "conv": tuple(conv_axes[0]),
            "dot": tuple(dot_axes[0]),
            "ignore": tuple(i for i in range(A_ndim) if i not in conv_axes[0] and i not in dot_axes[0]),
        },
        "B": {
            "conv": tuple(conv_axes[1]),
            "dot": tuple(dot_axes[1]),
            "ignore": tuple(i for i in range(B_ndim) if i not in conv_axes[1] and i not in dot_axes[1]),
        },
    }
    assert len(axes["A"]["dot"]) == len(axes["B"]["dot"])
    assert len(axes["A"]["conv"]) == len(axes["B"]["conv"])
    i1 = len(axes["A"]["ignore"])
    i2 = i1 + len(axes["B"]["ignore"])
    i3 = i2 + len(axes["A"]["conv"])
    axes["out"] = {
        "ignore_A": tuple(range(i1)),
        "ignore_B": tuple(range(i1, i2)),
        "conv": tuple(range(i2, i3)),
    }
    conv_shape = (
        compute_conv_size(A_shape[i], B_shape[j], mode) for i, j in zip(axes["A"]["conv"], axes["B"]["conv"])
    )
    shapes = {
        "A": {s: tuple(A_shape[i] for i in ax) for s, ax in axes["A"].items()},
        "B": {s: tuple(B_shape[i] for i in ax) for s, ax in axes["B"].items()},
    }
    shapes["out"] = {
        "ignore_A": shapes["A"]["ignore"],
        "ignore_B": shapes["B"]["ignore"],
        "conv": conv_shape,
    }
    return axes, shapes


def compute_conv_size(A_size, B_size, mode):
    if mode == "full":
        return A_size + B_size - 1
    elif mode == "same":
        return A_size
    elif mode == "valid":
        return abs(A_size - B_size) + 1
    else:
        raise Exception(f"Mode {mode} not recognized")


def flipped_idxs(ndim, axes):
    new_idxs = [slice(None)] * ndim
    for ax in axes:
        new_idxs[ax] = slice(None, None, -1)
    return tuple(new_idxs)


def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode="full"):
    assert mode in ["valid", "full"], f"Grad for mode {mode} not yet implemented"
    axes, shapes = parse_axes(A.shape, B.shape, axes, dot_axes, mode)
    if argnum == 0:
        X, Y = A, B
        _X_, _Y_ = "A", "B"
        ignore_Y = "ignore_B"
    elif argnum == 1:
        X, Y = B, A
        _X_, _Y_ = "B", "A"
        ignore_Y = "ignore_A"
    else:
        raise NotImplementedError(f"Can't take grad of convolve w.r.t. arg {argnum}")

    if mode == "full":
        new_mode = "valid"
    else:
        if any([x_size > y_size for x_size, y_size in zip(shapes[_X_]["conv"], shapes[_Y_]["conv"])]):
            new_mode = "full"
        else:
            new_mode = "valid"

    def vjp(g):
        result = convolve(
            g,
            Y[flipped_idxs(Y.ndim, axes[_Y_]["conv"])],
            axes=[axes["out"]["conv"], axes[_Y_]["conv"]],
            dot_axes=[axes["out"][ignore_Y], axes[_Y_]["ignore"]],
            mode=new_mode,
        )
        new_order = npo.argsort(axes[_X_]["ignore"] + axes[_X_]["dot"] + axes[_X_]["conv"])
        return np.transpose(result, new_order)

    return vjp


defvjp(convolve, partial(grad_convolve, 0), partial(grad_convolve, 1))


================================================
FILE: autograd/scipy/special.py
================================================
import scipy.special

import autograd.numpy as np
from autograd.extend import defjvp, defvjp, primitive
from autograd.numpy.numpy_vjps import repeat_to_match_shape, unbroadcast_f

### Beta function ###
beta = primitive(scipy.special.beta)
betainc = primitive(scipy.special.betainc)
betaln = primitive(scipy.special.betaln)

defvjp(
    beta,
    lambda ans, a, b: unbroadcast_f(a, lambda g: g * ans * (psi(a) - psi(a + b))),
    lambda ans, a, b: unbroadcast_f(b, lambda g: g * ans * (psi(b) - psi(a + b))),
)
defvjp(
    betainc,
    lambda ans, a, b, x: unbroadcast_f(
        x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b)
    ),
    argnums=[2],
)
defvjp(
    betaln,
    lambda ans, a, b: unbroadcast_f(a, lambda g: g * (psi(a) - psi(a + b))),
    lambda ans, a, b: unbroadcast_f(b, lambda g: g * (psi(b) - psi(a + b))),
)

### Gamma functions ###
polygamma = primitive(scipy.special.polygamma)
psi = primitive(scipy.special.psi)  # psi(x) is just polygamma(0, x)
digamma = primitive(scipy.special.digamma)  # digamma is another name for psi.
gamma = primitive(scipy.special.gamma)
gammaln = primitive(scipy.special.gammaln)
gammainc = primitive(scipy.special.gammainc)
gammaincc = primitive(scipy.special.gammaincc)
gammasgn = primitive(scipy.special.gammasgn)
rgamma = primitive(scipy.special.rgamma)
multigammaln = primitive(scipy.special.multigammaln)

defvjp(gammasgn, None)
defvjp(polygamma, None, lambda ans, n, x: lambda g: g * polygamma(n + 1, x))
defvjp(psi, lambda ans, x: lambda g: g * polygamma(1, x))
defvjp(digamma, lambda ans, x: lambda g: g * polygamma(1, x))
defvjp(gamma, lambda ans, x: lambda g: g * ans * psi(x))
defvjp(gammaln, lambda ans, x: lambda g: g * psi(x))
defvjp(rgamma, lambda ans, x: lambda g: g * psi(x) / -gamma(x))
defvjp(
    multigammaln,
    lambda ans, a, d: lambda g: g * np.sum(digamma(np.expand_dims(a, -1) - np.arange(d) / 2.0), -1),
    None,
)


def make_gammainc_vjp_arg1(sign):
    def gammainc_vjp_arg1(ans, a, x):
        coeffs = sign * np.exp(-x) * np.power(x, a - 1) / gamma(a)
        return unbroadcast_f(x, lambda g: g * coeffs)

    return gammainc_vjp_arg1


defvjp(gammainc, make_gammainc_vjp_arg1(1), argnums=[1])
defvjp(gammaincc, make_gammainc_vjp_arg1(-1), argnums=[1])

### Bessel functions ###

j0 = primitive(scipy.special.j0)
y0 = primitive(scipy.special.y0)
j1 = primitive(scipy.special.j1)
y1 = primitive(scipy.special.y1)
jn = primitive(scipy.special.jn)
yn = primitive(scipy.special.yn)

defvjp(j0, lambda ans, x: lambda g: -g * j1(x))
defvjp(y0, lambda ans, x: lambda g: -g * y1(x))
defvjp(j1, lambda ans, x: lambda g: g * (j0(x) - jn(2, x)) / 2.0)
defvjp(y1, lambda ans, x: lambda g: g * (y0(x) - yn(2, x)) / 2.0)
defvjp(jn, None, lambda ans, n, x: lambda g: g * (jn(n - 1, x) - jn(n + 1, x)) / 2.0)
defvjp(yn, None, lambda ans, n, x: lambda g: g * (yn(n - 1, x) - yn(n + 1, x)) / 2.0)


### Faster versions of common Bessel functions ###
i0 = primitive(scipy.special.i0)
i1 = primitive(scipy.special.i1)
iv = primitive(scipy.special.iv)
ive = primitive(scipy.special.ive)

defvjp(i0, lambda ans, x: lambda g: g * i1(x))
defvjp(i1, lambda ans, x: lambda g: g * (i0(x) + iv(2, x)) / 2.0)
defvjp(iv, None, lambda ans, n, x: lambda g: g * (iv(n - 1, x) + iv(n + 1, x)) / 2.0)
defvjp(ive, None, lambda ans, n, x: lambda g: g * (ans * (n / x - np.sign(x)) + ive(n + 1, x)))

### Error Function ###
inv_root_pi = 0.56418958354775627928
erf = primitive(scipy.special.erf)
erfc = primitive(scipy.special.erfc)

defvjp(erf, lambda ans, x: lambda g: 2.0 * g * inv_root_pi * np.exp(-(x**2)))
defvjp(erfc, lambda ans, x: lambda g: -2.0 * g * inv_root_pi * np.exp(-(x**2)))


### Inverse error function ###
root_pi = 1.7724538509055159
erfinv = primitive(scipy.special.erfinv)
erfcinv = primitive(scipy.special.erfcinv)

defvjp(erfinv, lambda ans, x: lambda g: g * root_pi / 2 * np.exp(erfinv(x) ** 2))
defvjp(erfcinv, lambda ans, x: lambda g: -g * root_pi / 2 * np.exp(erfcinv(x) ** 2))

### Logit and Expit ###
logit = primitive(scipy.special.logit)
expit = primitive(scipy.special.expit)

defvjp(logit, lambda ans, x: lambda g: g / (x * (1 - x)))
defvjp(expit, lambda ans, x: lambda g: g * ans * (1 - ans))

### logsumexp ###
logsumexp = primitive(scipy.special.logsumexp)


def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False):
    shape, dtype = np.shape(x), np.result_type(x)

    def vjp(g):
        g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
        ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims)
        return g_repeated * b * np.exp(x - ans_repeated)

    return vjp


defvjp(logsumexp, make_grad_logsumexp)


def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False):
    if not keepdims:
        if isinstance(axis, int):
            ans = np.expand_dims(ans, axis)
        elif isinstance(axis, tuple):
            for ax in sorted(axis):
                ans = np.expand_dims(ans, ax)
    return np.sum(g * b * np.exp(x - ans), axis=axis, keepdims=keepdims)


defjvp(logsumexp, fwd_grad_logsumexp)


================================================
FILE: autograd/scipy/stats/__init__.py
================================================
from . import beta, chi2, gamma, norm, poisson, t

# Try block needed in case the user has an
# old version of scipy without multivariate normal.
try:
    from . import multivariate_normal
except AttributeError:
    pass

try:
    from . import dirichlet
except AttributeError:
    pass


================================================
FILE: autograd/scipy/stats/beta.py
================================================
import scipy.stats

import autograd.numpy as np
from autograd.extend import defvjp, primitive
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.scipy.special import beta, psi

cdf = primitive(scipy.stats.beta.cdf)
logpdf = primitive(scipy.stats.beta.logpdf)
pdf = primitive(scipy.stats.beta.pdf)


def grad_beta_logpdf_arg0(x, a, b):
    return (1 + a * (x - 1) + x * (b - 2)) / (x * (x - 1))


def grad_beta_logpdf_arg1(x, a, b):
    return np.log(x) - psi(a) + psi(a + b)


def grad_beta_logpdf_arg2(x, a, b):
    return np.log1p(-x) - psi(b) + psi(a + b)


defvjp(
    cdf,
    lambda ans, x, a, b: unbroadcast_f(
        x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b)
    ),
    argnums=[0],
)
defvjp(
    logpdf,
    lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * grad_beta_logpdf_arg0(x, a, b)),
    lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * grad_beta_logpdf_arg1(x, a, b)),
    lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * grad_beta_logpdf_arg2(x, a, b)),
)
defvjp(
    pdf,
    lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * ans * grad_beta_logpdf_arg0(x, a, b)),
    lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * ans * grad_beta_logpdf_arg1(x, a, b)),
    lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * ans * grad_beta_logpdf_arg2(x, a, b)),
)


================================================
FILE: autograd/scipy/stats/chi2.py
================================================
import scipy.stats

import autograd.numpy as np
from autograd.extend import defvjp, primitive
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.scipy.special import gamma

cdf = primitive(scipy.stats.chi2.cdf)
logpdf = primitive(scipy.stats.chi2.logpdf)
pdf = primitive(scipy.stats.chi2.pdf)


def grad_chi2_logpdf(x, df):
    return np.where(df % 1 == 0, (df - x - 2) / (2 * x), 0)


defvjp(
    cdf,
    lambda ans, x, df: unbroadcast_f(
        x, lambda g: g * np.power(2.0, -df / 2) * np.exp(-x / 2) * np.power(x, df / 2 - 1) / gamma(df / 2)
    ),
    argnums=[0],
)
defvjp(logpdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * grad_chi2_logpdf(x, df)), argnums=[0])
defvjp(pdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * ans * grad_chi2_logpdf(x, df)), argnums=[0])


================================================
FILE: autograd/scipy/stats/dirichlet.py
================================================
import scipy.stats

import autograd.numpy as np
from autograd.extend import defvjp, primitive
from autograd.scipy.special import digamma

rvs = primitive(scipy.stats.dirichlet.rvs)
pdf = primitive(scipy.stats.dirichlet.pdf)
logpdf = primitive(scipy.stats.dirichlet.logpdf)

defvjp(
    logpdf,
    lambda ans, x, alpha: lambda g: g * (alpha - 1) / x,
    lambda ans, x, alpha: lambda g: g * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)),
)

# Same as log pdf, but multiplied by the pdf (ans).
defvjp(
    pdf,
    lambda ans, x, alpha: lambda g: g * ans * (alpha - 1) / x,
    lambda ans, x, alpha: lambda g: g * ans * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)),
)


================================================
FILE: autograd/scipy/stats/gamma.py
================================================
import scipy.stats

import autograd.numpy as np
from autograd.extend import defvjp, primitive
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.scipy.special import gamma, psi

cdf = primitive(scipy.stats.gamma.cdf)
logpdf = primitive(scipy.stats.gamma.logpdf)
pdf = primitive(scipy.stats.gamma.pdf)


def grad_gamma_logpdf_arg0(x, a):
    return (a - x - 1) / x


def grad_gamma_logpdf_arg1(x, a):
    return np.log(x) - psi(a)


defvjp(
    cdf,
    lambda ans, x, a: unbroadcast_f(x, lambda g: g * np.exp(-x) * np.power(x, a - 1) / gamma(a)),
    argnums=[0],
)
defvjp(
    logpdf,
    lambda ans, x, a: unbroadcast_f(x, lambda g: g * grad_gamma_logpdf_arg0(x, a)),
    lambda ans, x, a: unbroadcast_f(a, lambda g: g * grad_gamma_logpdf_arg1(x, a)),
)
defvjp(
    pdf,
    lambda ans, x, a: unbroadcast_f(x, lambda g: g * ans * grad_gamma_logpdf_arg0(x, a)),
    lambda ans, x, a: unbroadcast_f(a, lambda g: g * ans * grad_gamma_logpdf_arg1(x, a)),
)


================================================
FILE: autograd/scipy/stats/multivariate_normal.py
================================================
import scipy.stats

import autograd.numpy as np
from autograd.extend import defvjp, primitive
from autograd.numpy.numpy_vjps import unbroadcast_f

pdf = primitive(scipy.stats.multivariate_normal.pdf)
logpdf = primitive(scipy.stats.multivariate_normal.logpdf)
entropy = primitive(scipy.stats.multivariate_normal.entropy)

# With thanks to Eric Bresch.
# Some formulas are from
# "An extended collection of matrix derivative results
#  for forward and reverse mode algorithmic differentiation"
# by Mike Giles
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf


def generalized_outer_product(x):
    if np.ndim(x) == 1:
        return np.outer(x, x)
    return np.matmul(x, np.swapaxes(x, -1, -2))


def covgrad(x, mean, cov, allow_singular=False):
    if allow_singular:
        raise NotImplementedError(
            "The multivariate normal pdf is not differentiable w.r.t. a singular covariance matix"
        )
    J = np.linalg.inv(cov)
    solved = np.matmul(J, np.expand_dims(x - mean, -1))
    return 1.0 / 2 * (generalized_outer_product(solved) - J)


def solve(allow_singular):
    if allow_singular:
        return lambda A, x: np.dot(np.linalg.pinv(A), x)
    else:
        return np.linalg.solve


defvjp(
    logpdf,
    lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
        x, lambda g: -np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T
    ),
    lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
        mean, lambda g: np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T
    ),
    lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
        cov, lambda g: np.reshape(g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular)
    ),
)

# Same as log pdf, but multiplied by the pdf (ans).
defvjp(
    pdf,
    lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
        x, lambda g: -np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T
    ),
    lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
        mean,
        lambda g: np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T,
    ),
    lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
        cov, lambda g: np.reshape(ans * g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular)
    ),
)

defvjp(entropy, None, lambda ans, mean, cov: unbroadcast_f(cov, lambda g: 0.5 * g * np.linalg.inv(cov).T))


================================================
FILE: autograd/scipy/stats/norm.py
================================================
"""Gradients of the normal distribution."""

import scipy.stats

import autograd.numpy as anp
from autograd.extend import defvjp, primitive
from autograd.numpy.numpy_vjps import unbroadcast_f

pdf = primitive(scipy.stats.norm.pdf)
cdf = primitive(scipy.stats.norm.cdf)
sf = primitive(scipy.stats.norm.sf)
logpdf = primitive(scipy.stats.norm.logpdf)
logcdf = primitive(scipy.stats.norm.logcdf)
logsf = primitive(scipy.stats.norm.logsf)

defvjp(
    pdf,
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        scale, lambda g: g * ans * (((x - loc) / scale) ** 2 - 1.0) / scale
    ),
)

defvjp(
    cdf,
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        scale, lambda g: -g * pdf(x, loc, scale) * (x - loc) / scale
    ),
)

defvjp(
    logpdf,
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        scale, lambda g: g * (-1.0 / scale + (x - loc) ** 2 / scale**3)
    ),
)

defvjp(
    logcdf,
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))
    ),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        loc, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))
    ),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        scale, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)) * (x - loc) / scale
    ),
)

defvjp(
    logsf,
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        x, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))
    ),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        loc, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))
    ),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        scale, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) * (x - loc) / scale
    ),
)

defvjp(
    sf,
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * pdf(x, loc, scale)),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * pdf(x, loc, scale)),
    lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
        scale, lambda g: g * pdf(x, loc, scale) * (x - loc) / scale
    ),
)


================================================
FILE: autograd/scipy/stats/poisson.py
================================================
import scipy.stats

import autograd.numpy as np
from autograd.extend import defvjp, primitive
from autograd.numpy.numpy_vjps import unbroadcast_f

cdf = primitive(scipy.stats.poisson.cdf)
logpmf = primitive(scipy.stats.poisson.logpmf)
pmf = primitive(scipy.stats.poisson.pmf)


def grad_poisson_logpmf(k, mu):
    return np.where(k % 1 == 0, k / mu - 1, 0)


defvjp(cdf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * -pmf(np.floor(k), mu)), argnums=[1])
defvjp(logpmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * grad_poisson_logpmf(k, mu)), argnums=[1])
defvjp(
    pmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * ans * grad_poisson_logpmf(k, mu)), argnums=[1]
)


================================================
FILE: autograd/scipy/stats/t.py
================================================
"""Gradients of the univariate t distribution."""

import scipy.stats

import autograd.numpy as np
from autograd.extend import defvjp, primitive
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.scipy.special import psi

pdf = primitive(scipy.stats.t.pdf)
cdf = primitive(scipy.stats.t.cdf)
logpdf = primitive(scipy.stats.t.logpdf)
logcdf = primitive(scipy.stats.t.logcdf)


def grad_tlogpdf_diff(diff, df):
    return -diff * (1.0 + df) / (diff**2 + df)


def grad_tlogpdf_x(x, df, loc, scale):
    return grad_tlogpdf_diff((x - loc) / scale, df) / scale


def grad_tlogpdf_loc(x, df, loc, scale):
    return -grad_tlogpdf_diff((x - loc) / scale, df) / scale


def grad_tlogpdf_scale(x, df, loc, scale):
    diff = x - loc
    return -(df * (scale**2 - diff**2)) / (scale * (df * scale**2 + diff**2))


def grad_tlogpdf_df(x, df, loc, scale):
    y = (x - loc) / scale
    return 0.5 * (
        (y**2 * (df + 1)) / (df * (y**2 + df))
        - np.log(y**2 / df + 1)
        - 1.0 / df
        - psi(df / 2.0)
        + psi((df + 1) / 2.0)
    )


defvjp(
    pdf,
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        x, lambda g: g * ans * grad_tlogpdf_x(x, df, loc, scale)
    ),
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        df, lambda g: g * ans * grad_tlogpdf_df(x, df, loc, scale)
    ),
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        loc, lambda g: g * ans * grad_tlogpdf_loc(x, df, loc, scale)
    ),
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        scale, lambda g: g * ans * grad_tlogpdf_scale(x, df, loc, scale)
    ),
)

defvjp(
    cdf,
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * pdf(x, df, loc, scale)),
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: -g * pdf(x, df, loc, scale)),
    argnums=(0, 2),
)

defvjp(
    logpdf,
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * grad_tlogpdf_x(x, df, loc, scale)),
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        df, lambda g: g * grad_tlogpdf_df(x, df, loc, scale)
    ),
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        loc, lambda g: g * grad_tlogpdf_loc(x, df, loc, scale)
    ),
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        scale, lambda g: g * grad_tlogpdf_scale(x, df, loc, scale)
    ),
)

defvjp(
    logcdf,
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        x, lambda g: g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale))
    ),
    lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
        loc, lambda g: -g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale))
    ),
    argnums=(0, 2),
)


================================================
FILE: autograd/test_util.py
================================================
from itertools import product

from .core import make_jvp, make_vjp, vspace
from .wrap_util import get_name, unary_to_nary

TOL = 1e-6
RTOL = 1e-6


def scalar_close(a, b):
    return abs(a - b) < TOL or abs(a - b) / abs(a + b) < RTOL


EPS = 1e-6


def make_numerical_jvp(f, x):
    y = f(x)
    x_vs, y_vs = vspace(x), vspace(y)

    def jvp(v):
        # (f(x + v*eps/2) - f(x - v*eps/2)) / eps
        f_x_plus = f(x_vs.add(x, x_vs.scalar_mul(v, EPS / 2)))
        f_x_minus = f(x_vs.add(x, x_vs.scalar_mul(v, -EPS / 2)))
        neg_f_x_minus = y_vs.scalar_mul(f_x_minus, -1.0)
        return y_vs.scalar_mul(y_vs.add(f_x_plus, neg_f_x_minus), 1.0 / EPS)

    return jvp


def check_vjp(f, x):
    vjp, y = make_vjp(f, x)
    jvp = make_numerical_jvp(f, x)
    x_vs, y_vs = vspace(x), vspace(y)
    x_v, y_v = x_vs.randn(), y_vs.randn()

    vjp_y = x_vs.covector(vjp(y_vs.covector(y_v)))
    assert vspace(vjp_y) == x_vs
    vjv_exact = x_vs.inner_prod(x_v, vjp_y)
    vjv_numeric = y_vs.inner_prod(y_v, jvp(x_v))
    assert scalar_close(vjv_numeric, vjv_exact), (
        "Derivative (VJP) check of {} failed with arg {}:\nanalytic: {}\nnumeric:  {}".format(
            get_name(f), x, vjv_exact, vjv_numeric
        )
    )


def check_jvp(f, x):
    jvp = make_jvp(f, x)
    jvp_numeric = make_numerical_jvp(f, x)
    x_v = vspace(x).randn()
    check_equivalent(jvp(x_v)[1], jvp_numeric(x_v))


def check_equivalent(x, y):
    x_vs, y_vs = vspace(x), vspace(y)
    assert x_vs == y_vs, f"VSpace mismatch:\nx: {x_vs}\ny: {y_vs}"
    v = x_vs.randn()
    assert scalar_close(x_vs.inner_prod(x, v), x_vs.inner_prod(y, v)), f"Value mismatch:\nx: {x}\ny: {y}"


@unary_to_nary
def check_grads(f, x, modes=["fwd", "rev"], order=2):
    assert all(m in ["fwd", "rev"] for m in modes)
    if "fwd" in modes:
        check_jvp(f, x)
        if order > 1:
            grad_f = lambda x, v: make_jvp(f, x)(v)[1]
            grad_f.__name__ = f"jvp_{get_name(f)}"
            v = vspace(x).randn()
            check_grads(grad_f, (0, 1), modes, order=order - 1)(x, v)
    if "rev" in modes:
        check_vjp(f, x)
        if order > 1:
            grad_f = lambda x, v: make_vjp(f, x)[0](v)
            grad_f.__name__ = f"vjp_{get_name(f)}"
            v = vspace(f(x)).randn()
            check_grads(grad_f, (0, 1), modes, order=order - 1)(x, v)


def combo_check(fun, *args, **kwargs):
    # Tests all combinations of args and kwargs given.
    _check_grads = lambda f: check_grads(f, *args, **kwargs)

    def _combo_check(*args, **kwargs):
        kwarg_key_vals = [[(k, x) for x in xs] for k, xs in kwargs.items()]
        for _args in product(*args):
            for _kwargs in product(*kwarg_key_vals):
                _check_grads(fun)(*_args, **dict(_kwargs))

    return _combo_check


================================================
FILE: autograd/tracer.py
================================================
import warnings
from collections import defaultdict
from contextlib import contextmanager

from .util import subvals, toposort
from .wrap_util import wraps


def trace(start_node, fun, x):
    with trace_stack.new_trace() as t:
        start_box = new_box(x, t, start_node)
        end_box = fun(start_box)
        if isbox(end_box) and end_box._trace == start_box._trace:
            return end_box._value, end_box._node
        else:
            warnings.warn("Output seems independent of input.")
            return end_box, None


class Node:
    __slots__ = []

    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        assert False

    def initialize_root(self, *args, **kwargs):
        assert False

    @classmethod
    def new_root(cls, *args, **kwargs):
        root = cls.__new__(cls)
        root.initialize_root(*args, **kwargs)
        return root


def primitive(f_raw):
    """
    Wraps a function so that its gradient can be specified and its invocation
    can be recorded. For examples, see the docs."""

    @wraps(f_raw)
    def f_wrapped(*args, **kwargs):
        boxed_args, trace, node_constructor = find_top_boxed_args(args)
        if boxed_args:
            argvals = subvals(args, [(argnum, box._value) for argnum, box in boxed_args])
            if f_wrapped in notrace_primitives[node_constructor]:
                return f_wrapped(*argvals, **kwargs)
            parents = tuple(box._node for _, box in boxed_args)
            argnums = tuple(argnum for argnum, _ in boxed_args)
            ans = f_wrapped(*argvals, **kwargs)
            node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents)
            return new_box(ans, trace, node)
        else:
            return f_raw(*args, **kwargs)

    f_wrapped.fun = f_raw
    f_wrapped._is_autograd_primitive = True
    return f_wrapped


notrace_primitives = defaultdict(set)


def register_notrace(trace_type, primitive_fun):
    notrace_primitives[trace_type].add(primitive_fun)


def notrace_primitive(f_raw):
    @wraps(f_raw)
    def f_wrapped(*args, **kwargs):
        argvals = map(getval, args)
        return f_raw(*argvals, **kwargs)

    f_wrapped._is_primitive = True
    return f_wrapped


def find_top_boxed_args(args):
    top_trace = -1
    top_boxes = []
    top_node_type = None
    for argnum, arg in enumerate(args):
        if isbox(arg):
            trace = arg._trace
            if trace > top_trace:
                top_boxes = [(argnum, arg)]
                top_trace = trace
                top_node_type = type(arg._node)
            elif trace == top_trace:
                top_boxes.append((argnum, arg))
    return top_boxes, top_trace, top_node_type


class TraceStack:
    def __init__(self):
        self.top = -1

    @contextmanager
    def new_trace(self):
        self.top += 1
        yield self.top
        self.top -= 1


trace_stack = TraceStack()


class Box:
    type_mappings = {}
    types = set()

    __slots__ = ["_value", "_trace", "_node"]

    def __init__(self, value, trace, node):
        self._value = value
        self._node = node
        self._trace = trace

    def __bool__(self):
        return bool(self._value)

    __nonzero__ = __bool__

    def __str__(self):
        return f"Autograd {type(self).__name__} with value {str(self._value)}"

    @classmethod
    def register(cls, value_type):
        Box.types.add(cls)
        Box.type_mappings[value_type] = cls
        Box.type_mappings[cls] = cls


box_type_mappings = Box.type_mappings


def new_box(value, trace, node):
    try:
        return box_type_mappings[type(value)](value, trace, node)
    except KeyError:
        raise TypeError(f"Can't differentiate w.r.t. type {type(value)}")


box_types = Box.types
isbox = lambda x: type(x) in box_types  # almost 3X faster than isinstance(x, Box)
getval = lambda x: getval(x._value) if isbox(x) else x


================================================
FILE: autograd/util.py
================================================
import operator


def subvals(x, ivs):
    x_ = list(x)
    for i, v in ivs:
        x_[i] = v
    return tuple(x_)


def subval(x, i, v):
    x_ = list(x)
    x_[i] = v
    return tuple(x_)


def func(f):
    return f


def toposort(end_node, parents=operator.attrgetter("parents")):
    child_counts = {}
    stack = [end_node]
    while stack:
        node = stack.pop()
        if node in child_counts:
            child_counts[node] += 1
        else:
            child_counts[node] = 1
            stack.extend(parents(node))

    childless_nodes = [end_node]
    while childless_nodes:
        node = childless_nodes.pop()
        yield node
        for parent in parents(node):
            if child_counts[parent] == 1:
                childless_nodes.append(parent)
            else:
                child_counts[parent] -= 1


# -------------------- deprecation warnings -----------------------

import warnings

deprecation_msg = """
The quick_grad_check function is deprecated. See the update guide:
https://github.com/HIPS/autograd/blob/master/docs/updateguide.md"""


def quick_grad_check(
    fun, arg0, extra_args=(), kwargs={}, verbose=True, eps=1e-4, rtol=1e-4, atol=1e-6, rs=None
):
    warnings.warn(deprecation_msg)
    from autograd.test_util import check_grads

    fun_ = lambda arg0: fun(arg0, *extra_args, **kwargs)
    check_grads(fun_, modes=["rev"], order=1)(arg0)


================================================
FILE: autograd/wrap_util.py
================================================
from .util import subvals


def unary_to_nary(unary_operator):
    @wraps(unary_operator)
    def nary_operator(fun, argnum=0, *nary_op_args, **nary_op_kwargs):
        assert type(argnum) in (int, tuple, list), argnum

        @wrap_nary_f(fun, unary_operator, argnum)
        def nary_f(*args, **kwargs):
            @wraps(fun)
            def unary_f(x):
                if isinstance(argnum, int):
                    subargs = subvals(args, [(argnum, x)])
                else:
                    subargs = subvals(args, zip(argnum, x))
                return fun(*subargs, **kwargs)

            if isinstance(argnum, int):
                x = args[argnum]
            else:
                x = tuple(args[i] for i in argnum)
            return unary_operator(unary_f, x, *nary_op_args, **nary_op_kwargs)

        return nary_f

    return nary_operator


def wraps(fun, namestr="{fun}", docstr="{doc}", **kwargs):
    def _wraps(f):
        try:
            f.__name__ = namestr.format(fun=get_name(fun), **kwargs)
            f.__doc__ = docstr.format(fun=get_name(fun), doc=get_doc(fun), **kwargs)
        except BaseException:
            pass
        return f

    return _wraps


def wrap_nary_f(fun, op, argnum):
    namestr = "{op}_of_{fun}_wrt_argnum_{argnum}"
    docstr = """\
    {op} of function {fun} with respect to argument number {argnum}. Takes the
    same arguments as {fun} but returns the {op}.
    """
    return wraps(fun, namestr, docstr, op=get_name(op), argnum=argnum)


get_name = lambda f: getattr(f, "__name__", "[unknown name]")
get_doc = lambda f: getattr(f, "__doc__", "")


================================================
FILE: benchmarks/__init__.py
================================================


================================================
FILE: benchmarks/asv.conf.json.sample
================================================
{
    "version": 1,
    "project": "autograd",
    "project_url": "http://github.com/hips/autograd",
    "branches": ["master"],
    "dvcs": "git",
    "environment_type": "virtualenv",
    "install_timeout": 600,
    "repo"          : "..",
    "benchmark_dir" : ".",
    "env_dir"       : "../.asv/env",
    "results_dir"   : "../.asv/results",
    "html_dir"      : "../.asv/html",
}


================================================
FILE: benchmarks/bench_core.py
================================================
import numpy as onp

import autograd.numpy as np
from autograd import grad

try:
    from autograd.core import VJPNode, backward_pass, vspace
    from autograd.tracer import new_box, trace

    MASTER_BRANCH = False
except ImportError:
    from autograd.core import backward_pass, forward_pass, new_progenitor, vspace

    MASTER_BRANCH = True


## SHORT FUNCTION
def f_short(x):
    return x**2


def time_short_fun():
    f_short(2.0)


def time_short_forward_pass():
    if MASTER_BRANCH:
        forward_pass(f_short, (2.0,), {})
    else:
        start_node = VJPNode.new_root()
        trace(start_node, f_short, x)


def time_short_backward_pass():
    if MASTER_BRANCH:
        backward_pass(1.0, short_end_node, short_start_node)
    else:
        backward_pass(1.0, short_end_node)


def time_short_grad():
    grad(f_short)(2.0)


## LONG FUNCTION
def f_long(x):
    for i in range(50):
        x = np.sin(x)
    return x


def time_long_fun():
    f_long(2.0)


def time_long_forward_pass():
    if MASTER_BRANCH:
        forward_pass(f_long, (2.0,), {})
    else:
        start_node = VJPNode.new_root()
        trace(start_node, f_long, x)


def time_long_backward_pass():
    if MASTER_BRANCH:
        backward_pass(1.0, long_end_node, long_start_node)
    else:
        backward_pass(1.0, long_end_node)


def time_long_grad():
    grad(f_long)(2.0)


## 'PEARLMUTTER TEST' FUNCTION
def fan_out_fan_in(x):
    for i in range(10**4):
        x = (x + x) / 2.0
    return np.sum(x)


def time_fan_out_fan_in_fun():
    fan_out_fan_in(2.0)


def time_fan_out_fan_in_forward_pass():
    if MASTER_BRANCH:
        forward_pass(fan_out_fan_in, (2.0,), {})
    else:
        start_node = VJPNode.new_root()
        trace(start_node, fan_out_fan_in, x)


def time_fan_out_fan_in_backward_pass():
    if MASTER_BRANCH:
        backward_pass(1.0, fan_end_node, fan_start_node)
    else:
        backward_pass(1.0, fan_end_node)


def time_fan_out_fan_in_grad():
    grad(fan_out_fan_in)(2.0)


## UNIT BENCHMARKS
def time_vspace_float():
    vspace(1.0)


A = np.array([[1.0, 2.0, 3.0]])


def time_vspace_array():
    vspace(A)


def time_new_box_float():
    new_box(1.0, 0, start_node)


def time_new_box_array():
    new_box(A, 0, start_node)


def time_exp_call():
    onp.exp(2.0)


def time_exp_primitive_call_unboxed():
    np.exp(2.0)


def time_exp_primitive_call_boxed():
    if MASTER_BRANCH:
        np.exp(progenitor)
    else:
        np.exp(start_box)


def time_no_autograd_control():
    # Test whether the benchmarking machine is running slowly independent of autograd
    A = np.random.randn(200, 200)
    np.dot(A, A)


if MASTER_BRANCH:
    short_start_node, short_end_node = forward_pass(f_short, (2.0,), {})
    long_start_node, long_end_node = forward_pass(f_long, (2.0,), {})
    fan_start_node, fan_end_node = forward_pass(fan_out_fan_in, (2.0,), {})
    progenitor = new_progenitor(2.0)
else:
    x = 2.0
    start_node = VJPNode.new_root()
    start_box = new_box(x, 0, start_node)
    _, short_end_node = trace(VJPNode.new_root(), f_short, x)
    _, long_end_node = trace(VJPNode.new_root(), f_long, x)
    _, fan_end_node = trace(VJPNode.new_root(), fan_out_fan_in, x)


================================================
FILE: benchmarks/bench_mem.py
================================================
import autograd.numpy as np
from autograd import grad


def peakmem_needless_nodes():
    N, M = 1000, 100

    def fun(x):
        for i in range(M):
            x = x + 1
        return np.sum(x)

    grad(fun)(np.zeros((N, N)))


================================================
FILE: benchmarks/bench_numpy_vjps.py
================================================
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import make_vjp

dot_0 = lambda a, b, g: make_vjp(np.dot, argnum=0)(a, b)[0](g)
dot_1 = lambda a, b, g: make_vjp(np.dot, argnum=1)(a, b)[0](g)

dot_0_0 = lambda a, b, g: make_vjp(dot_0, argnum=0)(a, b, g)[0](a)
dot_0_1 = lambda a, b, g: make_vjp(dot_0, argnum=1)(a, b, g)[0](a)
dot_0_2 = lambda a, b, g: make_vjp(dot_0, argnum=2)(a, b, g)[0](a)

dot_1_0 = lambda a, b, g: make_vjp(dot_1, argnum=0)(a, b, g)[0](b)
dot_1_1 = lambda a, b, g: make_vjp(dot_1, argnum=1)(a, b, g)[0](b)
dot_1_2 = lambda a, b, g: make_vjp(dot_1, argnum=2)(a, b, g)[0](b)

a = npr.randn(2, 3, 4, 5)
b = npr.randn(2, 3, 5, 4)
g = npr.randn(2, 3, 4, 2, 3, 4)


def time_dot_0():
    dot_0(a, b, g)


def time_dot_1():
    dot_1(a, b, g)


def time_dot_0_0():
    dot_0_0(a, b, g)


def time_dot_0_1():
    dot_0_1(a, b, g)


def time_dot_0_2():
    dot_0_2(a, b, g)


def time_dot_1_0():
    dot_1_0(a, b, g)


def time_dot_1_1():
    dot_1_1(a, b, g)


def time_dot_1_2():
    dot_1_2(a, b, g)


tensordot_0 = lambda A, B, G: make_vjp(np.tensordot, argnum=0)(A, B, 2)[0](G)
tensordot_1 = lambda A, B, G: make_vjp(np.tensordot, argnum=1)(A, B, 2)[0](G)

tensordot_0_0 = lambda A, B, G: make_vjp(tensordot_0, argnum=0)(A, B, G)[0](A)
tensordot_0_1 = lambda A, B, G: make_vjp(tensordot_0, argnum=1)(A, B, G)[0](A)
tensordot_0_2 = lambda A, B, G: make_vjp(tensordot_0, argnum=2)(A, B, G)[0](A)

tensordot_1_0 = lambda A, B, G: make_vjp(tensordot_1, argnum=0)(A, B, G)[0](B)
tensordot_1_1 = lambda A, B, G: make_vjp(tensordot_1, argnum=1)(A, B, G)[0](B)
tensordot_1_2 = lambda A, B, G: make_vjp(tensordot_1, argnum=2)(A, B, G)[0](B)

A = npr.randn(2, 3, 5, 4)
B = npr.randn(5, 4, 2, 3)
G = npr.randn(2, 3, 2, 3)


def time_tensordot_0():
    tensordot_0(A, B, G)


def time_tensordot_1():
    tensordot_1(A, B, G)


def time_tensordot_0_0():
    tensordot_0_0(A, B, G)


def time_tensordot_0_1():
    tensordot_0_1(A, B, G)


def time_tensordot_0_2():
    tensordot_0_2(A, B, G)


def time_tensordot_1_0():
    tensordot_1_0(A, B, G)


def time_tensordot_1_1():
    tensordot_1_1(A, B, G)


def time_tensordot_1_2():
    tensordot_1_2(A, B, G)


================================================
FILE: benchmarks/bench_rnn.py
================================================
# Write the benchmarking functions here.
# See "Writing benchmarks" in the asv docs for more information.
# http://asv.readthedocs.io/en/latest/writing_benchmarks.html
import autograd.numpy as np
from autograd import grad


class RNNSuite:
    """
    Checking speed on a vanilla RNN.
    """

    # NOTE: this is run each time we run a benchmark.
    # Might want to switch to setup_cache, which has to return an object which is loaded and unpacked in setup().
    def setup(self):
        self.batch_size = 16
        self.dtype = "float32"
        self.D = 2**10
        self.x = 0.01 * np.random.randn(self.batch_size, self.D).astype(self.dtype)
        self.W1 = 0.01 * np.random.randn(self.D, self.D).astype(self.dtype)
        self.b1 = 0.01 * np.random.randn(self.D).astype(self.dtype)
        self.Wout = 0.01 * np.random.randn(self.D, 1).astype(self.dtype)
        self.bout = 0.01 * np.random.randn(1).astype(self.dtype)
        self.l = (np.random.rand(self.batch_size, 1) > 0.5).astype(self.dtype)
        self.n = 50

        def autograd_rnn(params, x, label, n):
            W, b, Wout, bout = params
            h1 = x
            for i in range(n):
                h1 = np.tanh(np.dot(h1, W) + b)
            logit = np.dot(h1, Wout) + bout
            loss = -np.sum(label * logit - (logit + np.log(1 + np.exp(-logit))))
            return loss

        self.fn = autograd_rnn
        self.grad_fn = grad(self.fn)

    def rnn_grad(self):
        self.grad_fn((self.W1, self.b1, self.Wout, self.bout), self.x, self.l, self.n)

    def time_rnn_grad(self):
        self.rnn_grad()

    def peakmem_rnn_grad(self):
        self.rnn_grad()

    def time_manual_rnn_grad(self):
        self.manual_rnn_grad()

    def peakmem_manual_rnn_grad(self):
        self.manual_rnn_grad()

    def manual_rnn_grad(self):
        def repeat_to_match_shape(g, A, axis=None):
            gout = np.empty_like(A)
            if np.ndim(gout) == 0:
                gout = g
            else:
                gout = np.ones_like(A) * g
            return gout

        def sum_to_match_shape(sum_this, to_match_this):
            sum_this = np.sum(sum_this, axis=tuple(range(0, np.ndim(sum_this) - np.ndim(to_match_this))))
            for axis, size in enumerate(np.shape(to_match_this)):
                if size == 1:
                    sum_this = np.sum(sum_this, axis=axis, keepdims=True)
            return sum_this

        def grad_dot_A(g, A, B):
            ga = np.dot(g, B.T)
            ga = np.reshape(ga, np.shape(A))
            return ga

        def grad_dot_B(g, A, B):
            gb = np.dot(A.T, g)
            gb = np.reshape(gb, np.shape(B))
            return gb

        def _rnn_grad(x, W, b, Wout, bout, label, n):
            h1__1_stack, h1__1 = [], None
            h1__0_stack, h1__0 = [], None
            out_stack, out = [], None
            h1_stack = []
            h1 = x
            _for1 = list(range(n))

            for i in _for1:
                h1__1_stack.append(h1__1)
                h1__1 = np.dot(h1, W)
                h1__0_stack.append(h1__0)
                h1__0 = h1__1 + b
                h1_stack.append(h1)
                h1 = np.tanh(h1__0)
            out__0 = np.dot(h1, Wout)
            out = out__0 + bout
            loss__2 = label * out
            loss__7 = -out
            loss__6 = np.exp(loss__7)
            loss__5 = 1 + loss__6
            loss__4 = np.log(loss__5)
            loss__3 = out + loss__4
            loss__1 = loss__2 - loss__3

            # Begin Backward Pass
            g_loss = 1
            g_h1__0 = 0
            g_h1__1 = 0
            g_b = 0
            g_W = 0

            # Reverse of: loss = -loss__0
            g_loss__0 = -g_loss

            # Reverse of: loss__0 = np.sum(loss__1)
            g_loss__1 = repeat_to_match_shape(g_loss__0, loss__1)

            # Reverse of: loss__1 = loss__2 - loss__3
            g_loss__2 = sum_to_match_shape(g_loss__1, loss__2)
            g_loss__3 = sum_to_match_shape(-g_loss__1, loss__3)

            # Reverse of: loss__3 = out + loss__4
            g_out = sum_to_match_shape(g_loss__3, out)
            g_loss__4 = sum_to_match_shape(g_loss__3, loss__4)

            # Reverse of: loss__4 = np.log(loss__5)
            g_loss__5 = g_loss__4 / loss__5

            # Reverse of: loss__5 = 1 + loss__6
            g_loss__6 = sum_to_match_shape(g_loss__5, loss__6)

            # Reverse of: loss__6 = np.exp(loss__7)
            g_loss__7 = g_loss__6 * np.exp(loss__7)

            # Reverse of: loss__7 = -out
            g_out += -g_loss__7
            g_out += sum_to_match_shape(g_loss__2 * label, out)

            # Reverse of: out = out__0 + bout
            g_out__0 = sum_to_match_shape(g_out, out__0)
            g_bout = sum_to_match_shape(g_out, bout)

            # Reverse of: out__0 = np.dot(h1, Wout)
            g_h1 = grad_dot_A(g_out__0, h1, Wout)
            g_Wout = grad_dot_B(g_out__0, h1, Wout)
            _for1 = reversed(_for1)
            for i in _for1:
                h1 = h1_stack.pop()
                tmp_g0 = g_h1 / np.cosh(h1__0) ** 2.0
                g_h1 = 0
                g_h1__0 += tmp_g0
                h1__0 = h1__0_stack.pop()
                tmp_g1 = sum_to_match_shape(g_h1__0, h1__1)
                tmp_g2 = sum_to_match_shape(g_h1__0, b)
                g_h1__0 = 0
                g_h1__1 += tmp_g1
                g_b += tmp_g2
                h1__1 = h1__1_stack.pop()
                tmp_g3 = grad_dot_A(g_h1__1, h1, W)
                tmp_g4 = grad_dot_B(g_h1__1, h1, W)
                g_h1__1 = 0
                g_h1 += tmp_g3
                g_W += tmp_g4
            return g_W, g_b, g_Wout, g_bout

        _rnn_grad(self.x, self.W1, self.b1, self.Wout, self.bout, self.l, self.n)
        pass


================================================
FILE: benchmarks/bench_util.py
================================================
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad

try:
    from autograd.misc.flatten import flatten
except ImportError:
    from autograd.util import flatten


def time_flatten():
    val = {
        "k": npr.random((4, 4)),
        "k2": npr.random((3, 3)),
        "k3": 3.0,
        "k4": [1.0, 4.0, 7.0, 9.0],
        "k5": np.array([4.0, 5.0, 6.0]),
        "k6": np.array([[7.0, 8.0], [9.0, 10.0]]),
    }

    vect, unflatten = flatten(val)
    val_recovered = unflatten(vect)
    vect_2, _ = flatten(val_recovered)


# def time_vspace_flatten():
#     val = {'k':  npr.random((4, 4)),
#            'k2': npr.random((3, 3)),
#            'k3': 3.0,
#            'k4': [1.0, 4.0, 7.0, 9.0],
#            'k5': np.array([4., 5., 6.]),
#            'k6': np.array([[7., 8.], [9., 10.]])}

#     vspace_flatten(val)


def time_grad_flatten():
    val = {
        "k": npr.random((4, 4)),
        "k2": npr.random((3, 3)),
        "k3": 3.0,
        "k4": [1.0, 4.0, 7.0, 9.0],
        "k5": np.array([4.0, 5.0, 6.0]),
        "k6": np.array([[7.0, 8.0], [9.0, 10.0]]),
    }

    vect, unflatten = flatten(val)

    def fun(vec):
        v = unflatten(vec)
        return np.sum(v["k5"]) + np.sum(v["k6"])

    grad(fun)(vect)


================================================
FILE: conda_recipe/conda.yaml
================================================
package:
  name: autograd
  # there are ways to derive version from other sources; for now, it's hard-coded
  version: 1.1.1

source:
  {% if not environ.get('BINSTAR_PLATFORM', None) %}
  git_url: ../
  {% else %}
  # we're building on binstar, we already have the repo; treat as local path
  path: ../
  {% endif %}

requirements:
  build:
    - python
    - hatch
    - hatchling
    - future
    - numpy >=1.9

  run:
    - python
    - future
    - numpy >=1.9

build:
  script: pip install . --no-deps

test:
  # Python imports
  imports:
    - autograd
    - autograd.numpy

about:
  home: https://github.com/HIPS/autograd
  license: MIT
  summary: 'Efficiently computes derivatives of numpy code.'


================================================
FILE: docs/tutorial.md
================================================
# Autograd tutorial

## Motivation

Imagine you want to test out a new machine learning model for your data. This
usually means coming up with some loss function to capture how well your model
fits the data and optimizing that loss with respect to the model parameters. If
there are many model parameters (neural nets can have millions) then you need
gradients. You then have two options: derive and code them up yourself, or
implement your model using the syntactic and semantic constraints of a system
like [Theano](http://deeplearning.net/software/theano/) or
[TensorFlow](https://github.com/tensorflow/tensorflow).

We want to provide a third way: just write down the loss function using a
standard numerical library like Numpy, and Autograd will give you its gradient.

## How to use Autograd

Autograd's `grad` function takes in a function, and gives you a function that computes its derivative.
Your function must have a scalar-valued output (i.e. a float).
This covers the common case when you want to use gradients to optimize something.

Autograd works on ordinary Python and Numpy code containing all the usual control structures, including `while` loops, `if` statements, and closures.  Here's a simple example of using an open-ended loop to compute the sine function:

```python
import autograd.numpy as np   # Thinly-wrapped version of Numpy
from autograd import grad

def taylor_sine(x):  # Taylor approximation to sine function
    ans = currterm = x
    i = 0
    while np.abs(currterm) > 0.001:
        currterm = -currterm * x**2 / ((2 * i + 3) * (2 * i + 2))
        ans = ans + currterm
        i += 1
    return ans

grad_sine = grad(taylor_sine)
print "Gradient of sin(pi) is", grad_sine(np.pi)
```

## Complete example: logistic regression

A common use case for automatic differentiation is to train a probabilistic model.
Here we present a very simple (but complete) example of specifying and training
a logistic regression model for binary classification:

```python
import autograd.numpy as np
from autograd import grad

def sigmoid(x):
    return 0.5 * (np.tanh(x / 2.) + 1)

def logistic_predictions(weights, inputs):
    # Outputs probability of a label being true according to logistic model.
    return sigmoid(np.dot(inputs, weights))

def training_loss(weights):
    # Training loss is the negative log-likelihood of the training labels.
    preds = logistic_predictions(weights, inputs)
    label_probabilities = preds * targets + (1 - preds) * (1 - targets)
    return -np.sum(np.log(label_probabilities))

# Build a toy dataset.
inputs = np.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])

# Define a function that returns gradients of training loss using Autograd.
training_gradient_fun = grad(training_loss)

# Optimize weights using gradient descent.
weights = np.array([0.0, 0.0, 0.0])
print("Initial loss:", training_loss(weights))
for i in range(100):
    weights -= training_gradient_fun(weights) * 0.01

print("Trained loss:", training_loss(weights))
```

Python syntax is pretty good for specifying probabilistic models.  The biggest
win is that it becomes a lot easier to modify a model and rapidly iterate.

For more complex examples, see our [examples directory](../examples/), which includes:
* [a simple neural net](../examples/neural_net.py)
* [a convolutional neural net](../examples/convnet.py)
* [a recurrent neural net](../examples/rnn.py)
* [a long short-term memory (LSTM)](../examples/lstm.py)
* [backpropagating through a fluid simulation](../examples/fluidsim/fluidsim.py)


## What's going on under the hood?

To compute the gradient, Autograd first has to record every transformation that was applied to the input as it was turned into the output of your function.
To do this, Autograd wraps functions (using the function `primitive`) so that when they're called, they add themselves to a list of operations performed.
Autograd's core has a table mapping these wrapped primitives to their corresponding gradient functions (or, more precisely, their vector-Jacobian product functions).
To flag the variables we're taking the gradient with respect to, we wrap them using the `Box` class.
You should never have to think about the `Box` class, but you might notice it when printing out debugging info.

After the function is evaluated, Autograd has a graph specifying all operations that were performed on the inputs with respect to which we want to differentiate.
This is the computational graph of the function evaluation.
To compute the derivative, we simply apply the rules of differentiation to each node in the graph.

### Reverse mode differentiation

Given a function made up of several nested function calls, there are several ways to compute its derivative.

For example, given L(x) = F(G(H(x))), the chain rule says that its gradient is dL/dx = dF/dG * dG/dH * dH/dx.  If we evaluate this product from right-to-left: (dF/dG * (dG/dH * dH/dx)), the same order as the computations themselves were performed, this is called forward-mode differentiation.
If we evaluate this product from left-to-right: ((dF/dG * dG/dH) * dH/dx), the reverse order as the computations themselves were performed, this is called reverse-mode differentiation.

Compared to finite differences or forward-mode, reverse-mode differentiation is by far the more practical method for differentiating functions that take in a large vector and output a single number.
In the machine learning community, reverse-mode differentiation is known as 'backpropagation', since the gradients propagate backwards through the function.
It's particularly nice since you don't need to instantiate the intermediate Jacobian matrices explicitly, and instead only rely on applying a sequence of matrix-free vector-Jacobian product functions (VJPs).
Because Autograd supports higher derivatives as well, Hessian-vector products (a form of second-derivative) are also available and efficient to compute.

### How can you support ifs, while loops and recursion?

Some autodiff packages (such as [TensorFlow](https://github.com/tensorflow/tensorflow)) work by having you specify a graph of the computation that your function performs, including all the control flow (such as if and for loops), and then turn that graph into another one that computes gradients.
This has some benefits (such as allowing compile-time optimizations), but it requires you to express control flow in a limited mini-language that those packages know how to handle.  (For example, the `tf.while` and `tf.cond` operations in TensorFlow.)

In contrast, Autograd doesn't have to know about any ifs, branches, loops or recursion that were used to decide which operations were called.  To compute the gradient of a particular input, one only needs to know which continuous transforms were applied to that particular input, not which other transforms might have been applied.
Since Autograd keeps track of the relevant operations on each function call separately, it's not a problem that all the Python control flow operations are invisible to Autograd.  In fact, it greatly simplifies the implementation.


## What can Autograd differentiate?

The main constraint is that any function that operates on a `Box` is marked as `primitive`, and has its gradient implemented.
This is taken care of for most functions in the Numpy library, and it's easy to write your own gradients.

The input can be a scalar, complex number, vector, tuple, a tuple of vectors, a tuple of tuples, etc.

When using the `grad` function, the output must be a scalar, but the functions `elementwise_grad` and `jacobian` allow gradients of vectors.


## Supported and unsupported parts of numpy/scipy

Numpy has [a lot of features](http://docs.scipy.org/doc/numpy/reference/). We've done our best to support most of them. So far, we've implemented gradients for:
* most of the [mathematical operations](../autograd/numpy/numpy_vjps.py)
* most of the [array and matrix manipulation routines](../autograd/numpy/numpy_vjps.py)
* some [linear algebra](../autograd/numpy/linalg.py) functions
* most of the [fast fourier transform](../autograd/numpy/fft.py) routines
* full support for complex numbers
* [N-dimensional convolutions](../autograd/scipy/signal.py)
* Some scipy routines, including [`scipy.stats.norm`](../autograd/scipy/stats/norm.py)

Some things remain to be implemented. For example, we support indexing (`x = A[i, j, :]`) but not assignment (`A[i,j] = x`) in arrays that are being differentiated with respect to.
Assignment is hard to support because it requires keeping copies of the overwritten data, and so even when you write code that looks like it's performing assignment, the system would have to be making copies behind the scenes, often defeating the purpose of in-place operations.

Similarly, we don't support the syntax `A.dot(B)`; use the equivalent `np.dot(A, B)` instead.
The reason we don't support the first way is that subclassing `ndarray` raises a host of issues.
As another consequence of not subclassing `ndarray`, some subclass checks can break, like `isinstance(x, np.ndarray)` can return `False`.
However, those `isinstance` checks will work if you instead use Autograd's provided one, writing `from autograd.builtins import isinstance`.

In-place modification of arrays not being differentiated with respect to (for example, `A[i] = x` or `A += B`) won't raise an error, but be careful.
It's easy to accidentally change something without Autograd knowing about it.
This can be a problem because Autograd keeps references to variables used in the forward pass if they will be needed on the reverse pass.
Making copies would be too slow.

Lists and dicts can be used freely - like control flow, Autograd usually doesn't even need to know about them.
The exception is passing in a list to a primitive function, such as `autograd.numpy.sum`.
This requires special care, since the list contents need to be examined for boxes.
We do support passing lists to `autograd.numpy.array` and `autograd.numpy.concatenate`, but in other cases, you may need to explicitly construct an array using `autograd.numpy.array` before passing a list or tuple argument into a primitive.
An alternative is to use the `list`, `dict`, and `tuple` classes in `autograd.builtins`, which should work just like the Python builtins while also ensuring boxes don't get hidden inside those containers.
Remember, these issues typically only come up when you're passing a `list` or `tuple` to a primitive function; when passing around lists or tuples in your own (non-primitive) functions, you can put boxed values inside lists, tuples, or dicts without having to worry about it.

#### TL;DR: Do use
* [Most](../autograd/numpy/numpy_vjps.py) of numpy's functions
* [Most](../autograd/numpy/numpy_boxes.py) numpy.ndarray methods
* [Some](../autograd/scipy/) scipy functions
* Indexing and slicing of arrays `x = A[3, :, 2:4]`
* Explicit array creation from lists `A = np.array([x, y])`

#### Don't use
* Assignment to arrays `A[0,0] = x`
* Implicit casting of lists to arrays `A = np.sum([x, y])`, use `A = np.sum(np.array([x, y]))` instead.
* `A.dot(B)` notation (use `np.dot(A, B)` instead)
* In-place operations (such as `a += b`, use `a = a + b` instead)
* Some isinstance checks, like `isinstance(x, np.ndarray)` or `isinstance(x, tuple)`, without first doing `from autograd.builtins import isinstance, tuple`.

Luckily, it's easy to check gradients numerically if you're worried that something's wrong.

## Extend Autograd by defining your own primitives

What if Autograd doesn't support a function you need to take the gradient of?
This can happen if your code depends on external library calls or C code.
It can sometimes even be a good idea to provide the gradient of a pure Python function for speed or numerical stability.

For example, let's add the gradient of a numerically stable version of `log(sum(exp(x)))`.
This function is included in `scipy.special` and already supported, but let's make our own version.

Next, we define our function using standard Python, using `@primitive` as a decorator:

```python
import autograd.numpy as np
from autograd.extend import primitive, defvjp

@primitive
def logsumexp(x):
    """Numerically stable log(sum(exp(x)))"""
    max_x = np.max(x)
    return max_x + np.log(np.sum(np.exp(x - max_x)))
```

`@primitive` tells Autograd not to look inside the function, but instead to treat it as a black box whose gradient can be specified later.
Functions with this decorator can contain anything that Python knows how to execute, including calls to other languages.

Next, we write a function that specifies the gradient of the primitive `logsumexp`:

```python
def logsumexp_vjp(ans, x):
    x_shape = x.shape
    return lambda g: np.full(x_shape, g) * np.exp(x - np.full(x_shape, ans))
```

`logsumexp_vjp` returns a vector-Jacobian product (VJP) operator, which is a function that right-multiplies its argument `g` by the Jacobian matrix of `logsumexp` (without explicitly forming the matrix's coefficients).
`g` will be the gradient of the final objective with respect to `ans` (the output of `logsumexp`).
The calculation can depend on both the input (`x`) and the output (`ans`) of the original function.
If you want to be able to take higher-order derivatives, then the code inside the VJP function must be itself differentiable by Autograd, which usually just means you write it in terms of other primitives which themselves have VJPs (like Numpy functions).

The final step is to tell Autograd about `logsumexp`'s vector-Jacobian product function:
```python
defvjp(logsumexp, logsumexp_vjp)
```

Now we can use `logsumexp` anywhere, including inside of a larger function that we want to differentiate:

```python
from autograd import grad

def example_func(y):
    z = y**2
    lse = logsumexp(z)
    return np.sum(lse)

grad_of_example = grad(example_func)
print "Gradient: ", grad_of_example(np.array([1.5, 6.7, 1e-10])
```

This example can be found as a Python script [here](../examples/define_gradient.py).

## Complex numbers

Autograd supports complex arrays and scalars using a convention described as follows.
Consider a complex-to-complex function, `f`,
expressed in terms of real-to-real components, `u` and `v`:

```python
def f(z):
    x, y = real(z), imag(z)
    return u(x, y) + v(x, y) * 1j
```

We define `grad` of `f` as

```python
def grad_f(z):
    x, y = real(z), imag(z)
    return grad(u, 0)(x, y) - i * grad(u, 1)(x, y)
```

(The second argument of `grad` specifies which argument we're differentiating with respect to.)
So we throw out v, the imaginary part of f, entirely.

Our convention covers three important cases:
  * If `f` is holomorphic, we get the usual complex derivative
    (since `grad(u, 0) == grad(v, 1)` and `grad(u, 1) == - grad(v, 0)`).
  * If `f` is a real-valued loss function of a complex parameter, `x`,
    we get a result that we can use in a gradient-based optimizer,
    by taking steps in the direction of the complex conjugate of `grad(f)(x)`.
  * If `f` is a real-to-real function that happens to use complex primitives internally,
    some of which must necessarily be non-holomorphic
    (maybe you use FFTs to implement convolutions for example)
    then we get the same result that a purely real implementation would have given.

Our convention doesn't handle the case where `f` is a non-holomorphic function
and you're interested in all of du/dx, du/dy, dv/dx and dv/dy.
But then the answer would have to contain four real values
and there would be no way to express it as a single complex number.

We define primitive vector-Jacobian products of complex functions like this

```python
def f_vjp(g, z):
    z_x, z_y = real(z), imag(z)
    g_x, g_y = real(g), imag(g)
    return (       g_x * grad(u, 0)(x, y)
             - i * g_x * grad(u, 1)(x, y)
             -     g_y * grad(v, 0)(x, y)
             + i * g_y * grad(v, 1)(x, y))
```

For holomorphic primitives, this is just the regular complex derivative multiplied by `g`,
so most simple math primitives don't need to be changed from their real implementations.
For non-holomorphic primitives, it preserves all four real partial derivatives as if we
were treating complex numbers as real 2-tuples
(though it throws a couple of negative signs in there).
Chapter 4 of [Dougal's PhD thesis](https://dougalmaclaurin.com/phd-thesis.pdf)
goes into a bit more detail about how we define the primitive vector-Jacobian products.

## Autograd Lecture
For more information on automatic differentiation, autograd's implementation, and advanced automatic differentiation techniques, see a [talk by Matt at the Deep Learning Summer School, Montreal 2017](https://videolectures.net/videos/deeplearning2017_johnson_automatic_differentiation/).

## Support

Autograd was written by
[Dougal Maclaurin](https://dougalmaclaurin.com),
[David Duvenaud](http://mlg.eng.cam.ac.uk/duvenaud/), and
[Matthew Johnson](http://www.mit.edu/~mattjj/)
and we're actively developing it. Please
feel free to submit any bugs or feature requests. We'd also love to hear about
your experiences with Autograd in general. Drop us an email!


================================================
FILE: docs/updateguide.md
================================================
# Autograd v1.2 update guide

Autograd v1.2 changed the interface for defining custom vector-Jacobian
products (VJPs). Luckily the change only affects users writing custom VJPs, and
should only require minor updates to the custom VJP code.

This guide is meant to explain why we made these changes (and others) in
Autograd v1.2, and to summarize everything you need to know to update your
custom VJP code.

Download .txt
gitextract_4gygwh8h/

├── .github/
│   └── workflows/
│       ├── check.yml
│       ├── publish.yml
│       └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── README.md
├── autograd/
│   ├── __init__.py
│   ├── builtins.py
│   ├── core.py
│   ├── differential_operators.py
│   ├── extend.py
│   ├── misc/
│   │   ├── __init__.py
│   │   ├── fixed_points.py
│   │   ├── flatten.py
│   │   ├── optimizers.py
│   │   └── tracers.py
│   ├── numpy/
│   │   ├── __init__.py
│   │   ├── fft.py
│   │   ├── linalg.py
│   │   ├── numpy_boxes.py
│   │   ├── numpy_jvps.py
│   │   ├── numpy_vjps.py
│   │   ├── numpy_vspaces.py
│   │   ├── numpy_wrapper.py
│   │   └── random.py
│   ├── scipy/
│   │   ├── __init__.py
│   │   ├── integrate.py
│   │   ├── linalg.py
│   │   ├── signal.py
│   │   ├── special.py
│   │   └── stats/
│   │       ├── __init__.py
│   │       ├── beta.py
│   │       ├── chi2.py
│   │       ├── dirichlet.py
│   │       ├── gamma.py
│   │       ├── multivariate_normal.py
│   │       ├── norm.py
│   │       ├── poisson.py
│   │       └── t.py
│   ├── test_util.py
│   ├── tracer.py
│   ├── util.py
│   └── wrap_util.py
├── benchmarks/
│   ├── __init__.py
│   ├── asv.conf.json.sample
│   ├── bench_core.py
│   ├── bench_mem.py
│   ├── bench_numpy_vjps.py
│   ├── bench_rnn.py
│   └── bench_util.py
├── conda_recipe/
│   └── conda.yaml
├── docs/
│   ├── tutorial.md
│   └── updateguide.md
├── examples/
│   ├── README.md
│   ├── __init__.py
│   ├── bayesian_neural_net.py
│   ├── bayesian_optimization.py
│   ├── black_box_svi.py
│   ├── convnet.py
│   ├── data.py
│   ├── data_mnist.py
│   ├── deep_gaussian_process.py
│   ├── define_gradient.py
│   ├── dot_graph.py
│   ├── fixed_points.py
│   ├── fluidsim/
│   │   ├── fluidsim.py
│   │   └── wing.py
│   ├── gaussian_process.py
│   ├── generative_adversarial_net.py
│   ├── gmm.py
│   ├── gplvm.py
│   ├── hmm_em.py
│   ├── ica.py
│   ├── logistic_regression.py
│   ├── lstm.py
│   ├── mixture_variational_inference.py
│   ├── natural_gradient_black_box_svi.py
│   ├── negative_binomial_maxlike.py
│   ├── neural_net.py
│   ├── neural_net_regression.py
│   ├── ode_net.py
│   ├── print_trace.py
│   ├── rkhs.py
│   ├── rnn.py
│   ├── rosenbrock.py
│   ├── sinusoid.py
│   ├── tanh.py
│   └── variational_autoencoder.py
├── license.txt
├── noxfile.py
├── pyproject.toml
└── tests/
    ├── _test_complexity.py
    ├── check_examples_run.sh
    ├── conftest.py
    ├── numpy_utils.py
    ├── profiling.py
    ├── test_binary_ops.py
    ├── test_builtins.py
    ├── test_complex.py
    ├── test_core.py
    ├── test_dict.py
    ├── test_direct.py
    ├── test_fft.py
    ├── test_graphs.py
    ├── test_jacobian.py
    ├── test_linalg.py
    ├── test_list.py
    ├── test_logic.py
    ├── test_misc.py
    ├── test_numpy.py
    ├── test_performance.py
    ├── test_scalar_ops.py
    ├── test_scipy.py
    ├── test_systematic.py
    ├── test_tests.py
    ├── test_truediv.py
    ├── test_tuple.py
    ├── test_vspaces.py
    └── test_wrappers.py
Download .txt
SYMBOL INDEX (1282 symbols across 93 files)

FILE: autograd/builtins.py
  function container_take (line 25) | def container_take(A, idx):
  function grad_container_take (line 29) | def grad_container_take(ans, A, idx):
  class SequenceBox (line 37) | class SequenceBox(Box):
    method __len__ (line 41) | def __len__(self):
    method __add__ (line 44) | def __add__(self, other):
    method __radd__ (line 47) | def __radd__(self, other):
    method __contains__ (line 50) | def __contains__(self, elt):
    method index (line 53) | def index(self, elt):
  class DictBox (line 61) | class DictBox(Box):
    method __len__ (line 65) | def __len__(self):
    method __iter__ (line 68) | def __iter__(self):
    method __contains__ (line 71) | def __contains__(self, elt):
    method items (line 74) | def items(self):
    method keys (line 77) | def keys(self):
    method values (line 80) | def values(self):
    method iteritems (line 83) | def iteritems(self):
    method iterkeys (line 86) | def iterkeys(self):
    method itervalues (line 89) | def itervalues(self):
    method get (line 92) | def get(self, k, d=None):
  function container_untake (line 100) | def container_untake(x, idx, vs):
  function sequence_extend_right (line 117) | def sequence_extend_right(seq, *elts):
  function grad_sequence_extend_right (line 121) | def grad_sequence_extend_right(argnum, ans, args, kwargs):
  function sequence_extend_left (line 130) | def sequence_extend_left(seq, *elts):
  function grad_sequence_extend_left (line 134) | def grad_sequence_extend_left(argnum, ans, args, kwargs):
  function make_sequence (line 143) | def make_sequence(seq_type, *args):
  function fwd_grad_make_sequence (line 150) | def fwd_grad_make_sequence(argnum, g, ans, seq_type, *args, **kwargs):
  class TupleMeta (line 157) | class TupleMeta(type(tuple_)):
    method __instancecheck__ (line 158) | def __instancecheck__(self, instance):
  class tuple (line 162) | class tuple(tuple_, metaclass=TupleMeta):
    method __new__ (line 163) | def __new__(cls, xs):
  class ListMeta (line 167) | class ListMeta(type_):
    method __instancecheck__ (line 168) | def __instancecheck__(self, instance):
  class list (line 172) | class list(list_, metaclass=ListMeta):
    method __new__ (line 173) | def __new__(cls, xs):
  class DictMeta (line 177) | class DictMeta(type_):
    method __instancecheck__ (line 178) | def __instancecheck__(self, instance):
  class dict (line 182) | class dict(dict_, metaclass=DictMeta):
    method __new__ (line 183) | def __new__(cls, *args, **kwargs):
  function _make_dict (line 191) | def _make_dict(keys, vals):
  class ContainerVSpace (line 198) | class ContainerVSpace(VSpace):
    method __init__ (line 199) | def __init__(self, value):
    method size (line 204) | def size(self):
    method zeros (line 207) | def zeros(self):
    method ones (line 210) | def ones(self):
    method randn (line 213) | def randn(self):
    method standard_basis (line 216) | def standard_basis(self):
    method _add (line 222) | def _add(self, xs, ys):
    method _mut_add (line 225) | def _mut_add(self, xs, ys):
    method _scalar_mul (line 228) | def _scalar_mul(self, xs, a):
    method _inner_prod (line 231) | def _inner_prod(self, xs, ys):
    method _covector (line 234) | def _covector(self, xs):
  class SequenceVSpace (line 238) | class SequenceVSpace(ContainerVSpace):
    method _values (line 239) | def _values(self, x):
    method _kv_pairs (line 242) | def _kv_pairs(self, x):
    method _map (line 245) | def _map(self, f, *args):
    method _subval (line 248) | def _subval(self, xs, idx, x):
  class ListVSpace (line 252) | class ListVSpace(SequenceVSpace):
  class TupleVSpace (line 256) | class TupleVSpace(SequenceVSpace):
  class DictVSpace (line 260) | class DictVSpace(ContainerVSpace):
    method _values (line 261) | def _values(self, x):
    method _kv_pairs (line 264) | def _kv_pairs(self, x):
    method _map (line 267) | def _map(self, f, *args):
    method _subval (line 270) | def _subval(self, xs, idx, x):
  class NamedTupleVSpace (line 281) | class NamedTupleVSpace(SequenceVSpace):
    method _map (line 282) | def _map(self, f, *args):
    method _subval (line 285) | def _subval(self, xs, idx, x):

FILE: autograd/core.py
  function make_vjp (line 10) | def make_vjp(fun, x):
  function backward_pass (line 25) | def backward_pass(g, end_node):
  class VJPNode (line 35) | class VJPNode(Node):
    method __init__ (line 38) | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
    method initialize_root (line 47) | def initialize_root(self):
  function defvjp_argnums (line 55) | def defvjp_argnums(fun, vjpmaker):
  function defvjp_argnum (line 59) | def defvjp_argnum(fun, vjpmaker):
  function defvjp (line 67) | def defvjp(fun, *vjpmakers, **kwargs):
  function translate_vjp (line 101) | def translate_vjp(vjpfun, fun, argnum):
  function make_jvp (line 113) | def make_jvp(fun, x):
  class JVPNode (line 125) | class JVPNode(Node):
    method __init__ (line 128) | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
    method initialize_root (line 137) | def initialize_root(self, g):
  function defjvp_argnums (line 144) | def defjvp_argnums(fun, jvpmaker):
  function defjvp_argnum (line 148) | def defjvp_argnum(fun, jvpmaker):
  function defjvp (line 155) | def defjvp(fun, *jvpfuns, **kwargs):
  function translate_jvp (line 165) | def translate_jvp(jvpfun, fun, argnum):
  function def_linear (line 176) | def def_linear(fun):
  function add_outgrads (line 184) | def add_outgrads(prev_g_flagged, g):
  function sum_outgrads (line 207) | def sum_outgrads(gs):
  function sparse_add (line 212) | def sparse_add(vs, x_prev, x_new):
  class VSpace (line 217) | class VSpace:
    method __init__ (line 222) | def __init__(self, value):
    method zeros (line 225) | def zeros(self):
    method ones (line 228) | def ones(self):
    method standard_basis (line 231) | def standard_basis(self):
    method randn (line 234) | def randn(self):
    method mut_add (line 238) | def mut_add(self, x_prev, x_new):
    method add (line 243) | def add(self, x_prev, x_new):
    method scalar_mul (line 247) | def scalar_mul(self, x, a):
    method inner_prod (line 251) | def inner_prod(self, x, y):
    method covector (line 255) | def covector(self, x):
    method _add (line 258) | def _add(self, x, y):
    method _mut_add (line 261) | def _mut_add(self, x, y):
    method _scalar_mul (line 265) | def _scalar_mul(self, x, a):
    method _inner_prod (line 268) | def _inner_prod(self, x, y):
    method _covector (line 271) | def _covector(self, x):
    method __eq__ (line 274) | def __eq__(self, other):
    method __repr__ (line 277) | def __repr__(self):
    method register (line 281) | def register(cls, value_type, vspace_maker=None):
  function vspace (line 288) | def vspace(value):
  class SparseBox (line 302) | class SparseBox(Box):
  class SparseObject (line 306) | class SparseObject:
    method __init__ (line 309) | def __init__(self, vs, mut_add):
  function deprecated_defvjp (line 358) | def deprecated_defvjp(primitive_fun):
  function deprecated_defvjp_is_zero (line 379) | def deprecated_defvjp_is_zero(primitive_fun):
  function deprecated_defgrad (line 392) | def deprecated_defgrad(primitive_fun):
  function primitive_with_deprecation_warnings (line 408) | def primitive_with_deprecation_warnings(f_raw):

FILE: autograd/differential_operators.py
  function grad (line 24) | def grad(fun, x):
  function elementwise_grad (line 40) | def elementwise_grad(fun, x):
  function deriv (line 53) | def deriv(fun, x):
  function jacobian (line 58) | def jacobian(fun, x):
  function holomorphic_grad (line 75) | def holomorphic_grad(fun, x):
  function grad_named (line 81) | def grad_named(fun, argname):
  function hessian (line 89) | def hessian(fun, x):
  function make_hvp (line 95) | def make_hvp(fun, x):
  function hessian_tensor_product (line 102) | def hessian_tensor_product(fun, argnum=0):
  function tensor_jacobian_product (line 118) | def tensor_jacobian_product(fun, argnum=0):
  function make_jvp_reversemode (line 134) | def make_jvp_reversemode(fun, x):
  function make_ggnvp (line 145) | def make_ggnvp(f, g=lambda x: 1.0 / 2 * np.sum(x**2, axis=-1), f_argnum=0):
  function value_and_grad (line 164) | def value_and_grad(fun, x):
  function grad_and_aux (line 178) | def grad_and_aux(fun, x):
  function multigrad_dict (line 185) | def multigrad_dict(fun):
  function checkpoint (line 230) | def checkpoint(fun):

FILE: autograd/misc/fixed_points.py
  function fixed_point (line 7) | def fixed_point(f, a, x0, distance, tol):
  function fixed_point_vjp (line 15) | def fixed_point_vjp(ans, f, a, x0, distance, tol):

FILE: autograd/misc/flatten.py
  function flatten (line 11) | def flatten(value):
  function _flatten (line 20) | def _flatten(value):
  function _concatenate (line 30) | def _concatenate(lst):
  function flatten_func (line 35) | def flatten_func(func, example):

FILE: autograd/misc/optimizers.py
  function unflatten_optimizer (line 15) | def unflatten_optimizer(optimize):
  function sgd (line 34) | def sgd(grad, x, callback=None, num_iters=200, step_size=0.1, mass=0.9):
  function rmsprop (line 48) | def rmsprop(grad, x, callback=None, num_iters=100, step_size=0.1, gamma=...
  function adam (line 61) | def adam(grad, x, callback=None, num_iters=100, step_size=0.001, b1=0.9,...

FILE: autograd/misc/tracers.py
  class ConstGraphNode (line 9) | class ConstGraphNode(Node):
    method __init__ (line 12) | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
    method initialize_root (line 21) | def initialize_root(self):
  function const_graph_unary (line 25) | def const_graph_unary(fun):
  function const_graph (line 47) | def const_graph(fun, *args, **kwargs):
  class FullGraphNode (line 59) | class FullGraphNode(Node):
    method __init__ (line 62) | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
    method initialize_root (line 66) | def initialize_root(self):
  function full_graph (line 71) | def full_graph(fun, *args, **kwargs):

FILE: autograd/numpy/fft.py
  function fft_grad (line 14) | def fft_grad(get_args, fft_fun, ans, x, *args, **kwargs):
  function rfft_grad (line 31) | def rfft_grad(get_args, irfft_fun, ans, x, *args, **kwargs):
  function irfft_grad (line 54) | def irfft_grad(get_args, rfft_fun, ans, x, *args, **kwargs):
  function truncate_pad (line 98) | def truncate_pad(x, shape):
  function check_no_repeated_axes (line 111) | def check_no_repeated_axes(axes):
  function check_even_shape (line 117) | def check_even_shape(shape):
  function get_fft_args (line 122) | def get_fft_args(a, d=None, axis=-1, norm=None, *args, **kwargs):
  function get_fft2_args (line 129) | def get_fft2_args(a, s=None, axes=(-2, -1), norm=None, *args, **kwargs):
  function get_fftn_args (line 133) | def get_fftn_args(a, s=None, axes=None, norm=None, *args, **kwargs):
  function make_rfft_factors (line 139) | def make_rfft_factors(axes, resshape, facshape, normshape, norm):

FILE: autograd/numpy/linalg.py
  function T (line 20) | def T(x):
  function _matrix_diag (line 31) | def _matrix_diag(a):
  function add2d (line 40) | def add2d(x):
  function grad_inv (line 48) | def grad_inv(ans, x):
  function grad_pinv (line 55) | def grad_pinv(ans, x):
  function grad_solve (line 67) | def grad_solve(argnum, ans, a, b):
  function norm_vjp (line 78) | def norm_vjp(ans, x, ord=None, axis=None):
  function norm_jvp (line 134) | def norm_jvp(g, ans, x, ord=None, axis=None):
  function grad_eigh (line 180) | def grad_eigh(ans, x, UPLO="L"):
  function grad_eig (line 221) | def grad_eig(ans, x):
  function grad_cholesky (line 247) | def grad_cholesky(L, A):
  function grad_svd (line 271) | def grad_svd(usv_, a, full_matrices=True, compute_uv=True):

FILE: autograd/numpy/numpy_boxes.py
  class ArrayBox (line 11) | class ArrayBox(Box):
    method __getitem__ (line 16) | def __getitem__(A, idx):
    method __array_namespace__ (line 26) | def __array_namespace__(self, *, api_version: str | None = None):
    method __len__ (line 29) | def __len__(self):
    method astype (line 32) | def astype(self, *args, **kwargs):
    method __neg__ (line 35) | def __neg__(self):
    method __add__ (line 38) | def __add__(self, other):
    method __sub__ (line 41) | def __sub__(self, other):
    method __mul__ (line 44) | def __mul__(self, other):
    method __pow__ (line 47) | def __pow__(self, other):
    method __div__ (line 50) | def __div__(self, other):
    method __mod__ (line 53) | def __mod__(self, other):
    method __truediv__ (line 56) | def __truediv__(self, other):
    method __matmul__ (line 59) | def __matmul__(self, other):
    method __radd__ (line 62) | def __radd__(self, other):
    method __rsub__ (line 65) | def __rsub__(self, other):
    method __rmul__ (line 68) | def __rmul__(self, other):
    method __rpow__ (line 71) | def __rpow__(self, other):
    method __rdiv__ (line 74) | def __rdiv__(self, other):
    method __rmod__ (line 77) | def __rmod__(self, other):
    method __rtruediv__ (line 80) | def __rtruediv__(self, other):
    method __rmatmul__ (line 83) | def __rmatmul__(self, other):
    method __eq__ (line 86) | def __eq__(self, other):
    method __ne__ (line 89) | def __ne__(self, other):
    method __gt__ (line 92) | def __gt__(self, other):
    method __ge__ (line 95) | def __ge__(self, other):
    method __lt__ (line 98) | def __lt__(self, other):
    method __le__ (line 101) | def __le__(self, other):
    method __abs__ (line 104) | def __abs__(self):
    method __hash__ (line 107) | def __hash__(self):

FILE: autograd/numpy/numpy_jvps.py
  function forward_grad_np_var (line 173) | def forward_grad_np_var(g, ans, x, axis=None, ddof=0, keepdims=False):
  function forward_grad_np_std (line 188) | def forward_grad_np_std(g, ans, x, axis=None, ddof=0, keepdims=False):
  function fwd_grad_chooser (line 205) | def fwd_grad_chooser(g, ans, x, axis=None, keepdims=False):
  function fwd_grad_concatenate_args (line 240) | def fwd_grad_concatenate_args(argnum, g, ans, axis_args, kwargs):
  function fwd_grad_sort (line 253) | def fwd_grad_sort(g, ans, x, axis=-1, kind="quicksort", order=None):
  function fwd_grad_partition (line 263) | def fwd_grad_partition(g, ans, x, kth, axis=-1, kind="introselect", orde...
  function atleast_jvpmaker (line 271) | def atleast_jvpmaker(fun):
  function broadcast (line 288) | def broadcast(x, target):

FILE: autograd/numpy/numpy_vjps.py
  function grad_rollaxis (line 255) | def grad_rollaxis(ans, a, axis, start=0):
  function grad_diff (line 270) | def grad_diff(ans, a, n=1, axis=-1):
  function grad_gradient (line 297) | def grad_gradient(ans, x, *vargs, **kwargs):
  function grad_repeat (line 345) | def grad_repeat(ans, x, repeats, axis=None):
  function grad_tile (line 365) | def grad_tile(ans, x, reps):
  function grad_kron (line 380) | def grad_kron(argnum, ans, orig_A, orig_B):
  function grad_transpose (line 407) | def grad_transpose(ans, x, axes=None):
  function repeat_to_match_shape (line 416) | def repeat_to_match_shape(g, shape, dtype, axis, keepdims):
  function grad_broadcast_to (line 430) | def grad_broadcast_to(ans, x, new_shape):
  function grad_np_sum (line 443) | def grad_np_sum(ans, x, axis=None, keepdims=False, dtype=None):
  function grad_np_mean (line 451) | def grad_np_mean(ans, x, axis=None, keepdims=False):
  function grad_np_prod (line 464) | def grad_np_prod(ans, x, axis=None, keepdims=False):  # TODO: Support tu...
  function grad_np_var (line 477) | def grad_np_var(ans, x, axis=None, ddof=0, keepdims=False):
  function grad_np_std (line 493) | def grad_np_std(ans, x, axis=None, ddof=0, keepdims=False):
  function grad_chooser (line 515) | def grad_chooser(ans, x, axis=None, keepdims=None):
  function reverse_axis (line 533) | def reverse_axis(x, axis):
  function grad_np_cumsum (line 539) | def grad_np_cumsum(ans, x, axis=None):
  function grad_inner (line 552) | def grad_inner(argnum, ans, A, B):
  function matmul_adjoint_0 (line 567) | def matmul_adjoint_0(B, G, A_meta, B_ndim):
  function matmul_adjoint_1 (line 582) | def matmul_adjoint_1(A, G, A_ndim, B_meta):
  function matmul_vjp_0 (line 600) | def matmul_vjp_0(ans, A, B):
  function matmul_vjp_1 (line 606) | def matmul_vjp_1(ans, A, B):
  function dot_adjoint_0 (line 616) | def dot_adjoint_0(B, G, A_meta, B_meta):
  function dot_adjoint_1 (line 628) | def dot_adjoint_1(A, G, A_meta, B_meta):
  function dot_vjp_0 (line 641) | def dot_vjp_0(ans, A, B):
  function dot_vjp_1 (line 646) | def dot_vjp_1(ans, A, B):
  function tensordot_adjoint_0 (line 667) | def tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim):
  function tensordot_adjoint_1 (line 695) | def tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim):
  function tensordot_vjp_0 (line 722) | def tensordot_vjp_0(ans, A, B, axes=2):
  function tensordot_vjp_1 (line 727) | def tensordot_vjp_1(ans, A, B, axes=2):
  function grad_concatenate_args (line 750) | def grad_concatenate_args(argnum, ans, axis_args, kwargs):
  function wrapped_reshape (line 762) | def wrapped_reshape(x, *args, **kwargs):
  function grad_sort (line 774) | def grad_sort(ans, x, axis=-1, kind="quicksort", order=None):
  function grad_partition (line 787) | def grad_partition(ans, x, kth, axis=-1, kind="introselect", order=None):
  function unpermuter (line 798) | def unpermuter(g, permutation):
  function grad_reshape_list (line 804) | def grad_reshape_list(ans, *arys):
  function grad_einsum (line 815) | def grad_einsum(argnum, ans, operands_, kwargs):
  function match_complex (line 872) | def match_complex(target, x):
  function unbroadcast (line 883) | def unbroadcast(x, target_meta, broadcast_idx=0):
  function unbroadcast_f (line 895) | def unbroadcast_f(target, f):
  function unbroadcast_einsum (line 900) | def unbroadcast_einsum(x, target_meta, subscript):
  function balanced_eq (line 911) | def balanced_eq(x, z, y):
  function replace_zero (line 915) | def replace_zero(x, val):
  function array_from_args_gradmaker (line 922) | def array_from_args_gradmaker(argnum, ans, args, kwargs):
  function array_from_scalar_or_array_gradmaker (line 929) | def array_from_scalar_or_array_gradmaker(ans, array_args, array_kwargs, ...
  function untake (line 942) | def untake(x, idx, vs):
  function _unpad (line 957) | def _unpad(array, width):
  function pad_vjp (line 970) | def pad_vjp(ans, array, pad_width, mode, **kwargs):

FILE: autograd/numpy/numpy_vspaces.py
  class ArrayVSpace (line 7) | class ArrayVSpace(VSpace):
    method __init__ (line 8) | def __init__(self, value):
    method size (line 14) | def size(self):
    method ndim (line 18) | def ndim(self):
    method zeros (line 21) | def zeros(self):
    method ones (line 24) | def ones(self):
    method standard_basis (line 27) | def standard_basis(self):
    method randn (line 33) | def randn(self):
    method _inner_prod (line 36) | def _inner_prod(self, x, y):
  class ComplexArrayVSpace (line 40) | class ComplexArrayVSpace(ArrayVSpace):
    method size (line 44) | def size(self):
    method ones (line 47) | def ones(self):
    method standard_basis (line 50) | def standard_basis(self):
    method randn (line 57) | def randn(self):
    method _inner_prod (line 62) | def _inner_prod(self, x, y):
    method _covector (line 65) | def _covector(self, x):
  class EigResultVSpace (line 80) | class EigResultVSpace(NamedTupleVSpace):
  class EighResultVSpace (line 83) | class EighResultVSpace(NamedTupleVSpace):
  class QRResultVSpace (line 86) | class QRResultVSpace(NamedTupleVSpace):
  class SlogdetResultVSpace (line 89) | class SlogdetResultVSpace(NamedTupleVSpace):
  class SVDResultVSpace (line 92) | class SVDResultVSpace(NamedTupleVSpace):
  class EigResultVSpace (line 102) | class EigResultVSpace(NamedTupleVSpace):
  class EighResultVSpace (line 105) | class EighResultVSpace(NamedTupleVSpace):
  class QRResultVSpace (line 108) | class QRResultVSpace(NamedTupleVSpace):
  class SlogdetResultVSpace (line 111) | class SlogdetResultVSpace(NamedTupleVSpace):
  class SVDResultVSpace (line 114) | class SVDResultVSpace(NamedTupleVSpace):

FILE: autograd/numpy/numpy_wrapper.py
  function wrap_intdtype (line 18) | def wrap_intdtype(cls):
  function wrap_namespace (line 25) | def wrap_namespace(old, new):
  function concatenate_args (line 45) | def concatenate_args(axis, *args):
  function hstack (line 53) | def hstack(tup):
  function column_stack (line 60) | def column_stack(tup):
  function array (line 70) | def array(A, *args, **kwargs):
  function wrap_if_boxes_inside (line 78) | def wrap_if_boxes_inside(raw_array, slow_op_name=None):
  function _array_from_scalar_or_array (line 88) | def _array_from_scalar_or_array(array_args, array_kwargs, scalar):
  function array_from_args (line 93) | def array_from_args(array_args, array_kwargs, *args):
  function select (line 97) | def select(condlist, choicelist, default=0):
  function stack (line 102) | def stack(arrays, axis=0):
  function append (line 125) | def append(arr, values, axis=None):
  class r_class (line 139) | class r_class:
    method __getitem__ (line 140) | def __getitem__(self, args):
  class c_class (line 148) | class c_class:
    method __getitem__ (line 149) | def __getitem__(self, args):
  function make_diagonal (line 159) | def make_diagonal(D, offset=0, axis1=0, axis2=1):
  function metadata (line 176) | def metadata(A):
  function parse_einsum_input (line 181) | def parse_einsum_input(*args):
  function _astype (line 191) | def _astype(A, dtype, order="K", casting="unsafe", subok=True, copy=True):

FILE: autograd/scipy/integrate.py
  function grad_odeint (line 12) | def grad_odeint(yt, func, y0, t, func_args, **kwargs):
  function argnums_unpack (line 65) | def argnums_unpack(all_vjp_builder):

FILE: autograd/scipy/linalg.py
  function _vjp_sqrtm (line 12) | def _vjp_sqrtm(ans, A, disp=True, blocksize=64):
  function _flip (line 25) | def _flip(a, trans):
  function grad_solve_triangular (line 32) | def grad_solve_triangular(ans, a, b, trans=0, lower=False, **kwargs):
  function grad_solve_banded (line 53) | def grad_solve_banded(argnum, ans, l_and_u, a, b):
  function _jvp_sqrtm (line 107) | def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
  function _jvp_sylvester (line 115) | def _jvp_sylvester(argnums, dms, ans, args, _):
  function _vjp_sylvester (line 131) | def _vjp_sylvester(argnums, ans, args, _):

FILE: autograd/scipy/signal.py
  function convolve (line 11) | def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
  function einsum_tensordot (line 46) | def einsum_tensordot(A, B, axes, reverse=False):
  function pad_to_full (line 57) | def pad_to_full(A, B, axes):
  function parse_axes (line 64) | def parse_axes(A_shape, B_shape, conv_axes, dot_axes, mode):
  function compute_conv_size (line 108) | def compute_conv_size(A_size, B_size, mode):
  function flipped_idxs (line 119) | def flipped_idxs(ndim, axes):
  function grad_convolve (line 126) | def grad_convolve(argnum, ans, A, B, axes=None, dot_axes=[(), ()], mode=...

FILE: autograd/scipy/special.py
  function make_gammainc_vjp_arg1 (line 56) | def make_gammainc_vjp_arg1(sign):
  function make_grad_logsumexp (line 123) | def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False):
  function fwd_grad_logsumexp (line 137) | def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False):

FILE: autograd/scipy/stats/beta.py
  function grad_beta_logpdf_arg0 (line 13) | def grad_beta_logpdf_arg0(x, a, b):
  function grad_beta_logpdf_arg1 (line 17) | def grad_beta_logpdf_arg1(x, a, b):
  function grad_beta_logpdf_arg2 (line 21) | def grad_beta_logpdf_arg2(x, a, b):

FILE: autograd/scipy/stats/chi2.py
  function grad_chi2_logpdf (line 13) | def grad_chi2_logpdf(x, df):

FILE: autograd/scipy/stats/gamma.py
  function grad_gamma_logpdf_arg0 (line 13) | def grad_gamma_logpdf_arg0(x, a):
  function grad_gamma_logpdf_arg1 (line 17) | def grad_gamma_logpdf_arg1(x, a):

FILE: autograd/scipy/stats/multivariate_normal.py
  function generalized_outer_product (line 19) | def generalized_outer_product(x):
  function covgrad (line 25) | def covgrad(x, mean, cov, allow_singular=False):
  function solve (line 35) | def solve(allow_singular):

FILE: autograd/scipy/stats/poisson.py
  function grad_poisson_logpmf (line 12) | def grad_poisson_logpmf(k, mu):

FILE: autograd/scipy/stats/t.py
  function grad_tlogpdf_diff (line 16) | def grad_tlogpdf_diff(diff, df):
  function grad_tlogpdf_x (line 20) | def grad_tlogpdf_x(x, df, loc, scale):
  function grad_tlogpdf_loc (line 24) | def grad_tlogpdf_loc(x, df, loc, scale):
  function grad_tlogpdf_scale (line 28) | def grad_tlogpdf_scale(x, df, loc, scale):
  function grad_tlogpdf_df (line 33) | def grad_tlogpdf_df(x, df, loc, scale):

FILE: autograd/test_util.py
  function scalar_close (line 10) | def scalar_close(a, b):
  function make_numerical_jvp (line 17) | def make_numerical_jvp(f, x):
  function check_vjp (line 31) | def check_vjp(f, x):
  function check_jvp (line 48) | def check_jvp(f, x):
  function check_equivalent (line 55) | def check_equivalent(x, y):
  function check_grads (line 63) | def check_grads(f, x, modes=["fwd", "rev"], order=2):
  function combo_check (line 81) | def combo_check(fun, *args, **kwargs):

FILE: autograd/tracer.py
  function trace (line 9) | def trace(start_node, fun, x):
  class Node (line 20) | class Node:
    method __init__ (line 23) | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
    method initialize_root (line 26) | def initialize_root(self, *args, **kwargs):
    method new_root (line 30) | def new_root(cls, *args, **kwargs):
  function primitive (line 36) | def primitive(f_raw):
  function register_notrace (line 64) | def register_notrace(trace_type, primitive_fun):
  function notrace_primitive (line 68) | def notrace_primitive(f_raw):
  function find_top_boxed_args (line 78) | def find_top_boxed_args(args):
  class TraceStack (line 94) | class TraceStack:
    method __init__ (line 95) | def __init__(self):
    method new_trace (line 99) | def new_trace(self):
  class Box (line 108) | class Box:
    method __init__ (line 114) | def __init__(self, value, trace, node):
    method __bool__ (line 119) | def __bool__(self):
    method __str__ (line 124) | def __str__(self):
    method register (line 128) | def register(cls, value_type):
  function new_box (line 137) | def new_box(value, trace, node):

FILE: autograd/util.py
  function subvals (line 4) | def subvals(x, ivs):
  function subval (line 11) | def subval(x, i, v):
  function func (line 17) | def func(f):
  function toposort (line 21) | def toposort(end_node, parents=operator.attrgetter("parents")):
  function quick_grad_check (line 52) | def quick_grad_check(

FILE: autograd/wrap_util.py
  function unary_to_nary (line 4) | def unary_to_nary(unary_operator):
  function wraps (line 30) | def wraps(fun, namestr="{fun}", docstr="{doc}", **kwargs):
  function wrap_nary_f (line 42) | def wrap_nary_f(fun, op, argnum):

FILE: benchmarks/bench_core.py
  function f_short (line 18) | def f_short(x):
  function time_short_fun (line 22) | def time_short_fun():
  function time_short_forward_pass (line 26) | def time_short_forward_pass():
  function time_short_backward_pass (line 34) | def time_short_backward_pass():
  function time_short_grad (line 41) | def time_short_grad():
  function f_long (line 46) | def f_long(x):
  function time_long_fun (line 52) | def time_long_fun():
  function time_long_forward_pass (line 56) | def time_long_forward_pass():
  function time_long_backward_pass (line 64) | def time_long_backward_pass():
  function time_long_grad (line 71) | def time_long_grad():
  function fan_out_fan_in (line 76) | def fan_out_fan_in(x):
  function time_fan_out_fan_in_fun (line 82) | def time_fan_out_fan_in_fun():
  function time_fan_out_fan_in_forward_pass (line 86) | def time_fan_out_fan_in_forward_pass():
  function time_fan_out_fan_in_backward_pass (line 94) | def time_fan_out_fan_in_backward_pass():
  function time_fan_out_fan_in_grad (line 101) | def time_fan_out_fan_in_grad():
  function time_vspace_float (line 106) | def time_vspace_float():
  function time_vspace_array (line 113) | def time_vspace_array():
  function time_new_box_float (line 117) | def time_new_box_float():
  function time_new_box_array (line 121) | def time_new_box_array():
  function time_exp_call (line 125) | def time_exp_call():
  function time_exp_primitive_call_unboxed (line 129) | def time_exp_primitive_call_unboxed():
  function time_exp_primitive_call_boxed (line 133) | def time_exp_primitive_call_boxed():
  function time_no_autograd_control (line 140) | def time_no_autograd_control():

FILE: benchmarks/bench_mem.py
  function peakmem_needless_nodes (line 5) | def peakmem_needless_nodes():

FILE: benchmarks/bench_numpy_vjps.py
  function time_dot_0 (line 21) | def time_dot_0():
  function time_dot_1 (line 25) | def time_dot_1():
  function time_dot_0_0 (line 29) | def time_dot_0_0():
  function time_dot_0_1 (line 33) | def time_dot_0_1():
  function time_dot_0_2 (line 37) | def time_dot_0_2():
  function time_dot_1_0 (line 41) | def time_dot_1_0():
  function time_dot_1_1 (line 45) | def time_dot_1_1():
  function time_dot_1_2 (line 49) | def time_dot_1_2():
  function time_tensordot_0 (line 69) | def time_tensordot_0():
  function time_tensordot_1 (line 73) | def time_tensordot_1():
  function time_tensordot_0_0 (line 77) | def time_tensordot_0_0():
  function time_tensordot_0_1 (line 81) | def time_tensordot_0_1():
  function time_tensordot_0_2 (line 85) | def time_tensordot_0_2():
  function time_tensordot_1_0 (line 89) | def time_tensordot_1_0():
  function time_tensordot_1_1 (line 93) | def time_tensordot_1_1():
  function time_tensordot_1_2 (line 97) | def time_tensordot_1_2():

FILE: benchmarks/bench_rnn.py
  class RNNSuite (line 8) | class RNNSuite:
    method setup (line 15) | def setup(self):
    method rnn_grad (line 39) | def rnn_grad(self):
    method time_rnn_grad (line 42) | def time_rnn_grad(self):
    method peakmem_rnn_grad (line 45) | def peakmem_rnn_grad(self):
    method time_manual_rnn_grad (line 48) | def time_manual_rnn_grad(self):
    method peakmem_manual_rnn_grad (line 51) | def peakmem_manual_rnn_grad(self):
    method manual_rnn_grad (line 54) | def manual_rnn_grad(self):

FILE: benchmarks/bench_util.py
  function time_flatten (line 11) | def time_flatten():
  function time_grad_flatten (line 37) | def time_grad_flatten():

FILE: examples/bayesian_neural_net.py
  function make_nn_funs (line 9) | def make_nn_funs(layer_sizes, L2_reg, noise_variance, nonlinearity=np.ta...
  function build_toy_dataset (line 42) | def build_toy_dataset(n_data=40, noise_std=0.1):
  function callback (line 75) | def callback(params, t, g):

FILE: examples/bayesian_optimization.py
  function probability_of_improvement (line 14) | def probability_of_improvement(mean, std, max_so_far):
  function expected_new_max (line 18) | def expected_new_max(mean, std, max_so_far):
  function init_covariance_params (line 26) | def init_covariance_params(num_params):
  function defaultmax (line 30) | def defaultmax(x, default=-np.inf):
  function bayesian_optimize (line 36) | def bayesian_optimize(func, domain_min, domain_max, num_iters=20, callba...
  function example_function (line 110) | def example_function(x):
  function callback (line 121) | def callback(X, y, predict_func, acquisition_function, next_point, new_v...

FILE: examples/black_box_svi.py
  function black_box_variational_inference (line 11) | def black_box_variational_inference(logprob, D, num_samples):
  function log_density (line 41) | def log_density(x, t):
  function plot_isocontours (line 51) | def plot_isocontours(ax, func, xlimits=[-2, 2], ylimits=[-4, 2], numtick...
  function callback (line 67) | def callback(params, t, g):

FILE: examples/convnet.py
  class WeightsParser (line 14) | class WeightsParser:
    method __init__ (line 17) | def __init__(self):
    method add_weights (line 21) | def add_weights(self, name, shape):
    method get (line 26) | def get(self, vect, name):
  function make_batches (line 31) | def make_batches(N_total, N_batch):
  function logsumexp (line 40) | def logsumexp(X, axis, keepdims=False):
  function make_nn_funs (line 45) | def make_nn_funs(input_shape, layer_specs, L2_reg):
  class conv_layer (line 72) | class conv_layer:
    method __init__ (line 73) | def __init__(self, kernel_shape, num_filters):
    method forward_pass (line 77) | def forward_pass(self, inputs, param_vector):
    method build_weights_dict (line 86) | def build_weights_dict(self, input_shape):
    method conv_output_shape (line 94) | def conv_output_shape(self, A, B):
  class maxpool_layer (line 98) | class maxpool_layer:
    method __init__ (line 99) | def __init__(self, pool_shape):
    method build_weights_dict (line 102) | def build_weights_dict(self, input_shape):
    method forward_pass (line 110) | def forward_pass(self, inputs, param_vector):
  class full_layer (line 120) | class full_layer:
    method __init__ (line 121) | def __init__(self, size):
    method build_weights_dict (line 124) | def build_weights_dict(self, input_shape):
    method forward_pass (line 132) | def forward_pass(self, inputs, param_vector):
  class tanh_layer (line 140) | class tanh_layer(full_layer):
    method nonlinearity (line 141) | def nonlinearity(self, x):
  class softmax_layer (line 145) | class softmax_layer(full_layer):
    method nonlinearity (line 146) | def nonlinearity(self, x):
  function print_perf (line 195) | def print_perf(epoch, W):

FILE: examples/data.py
  function load_mnist (line 9) | def load_mnist():
  function plot_images (line 22) | def plot_images(
  function save_images (line 58) | def save_images(images, filename, **kwargs):
  function make_pinwheel (line 68) | def make_pinwheel(radial_std, tangential_std, num_classes, num_per_class...

FILE: examples/data_mnist.py
  function download (line 10) | def download(url, filename):
  function mnist (line 18) | def mnist():

FILE: examples/deep_gaussian_process.py
  function build_step_function_dataset (line 10) | def build_step_function_dataset(D=1, n_data=40, noise_std=0.1):
  function build_deep_gp (line 18) | def build_deep_gp(input_dimension, hidden_dimension, covariance_function):
  function plot_gp (line 84) | def plot_gp(ax, X, y, pred_mean, pred_cov, plot_xs):
  function callback (line 105) | def callback(params):

FILE: examples/define_gradient.py
  function logsumexp (line 17) | def logsumexp(x):
  function logsumexp_vjp (line 29) | def logsumexp_vjp(ans, x):
  function example_func (line 47) | def example_func(y):

FILE: examples/dot_graph.py
  class GraphNode (line 11) | class GraphNode(Node):
    method __init__ (line 13) | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
    method initialize_root (line 19) | def initialize_root(self, x):
    method __repr__ (line 22) | def __repr__(self):
  function trace_graph (line 26) | def trace_graph(f, x):
  function graph_to_dotfile (line 38) | def graph_to_dotfile(graph):
  function fun (line 67) | def fun(x):

FILE: examples/fixed_points.py
  function newton_sqrt_iter (line 6) | def newton_sqrt_iter(a):
  function grad_descent_sqrt_iter (line 10) | def grad_descent_sqrt_iter(a):
  function sqrt (line 14) | def sqrt(a, guess=10.0):
  function distance (line 19) | def distance(x, y):

FILE: examples/fluidsim/fluidsim.py
  function project (line 16) | def project(vx, vy):
  function advect (line 46) | def advect(f, vx, vy):
  function simulate (line 71) | def simulate(vx, vy, smoke, num_time_steps, ax=None, render=False):
  function plot_matrix (line 85) | def plot_matrix(ax, mat, t, render=False):
  function distance_from_target_image (line 108) | def distance_from_target_image(smoke):
  function convert_param_vector_to_matrices (line 111) | def convert_param_vector_to_matrices(params):
  function objective (line 116) | def objective(params):
  function callback (line 127) | def callback(params):

FILE: examples/fluidsim/wing.py
  function occlude (line 16) | def occlude(f, occlusion):
  function project (line 20) | def project(vx, vy, occlusion):
  function advect (line 47) | def advect(f, vx, vy):
  function make_continuous (line 72) | def make_continuous(f, occlusion):
  function sigmoid (line 89) | def sigmoid(x):
  function simulate (line 93) | def simulate(vx, vy, num_time_steps, occlusion, ax=None, render=False):
  function plot_matrix (line 122) | def plot_matrix(ax, r, g, b, t, render=False):
  function drag (line 146) | def drag(vx):
  function lift (line 149) | def lift(vy):
  function objective (line 152) | def objective(params):
  function callback (line 163) | def callback(weights):

FILE: examples/gaussian_process.py
  function make_gp_funs (line 11) | def make_gp_funs(cov_func, num_cov_params):
  function rbf_covariance (line 42) | def rbf_covariance(kernel_params, x, xp):
  function build_toy_dataset (line 49) | def build_toy_dataset(D=1, n_data=20, noise_std=0.1):
  function callback (line 72) | def callback(params):

FILE: examples/generative_adversarial_net.py
  function relu (line 16) | def relu(x):
  function sigmoid (line 20) | def sigmoid(x):
  function logsigmoid (line 24) | def logsigmoid(x):
  function init_random_params (line 28) | def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)):
  function batch_normalize (line 40) | def batch_normalize(activations):
  function neural_net_predict (line 45) | def neural_net_predict(params, inputs):
  function generate_from_noise (line 58) | def generate_from_noise(gen_params, num_samples, noise_dim, rs):
  function gan_objective (line 64) | def gan_objective(gen_params, dsc_params, real_data, num_samples, noise_...
  function adam_minimax (line 74) | def adam_minimax(
  function batch_indices (line 143) | def batch_indices(iter):
  function objective (line 150) | def objective(gen_params, dsc_params, iter):
  function print_perf (line 159) | def print_perf(gen_params, dsc_params, iter, gen_gradient, dsc_gradient):

FILE: examples/gmm.py
  function init_gmm_params (line 17) | def init_gmm_params(num_components, D, scale, rs=npr.RandomState(0)):
  function log_normalize (line 25) | def log_normalize(x):
  function unpack_gmm_params (line 29) | def unpack_gmm_params(params):
  function gmm_log_likelihood (line 34) | def gmm_log_likelihood(params, data):
  function plot_ellipse (line 42) | def plot_ellipse(ax, mean, cov_sqrt, alpha, num_points=100):
  function plot_gaussian_mixture (line 49) | def plot_gaussian_mixture(params, ax):
  function objective (line 60) | def objective(params):
  function callback (line 69) | def callback(flattened_params):

FILE: examples/gplvm.py
  function unpack_params (line 39) | def unpack_params(params):
  function objective (line 44) | def objective(params):
  function callback (line 58) | def callback(params):

FILE: examples/hmm_em.py
  function EM (line 11) | def EM(init_params, data, callback=None):
  function normalize (line 32) | def normalize(a):
  function log_partition_function (line 39) | def log_partition_function(natural_params, data):
  function initialize_hmm_parameters (line 52) | def initialize_hmm_parameters(num_states, num_outputs):
  function build_dataset (line 59) | def build_dataset(filename, max_lines=-1):

FILE: examples/ica.py
  function make_ica_funs (line 11) | def make_ica_funs(observed_dimension, latent_dimension):
  function color_scatter (line 42) | def color_scatter(ax, xs, ys):
  function unpack_params (line 59) | def unpack_params(params):
  function objective (line 95) | def objective(params):
  function callback (line 99) | def callback(params):

FILE: examples/logistic_regression.py
  function sigmoid (line 6) | def sigmoid(x):
  function logistic_predictions (line 10) | def logistic_predictions(weights, inputs):
  function training_loss (line 15) | def training_loss(weights):

FILE: examples/lstm.py
  function init_lstm_params (line 16) | def init_lstm_params(input_size, state_size, output_size, param_scale=0....
  function lstm_predict (line 31) | def lstm_predict(params, inputs):
  function lstm_log_likelihood (line 56) | def lstm_log_likelihood(params, inputs, targets):
  function print_training_prediction (line 74) | def print_training_prediction(weights):
  function training_loss (line 82) | def training_loss(params, iter):
  function callback (line 85) | def callback(weights, iter, gradient):

FILE: examples/mixture_variational_inference.py
  function diag_gaussian_log_density (line 17) | def diag_gaussian_log_density(x, mu, log_std):
  function unpack_gaussian_params (line 21) | def unpack_gaussian_params(params):
  function variational_log_density_gaussian (line 28) | def variational_log_density_gaussian(params, x):
  function sample_diag_gaussian (line 33) | def sample_diag_gaussian(params, num_samples, rs):
  function variational_lower_bound (line 39) | def variational_lower_bound(params, t, logprob, sampler, log_density, nu...
  function init_gaussian_var_params (line 50) | def init_gaussian_var_params(D, mean_mean=-1, log_std_mean=-5, scale=0.1...
  function log_normalize (line 56) | def log_normalize(x):
  function build_mog_bbsvi (line 60) | def build_mog_bbsvi(logprob, num_samples, k=10, rs=npr.RandomState(0)):
  function log_density (line 118) | def log_density(x, t):
  function objective (line 130) | def objective(params, t):
  function plot_isocontours (line 134) | def plot_isocontours(ax, func, xlimits=[-2, 2], ylimits=[-4, 2], numtick...
  function callback (line 151) | def callback(params, t, g):

FILE: examples/natural_gradient_black_box_svi.py
  function log_density (line 18) | def log_density(x, t):
  function fisher_diag (line 47) | def fisher_diag(lam):
  function optimize_and_lls (line 55) | def optimize_and_lls(optfun):

FILE: examples/negative_binomial_maxlike.py
  function newton (line 13) | def newton(f, x0):
  function negbin_loglike (line 18) | def negbin_loglike(r, p, x):
  function negbin_sample (line 23) | def negbin_sample(r, p, size):
  function fit_maxlike (line 28) | def fit_maxlike(x, r_guess):

FILE: examples/neural_net.py
  function init_random_params (line 13) | def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)):
  function neural_net_predict (line 25) | def neural_net_predict(params, inputs):
  function l2_norm (line 36) | def l2_norm(params):
  function log_posterior (line 42) | def log_posterior(params, inputs, targets, L2_reg):
  function accuracy (line 48) | def accuracy(params, inputs, targets):
  function batch_indices (line 72) | def batch_indices(iter):
  function objective (line 77) | def objective(params, iter):
  function print_perf (line 86) | def print_perf(params, iter, gradient):

FILE: examples/neural_net_regression.py
  function init_random_params (line 11) | def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)):
  function nn_predict (line 22) | def nn_predict(params, inputs, nonlinearity=np.tanh):
  function log_gaussian (line 29) | def log_gaussian(params, scale):
  function logprob (line 34) | def logprob(weights, inputs, targets, noise_scale=0.1):
  function build_toy_dataset (line 39) | def build_toy_dataset(n_data=80, noise_std=0.1):
  function objective (line 56) | def objective(weights, t):
  function callback (line 66) | def callback(params, t, g):

FILE: examples/ode_net.py
  function func (line 21) | def func(y, t0, A):
  function nn_predict (line 25) | def nn_predict(inputs, t, params):
  function init_nn_params (line 32) | def init_nn_params(scale, layer_sizes, rs=npr.RandomState(0)):
  function ode_pred (line 44) | def ode_pred(params, y0, t):
  function L1_loss (line 48) | def L1_loss(pred, targets):
  function train_loss (line 59) | def train_loss(params, iter):
  function callback (line 71) | def callback(params, iter, g):

FILE: examples/print_trace.py
  class PrintNode (line 9) | class PrintNode(Node):
    method __init__ (line 10) | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
    method initialize_root (line 18) | def initialize_root(self, x):
  function make_varname_generator (line 24) | def make_varname_generator():
  function print_trace (line 30) | def print_trace(f, x):
  function avg (line 36) | def avg(x, y):
  function fun (line 40) | def fun(x):

FILE: examples/rkhs.py
  class RKHSFun (line 15) | class RKHSFun:
    method __init__ (line 16) | def __init__(self, kernel, alphas={}):
    method __call__ (line 22) | def __call__(self, x):
    method __add__ (line 25) | def __add__(self, f):
    method __mul__ (line 28) | def __mul__(self, a):
  class RKHSFunBox (line 36) | class RKHSFunBox(Box, RKHSFun):
    method kernel (line 38) | def kernel(self):
  class RKHSFunVSpace (line 45) | class RKHSFunVSpace(VSpace):
    method __init__ (line 46) | def __init__(self, value):
    method zeros (line 49) | def zeros(self):
    method randn (line 52) | def randn(self):
    method _add (line 57) | def _add(self, f, g):
    method _scalar_mul (line 61) | def _scalar_mul(self, f, a):
    method _inner_prod (line 64) | def _inner_prod(self, f, g):
  function add_dicts (line 74) | def add_dicts(d1, d2):
  function sq_exp_kernel (line 83) | def sq_exp_kernel(x1, x2):
  function logprob (line 89) | def logprob(f, xs, ys):

FILE: examples/rnn.py
  function sigmoid (line 16) | def sigmoid(x):
  function concat_and_multiply (line 20) | def concat_and_multiply(weights, *args):
  function create_rnn_params (line 28) | def create_rnn_params(input_size, state_size, output_size, param_scale=0...
  function rnn_predict (line 36) | def rnn_predict(params, inputs):
  function rnn_log_likelihood (line 54) | def rnn_log_likelihood(params, inputs, targets):
  function string_to_one_hot (line 66) | def string_to_one_hot(string, maxchar):
  function one_hot_to_string (line 72) | def one_hot_to_string(one_hot_matrix):
  function build_dataset (line 76) | def build_dataset(filename, sequence_length, alphabet_size, max_lines=-1):
  function print_training_prediction (line 98) | def print_training_prediction(weights):
  function training_loss (line 106) | def training_loss(params, iter):
  function callback (line 109) | def callback(weights, iter, gradient):

FILE: examples/rosenbrock.py
  function rosenbrock (line 7) | def rosenbrock(x):

FILE: examples/sinusoid.py
  function fun (line 7) | def fun(x):
  function fun (line 25) | def fun(x):

FILE: examples/tanh.py
  function tanh (line 22) | def tanh(x):

FILE: examples/variational_autoencoder.py
  function diag_gaussian_log_density (line 13) | def diag_gaussian_log_density(x, mu, log_std):
  function unpack_gaussian_params (line 17) | def unpack_gaussian_params(params):
  function sample_diag_gaussian (line 24) | def sample_diag_gaussian(mean, log_std, rs):
  function bernoulli_log_density (line 28) | def bernoulli_log_density(targets, unnormalized_logprobs):
  function relu (line 35) | def relu(x):
  function init_net_params (line 39) | def init_net_params(scale, layer_sizes, rs=npr.RandomState(0)):
  function batch_normalize (line 50) | def batch_normalize(activations):
  function neural_net_predict (line 55) | def neural_net_predict(params, inputs):
  function nn_predict_gaussian (line 67) | def nn_predict_gaussian(params, inputs):
  function generate_from_prior (line 72) | def generate_from_prior(gen_params, num_samples, noise_dim, rs):
  function p_images_given_latents (line 77) | def p_images_given_latents(gen_params, images, latents):
  function vae_lower_bound (line 82) | def vae_lower_bound(gen_params, rec_params, data, rs):
  function binarise (line 109) | def binarise(images):
  function batch_indices (line 125) | def batch_indices(iter):
  function objective (line 132) | def objective(combined_params, iter):
  function print_perf (line 142) | def print_perf(combined_params, iter, grad):

FILE: noxfile.py
  function check (line 21) | def check(session):
  function run_tests (line 29) | def run_tests(session):
  function ruff (line 42) | def ruff(session):
  function run_nightly_tests (line 49) | def run_nightly_tests(session):

FILE: tests/_test_complexity.py
  function timefunction (line 9) | def timefunction(f):
  function assert_linear_time (line 15) | def assert_linear_time(f):
  function test_array_creation (line 24) | def test_array_creation():
  function test_array_indexing (line 32) | def test_array_indexing():
  function test_list_indexing (line 39) | def test_list_indexing():
  function test_list_creation (line 46) | def test_list_creation():
  function test_array_creation_fwd (line 54) | def test_array_creation_fwd():

FILE: tests/conftest.py
  function random_seed (line 6) | def random_seed():

FILE: tests/numpy_utils.py
  function stat_check (line 5) | def stat_check(fun, test_complex=True, **kwargs):
  function unary_ufunc_check (line 23) | def unary_ufunc_check(fun, lims=[-2, 2], test_complex=True, **kwargs):
  function binary_ufunc_check (line 36) | def binary_ufunc_check(fun, lims_A=[-2, 2], lims_B=[-2, 2], test_complex...
  function binary_ufunc_check_no_same_args (line 54) | def binary_ufunc_check_no_same_args(fun, lims_A=[-2, 2], lims_B=[-2, 2],...
  function transform (line 81) | def transform(lims, x):

FILE: tests/profiling.py
  function tictoc (line 10) | def tictoc(text=""):
  function fan_out_fan_in (line 18) | def fan_out_fan_in():
  function convolution (line 30) | def convolution():
  function dot_equivalent (line 41) | def dot_equivalent():

FILE: tests/test_binary_ops.py
  function arg_pairs (line 12) | def arg_pairs():
  function test_mul (line 21) | def test_mul():
  function test_add (line 27) | def test_add():
  function test_sub (line 33) | def test_sub():
  function test_div (line 39) | def test_div():
  function test_mod (line 48) | def test_mod():
  function test_pow (line 58) | def test_pow():
  function test_arctan2 (line 66) | def test_arctan2():
  function test_hypot (line 71) | def test_hypot():
  function test_comparison_grads (line 76) | def test_comparison_grads():
  function test_comparison_values (line 94) | def test_comparison_values():

FILE: tests/test_builtins.py
  function test_isinstance (line 6) | def test_isinstance():

FILE: tests/test_complex.py
  function test_real_type (line 9) | def test_real_type():
  function test_real_if_close_type (line 16) | def test_real_if_close_type():
  function test_angle_real (line 23) | def test_angle_real():
  function test_angle_complex (line 30) | def test_angle_complex():
  function test_abs_real (line 37) | def test_abs_real():
  function test_abs_complex (line 44) | def test_abs_complex():

FILE: tests/test_core.py
  function grad (line 11) | def grad(fun, x):
  function nd (line 17) | def nd(f, x, eps=1e-4):
  function check_close (line 21) | def check_close(a, b, atol=1e-4, rtol=1e-4):
  function check_binary_func (line 25) | def check_binary_func(fun, independent=False):
  function test_add (line 37) | def test_add():
  function test_sub (line 41) | def test_sub():
  function test_div (line 45) | def test_div():
  function test_mul (line 49) | def test_mul():
  function test_pow (line 53) | def test_pow():
  function test_mod (line 57) | def test_mod():
  function test_eq (line 61) | def test_eq():
  function test_neq (line 65) | def test_neq():
  function test_leq (line 69) | def test_leq():
  function test_geq (line 73) | def test_geq():
  function test_lt (line 77) | def test_lt():
  function test_gt (line 81) | def test_gt():

FILE: tests/test_dict.py
  function test_getter (line 13) | def test_getter():
  function test_grads (line 29) | def test_grads():
  function test_iter (line 48) | def test_iter():
  function test_items_values_keys (line 70) | def test_items_values_keys():
  function test_get (line 96) | def test_get():
  function test_make_dict (line 105) | def test_make_dict():
  function test_isinstance (line 122) | def test_isinstance():

FILE: tests/test_direct.py
  function test_grad (line 13) | def test_grad():
  function test_deriv (line 20) | def test_deriv():
  function test_grad_complex_output (line 27) | def test_grad_complex_output():
  function test_holomorphic_grad (line 35) | def test_holomorphic_grad():

FILE: tests/test_fft.py
  function test_fft (line 15) | def test_fft():
  function test_fft_ortho (line 24) | def test_fft_ortho():
  function test_fft_axis (line 33) | def test_fft_axis():
  function match_complex (line 42) | def match_complex(fft_fun, mat):
  function check_fft_n (line 50) | def check_fft_n(fft_fun, D, n):
  function test_fft_n_smaller (line 59) | def test_fft_n_smaller():
  function test_fft_n_bigger (line 63) | def test_fft_n_bigger():
  function test_ifft_n_smaller (line 67) | def test_ifft_n_smaller():
  function test_ifft_n_bigger (line 71) | def test_ifft_n_bigger():
  function test_rfft_n_smaller (line 75) | def test_rfft_n_smaller():
  function test_rfft_n_bigger (line 79) | def test_rfft_n_bigger():
  function test_irfft_n_smaller (line 83) | def test_irfft_n_smaller():
  function test_irfft_n_bigger (line 87) | def test_irfft_n_bigger():
  function check_fft_s (line 91) | def check_fft_s(fft_fun, D):
  function test_fft2_s (line 102) | def test_fft2_s():
  function test_ifft2_s (line 106) | def test_ifft2_s():
  function test_fftn_s (line 110) | def test_fftn_s():
  function test_ifftn_s (line 114) | def test_ifftn_s():
  function test_rfft2_s (line 118) | def test_rfft2_s():
  function test_irfft2_s (line 122) | def test_irfft2_s():
  function test_rfftn_s (line 126) | def test_rfftn_s():
  function test_irfftn_s (line 130) | def test_irfftn_s():
  function test_ifft (line 147) | def test_ifft():
  function test_fft2 (line 156) | def test_fft2():
  function test_ifft2 (line 165) | def test_ifft2():
  function test_fftn (line 174) | def test_fftn():
  function test_ifftn (line 183) | def test_ifftn():
  function test_rfft (line 192) | def test_rfft():
  function test_rfft_ortho (line 201) | def test_rfft_ortho():
  function test_rfft_axes (line 210) | def test_rfft_axes():
  function test_irfft (line 219) | def test_irfft():
  function test_irfft_ortho (line 230) | def test_irfft_ortho():
  function test_rfft2 (line 241) | def test_rfft2():
  function test_irfft2 (line 250) | def test_irfft2():
  function test_rfftn (line 261) | def test_rfftn():
  function test_rfftn_odd_not_implemented (line 270) | def test_rfftn_odd_not_implemented():
  function test_rfftn_subset (line 280) | def test_rfftn_subset():
  function test_rfftn_axes (line 289) | def test_rfftn_axes():
  function test_irfftn (line 298) | def test_irfftn():
  function test_irfftn_subset (line 309) | def test_irfftn_subset():
  function test_fftshift (line 320) | def test_fftshift():
  function test_fftshift_even (line 329) | def test_fftshift_even():
  function test_fftshift_axes (line 338) | def test_fftshift_axes():
  function test_ifftshift (line 347) | def test_ifftshift():
  function test_ifftshift_even (line 356) | def test_ifftshift_even():
  function test_ifftshift_axes (line 365) | def test_ifftshift_axes():

FILE: tests/test_graphs.py
  function test_grad_fanout (line 13) | def test_grad_fanout():
  function test_grad_const (line 20) | def test_grad_const():
  function test_grad_identity (line 28) | def test_grad_identity():
  function test_hess_vector_prod (line 36) | def test_hess_vector_prod():
  function test_enclosing_scope_ref (line 55) | def test_enclosing_scope_ref():
  function test_enclosing_scope_ref_2 (line 63) | def test_enclosing_scope_ref_2():
  function test_mutating_outgrad (line 71) | def test_mutating_outgrad():
  function test_mutating_outgrad_from_indexing (line 83) | def test_mutating_outgrad_from_indexing():
  function test_complex_mutating_outgrad_from_indexing (line 95) | def test_complex_mutating_outgrad_from_indexing():
  function test_complex_separate_real_and_imaginary (line 109) | def test_complex_separate_real_and_imaginary():
  function test_third_derivative (line 121) | def test_third_derivative():
  function test_third_derivative_other_args (line 132) | def test_third_derivative_other_args():
  function test_third_derivative_other_args2 (line 143) | def test_third_derivative_other_args2():
  function test_singleton_array_output (line 154) | def test_singleton_array_output():
  function test_singleton_array_output_axis0 (line 160) | def test_singleton_array_output_axis0():
  function test_singleton_array_output_axis1 (line 166) | def test_singleton_array_output_axis1():
  function test_singleton_array_output_axis0_keepdims (line 172) | def test_singleton_array_output_axis0_keepdims():
  function test_singleton_array_output_axis1_keepdims (line 178) | def test_singleton_array_output_axis1_keepdims():
  function test_assignment_raises_error (line 184) | def test_assignment_raises_error():

FILE: tests/test_jacobian.py
  function test_jacobian_against_grad (line 9) | def test_jacobian_against_grad():
  function test_jacobian_scalar_to_vector (line 15) | def test_jacobian_scalar_to_vector():
  function test_jacobian_against_stacked_grads (line 21) | def test_jacobian_against_stacked_grads():
  function test_jacobian_higher_order (line 37) | def test_jacobian_higher_order():

FILE: tests/test_linalg.py
  function check_symmetric_matrix_grads (line 17) | def check_symmetric_matrix_grads(fun, **grad_check_kwargs):
  function rand_psd (line 32) | def rand_psd(D):
  function test_inv (line 37) | def test_inv():
  function test_pinv (line 47) | def test_pinv():
  function test_inv_3d (line 73) | def test_inv_3d():
  function test_solve_arg1 (line 84) | def test_solve_arg1():
  function test_solve_arg1_1d (line 95) | def test_solve_arg1_1d():
  function test_solve_arg2 (line 106) | def test_solve_arg2():
  function test_solve_arg1_3d (line 117) | def test_solve_arg1_3d():
  function test_solve_arg1_3d_3d (line 128) | def test_solve_arg1_3d_3d():
  function test_det (line 136) | def test_det():
  function test_det_3d (line 145) | def test_det_3d():
  function test_slogdet (line 152) | def test_slogdet():
  function test_slogdet_3d (line 163) | def test_slogdet_3d():
  function test_vector_2norm (line 169) | def test_vector_2norm():
  function test_vector_2norm_complex (line 178) | def test_vector_2norm_complex():
  function test_frobenius_norm (line 187) | def test_frobenius_norm():
  function test_frobenius_norm_complex (line 196) | def test_frobenius_norm_complex():
  function test_frobenius_norm_axis (line 205) | def test_frobenius_norm_axis():
  function test_frobenius_norm_axis_complex (line 214) | def test_frobenius_norm_axis_complex():
  function test_vector_norm_ord (line 225) | def test_vector_norm_ord(size, ord):
  function test_vector_norm_ord_complex (line 235) | def test_vector_norm_ord_complex(size, ord):
  function test_norm_axis (line 245) | def test_norm_axis(shape, axis):
  function test_norm_axis_complex (line 255) | def test_norm_axis_complex(shape, axis):
  function test_norm_nuclear (line 263) | def test_norm_nuclear():
  function test_norm_nuclear_complex (line 273) | def test_norm_nuclear_complex():
  function test_norm_nuclear_axis (line 282) | def test_norm_nuclear_axis():
  function test_norm_nuclear_axis_complex (line 292) | def test_norm_nuclear_axis_complex():
  function test_eigvalh_lower (line 301) | def test_eigvalh_lower():
  function test_eigvalh_upper (line 311) | def test_eigvalh_upper():
  function test_eigvalh_lower_broadcasting (line 324) | def test_eigvalh_lower_broadcasting():
  function test_eigvalh_upper_broadcasting (line 335) | def test_eigvalh_upper_broadcasting():
  function test_eigvalh_lower_complex (line 349) | def test_eigvalh_lower_complex():
  function test_eigvalh_upper_complex (line 359) | def test_eigvalh_upper_complex():
  function test_eig_real (line 370) | def test_eig_real():
  function test_eig_complex (line 380) | def test_eig_complex():
  function test_eig_batched (line 390) | def test_eig_batched():
  function test_cholesky (line 401) | def test_cholesky():
  function test_cholesky_broadcast (line 406) | def test_cholesky_broadcast():
  function test_cholesky_reparameterization_trick (line 412) | def test_cholesky_reparameterization_trick():
  function test_svd_wide_2d (line 421) | def test_svd_wide_2d():
  function test_svd_wide_2d_complex (line 432) | def test_svd_wide_2d_complex():
  function test_svd_wide_3d (line 443) | def test_svd_wide_3d():
  function test_svd_wide_3d_complex (line 455) | def test_svd_wide_3d_complex():
  function test_svd_square_2d (line 467) | def test_svd_square_2d():
  function test_svd_square_2d_complex (line 478) | def test_svd_square_2d_complex():
  function test_svd_square_3d (line 489) | def test_svd_square_3d():
  function test_svd_square_3d_complex (line 501) | def test_svd_square_3d_complex():
  function test_svd_tall_2d (line 513) | def test_svd_tall_2d():
  function test_svd_tall_2d_complex (line 524) | def test_svd_tall_2d_complex():
  function test_svd_tall_3d (line 535) | def test_svd_tall_3d():
  function test_svd_tall_3d_complex (line 547) | def test_svd_tall_3d_complex():
  function test_svd_only_s_2d (line 559) | def test_svd_only_s_2d():
  function test_svd_only_s_2d_complex (line 570) | def test_svd_only_s_2d_complex():
  function test_svd_only_s_3d (line 581) | def test_svd_only_s_3d():
  function test_svd_only_s_3d_complex (line 593) | def test_svd_only_s_3d_complex():

FILE: tests/test_list.py
  function test_getter (line 11) | def test_getter():
  function test_grads (line 27) | def test_grads():
  function test_slices (line 46) | def test_slices():
  function test_nested_list (line 61) | def test_nested_list():
  function test_make_list (line 70) | def test_make_list():
  function test_isinstance (line 77) | def test_isinstance():

FILE: tests/test_logic.py
  function test_assert (line 13) | def test_assert():
  function test_nograd (line 22) | def test_nograd():
  function test_no_vjp_def (line 30) | def test_no_vjp_def():
  function test_no_jvp_def (line 36) | def test_no_jvp_def():
  function test_falseyness (line 42) | def test_falseyness():
  function test_unimplemented_falseyness (line 48) | def test_unimplemented_falseyness():

FILE: tests/test_misc.py
  function test_const_graph (line 9) | def test_const_graph():
  function test_const_graph_args (line 27) | def test_const_graph_args():
  function test_flatten (line 54) | def test_flatten():
  function test_flatten_empty (line 68) | def test_flatten_empty():
  function test_flatten_dict (line 76) | def test_flatten_dict():
  function unflatten_tracing (line 85) | def unflatten_tracing():
  function test_flatten_nodes_in_containers (line 96) | def test_flatten_nodes_in_containers():
  function test_flatten_complex (line 105) | def test_flatten_complex():

FILE: tests/test_numpy.py
  function test_numpy_version (line 13) | def test_numpy_version():
  function test_dot (line 19) | def test_dot():
  function test_dot_with_floats (line 35) | def test_dot_with_floats():
  function test_outer (line 65) | def test_outer():
  function test_max (line 77) | def test_max():
  function test_max_axis (line 85) | def test_max_axis():
  function test_max_axis_keepdims (line 93) | def test_max_axis_keepdims():
  function test_min (line 101) | def test_min():
  function test_min_axis (line 109) | def test_min_axis():
  function test_min_axis_keepdims (line 117) | def test_min_axis_keepdims():
  function test_sum_1 (line 125) | def test_sum_1():
  function test_sum_2 (line 133) | def test_sum_2():
  function test_sum_3 (line 141) | def test_sum_3():
  function test_sum_with_axis_tuple (line 149) | def test_sum_with_axis_tuple():
  function test_flipud (line 157) | def test_flipud():
  function test_fliplr (line 165) | def test_fliplr():
  function test_rot90 (line 173) | def test_rot90():
  function test_cumsum_axis0 (line 181) | def test_cumsum_axis0():
  function test_cumsum_axis1 (line 189) | def test_cumsum_axis1():
  function test_cumsum_1d (line 197) | def test_cumsum_1d():
  function test_cumsum_no_axis (line 205) | def test_cumsum_no_axis():
  function test_non_numpy_sum (line 213) | def test_non_numpy_sum():
  function test_mean_1 (line 222) | def test_mean_1():
  function test_mean_2 (line 230) | def test_mean_2():
  function test_mean_3 (line 238) | def test_mean_3():
  function test_index_ints (line 246) | def test_index_ints():
  function test_index_slice (line 255) | def test_index_slice():
  function test_index_lists (line 264) | def test_index_lists():
  function test_index_mixed (line 273) | def test_index_mixed():
  function test_vector_slice (line 282) | def test_vector_slice():
  function test_index_slice_fanout (line 291) | def test_index_slice_fanout():
  function test_index_multiple_slices (line 302) | def test_index_multiple_slices():
  function test_reshape_method (line 313) | def test_reshape_method():
  function test_reshape_call (line 322) | def test_reshape_call():
  function test_reshape_method_nolist (line 331) | def test_reshape_method_nolist():
  function test_ravel_method (line 343) | def test_ravel_method():
  function test_ravel_call (line 352) | def test_ravel_call():
  function test_flatten_method (line 361) | def test_flatten_method():
  function test_simple_append_list (line 370) | def test_simple_append_list():
  function test_simple_append_arr (line 376) | def test_simple_append_arr():
  function test_simple_append_list_2D (line 382) | def test_simple_append_list_2D():
  function test_simple_concatenate (line 388) | def test_simple_concatenate():
  function test_concatenate_axis_0 (line 398) | def test_concatenate_axis_0():
  function test_concatenate_axis_1 (line 408) | def test_concatenate_axis_1():
  function test_concatenate_axis_1_unnamed (line 418) | def test_concatenate_axis_1_unnamed():
  function test_trace (line 429) | def test_trace():
  function test_trace2 (line 438) | def test_trace2():
  function test_trace_extradims (line 447) | def test_trace_extradims():
  function test_diag (line 464) | def test_diag():
  function test_transpose (line 472) | def test_transpose():
  function test_roll (line 480) | def test_roll():
  function test_roll_no_axis (line 488) | def test_roll_no_axis():
  function test_triu (line 496) | def test_triu():
  function test_tril (line 504) | def test_tril():
  function test_clip (line 512) | def test_clip():
  function test_prod_1 (line 520) | def test_prod_1():
  function test_prod_2 (line 528) | def test_prod_2():
  function test_prod_3 (line 536) | def test_prod_3():
  function test_prod_4 (line 544) | def test_prod_4():
  function test_1d_array (line 552) | def test_1d_array():
  function test_2d_array (line 559) | def test_2d_array():
  function test_1d_array_fanout (line 566) | def test_1d_array_fanout():
  function test_2d_array_fanout (line 574) | def test_2d_array_fanout():
  function test_array_from_scalar (line 582) | def test_array_from_scalar():
  function test_array_from_arrays (line 589) | def test_array_from_arrays():
  function test_array_from_arrays_2 (line 597) | def test_array_from_arrays_2():
  function test_len (line 605) | def test_len():
  function test_r_basic (line 614) | def test_r_basic():
  function test_r_double (line 626) | def test_r_double():
  function test_no_relation (line 638) | def test_no_relation():
  function test_r_no_relation (line 649) | def test_r_no_relation():
  function test_r_node_and_const (line 661) | def test_r_node_and_const():
  function test_r_mixed (line 673) | def test_r_mixed():
  function test_r_slicing (line 685) | def test_r_slicing():
  function test_c_ (line 697) | def test_c_():
  function test_c_mixed (line 709) | def test_c_mixed():
  function test_var_ddof (line 721) | def test_var_ddof():
  function test_std_ddof (line 729) | def test_std_ddof():
  function test_where (line 737) | def test_where():
  function test_squeeze_func (line 748) | def test_squeeze_func():
  function test_squeeze_method (line 757) | def test_squeeze_method():
  function test_repeat (line 766) | def test_repeat():
  function test_repeat_axis1_rep1 (line 775) | def test_repeat_axis1_rep1():
  function test_repeat_axis0 (line 784) | def test_repeat_axis0():
  function test_repeat_1d_axis0 (line 793) | def test_repeat_1d_axis0():
  function test_repeat_axis0_rep1 (line 802) | def test_repeat_axis0_rep1():
  function test_expand_dims (line 811) | def test_expand_dims():
  function test_tensordot_kwargs_by_position (line 820) | def test_tensordot_kwargs_by_position():
  function test_multi_index (line 827) | def test_multi_index():
  function test_multi_index2 (line 833) | def test_multi_index2():
  function test_index_dot_slices (line 839) | def test_index_dot_slices():
  function test_cast_to_int (line 866) | def test_cast_to_int():
  function test_make_diagonal (line 884) | def test_make_diagonal():
  function test_diagonal (line 899) | def test_diagonal():
  function test_nan_to_num (line 912) | def test_nan_to_num():
  function test_max_equal_values (line 928) | def test_max_equal_values():
  function test_max_equal_values_2d (line 935) | def test_max_equal_values_2d():
  function test_min_3_way_equality (line 943) | def test_min_3_way_equality():
  function test_maximum_equal_values (line 951) | def test_maximum_equal_values():
  function test_maximum_equal_values_2d (line 958) | def test_maximum_equal_values_2d():
  function test_linspace (line 967) | def test_linspace():
  function test_astype (line 978) | def test_astype():
  function test_gradient (line 987) | def test_gradient():

FILE: tests/test_scalar_ops.py
  function test_abs (line 9) | def test_abs():
  function test_absolute (line 16) | def test_absolute():
  function test_sin (line 23) | def test_sin():
  function test_sign (line 28) | def test_sign():
  function test_exp (line 34) | def test_exp():
  function test_log (line 39) | def test_log():
  function test_log2 (line 44) | def test_log2():
  function test_log10 (line 49) | def test_log10():
  function test_log1p (line 54) | def test_log1p():
  function test_expm1 (line 59) | def test_expm1():
  function test_exp2 (line 64) | def test_exp2():
  function test_neg (line 69) | def test_neg():
  function test_cos (line 74) | def test_cos():
  function test_tan (line 79) | def test_tan():
  function test_cosh (line 84) | def test_cosh():
  function test_sinh (line 89) | def test_sinh():
  function test_tanh (line 94) | def test_tanh():
  function test_arccos (line 99) | def test_arccos():
  function test_arcsin (line 104) | def test_arcsin():
  function test_arctan (line 109) | def test_arctan():
  function test_arccosh (line 114) | def test_arccosh():
  function test_arcsinh (line 119) | def test_arcsinh():
  function test_arctanh (line 124) | def test_arctanh():
  function test_sqrt (line 129) | def test_sqrt():
  function test_power_arg0 (line 134) | def test_power_arg0():
  function test_power_arg1 (line 146) | def test_power_arg1():
  function test_power_arg1_zero (line 152) | def test_power_arg1_zero():
  function test_mod_arg0 (line 157) | def test_mod_arg0():
  function test_mod_arg1 (line 162) | def test_mod_arg1():
  function test_divide_arg0 (line 167) | def test_divide_arg0():
  function test_divide_arg1 (line 172) | def test_divide_arg1():
  function test_multiply_arg0 (line 177) | def test_multiply_arg0():
  function test_multiply_arg1 (line 182) | def test_multiply_arg1():
  function test_true_divide_arg0 (line 187) | def test_true_divide_arg0():
  function test_true_divide_arg1 (line 192) | def test_true_divide_arg1():
  function test_reciprocal (line 197) | def test_reciprocal():
  function test_negative (line 202) | def test_negative():
  function test_rad2deg (line 207) | def test_rad2deg():
  function test_deg2rad (line 212) | def test_deg2rad():
  function test_radians (line 217) | def test_radians():
  function test_degrees (line 222) | def test_degrees():
  function test_sinc (line 227) | def test_sinc():

FILE: tests/test_scipy.py
  function symmetrize_matrix_arg (line 35) | def symmetrize_matrix_arg(fun, argnum):
  function test_chi2_pdf (line 50) | def test_chi2_pdf():
  function test_chi2_cdf (line 53) | def test_chi2_cdf():
  function test_chi2_logpdf (line 56) | def test_chi2_logpdf():
  function test_beta_cdf (line 59) | def test_beta_cdf():
  function test_beta_pdf (line 62) | def test_beta_pdf():
  function test_beta_logpdf (line 65) | def test_beta_logpdf():
  function test_gamma_cdf (line 68) | def test_gamma_cdf():
  function test_gamma_pdf (line 71) | def test_gamma_pdf():
  function test_gamma_logpdf (line 74) | def test_gamma_logpdf():
  function test_norm_pdf (line 77) | def test_norm_pdf():
  function test_norm_cdf (line 80) | def test_norm_cdf():
  function test_norm_sf (line 83) | def test_norm_sf():
  function test_norm_logpdf (line 86) | def test_norm_logpdf():
  function test_norm_logcdf (line 89) | def test_norm_logcdf():
  function test_norm_logsf (line 92) | def test_norm_logsf():
  function test_norm_pdf_broadcast (line 95) | def test_norm_pdf_broadcast():
  function test_norm_cdf_broadcast (line 98) | def test_norm_cdf_broadcast():
  function test_norm_sf_broadcast (line 101) | def test_norm_sf_broadcast():
  function test_norm_logpdf_broadcast (line 104) | def test_norm_logpdf_broadcast():
  function test_norm_logcdf_broadcast (line 107) | def test_norm_logcdf_broadcast():
  function test_norm_logsf_broadcast (line 110) | def test_norm_logsf_broadcast():
  function test_poisson_cdf (line 113) | def test_poisson_cdf():
  function test_poisson_logpmf (line 116) | def test_poisson_logpmf():
  function test_poisson_pmf (line 119) | def test_poisson_pmf():
  function test_poisson_cdf_broadcast (line 122) | def test_poisson_cdf_broadcast():
  function test_poisson_logpmf_broadcast (line 125) | def test_poisson_logpmf_broadcast():
  function test_poisson_pmf_broadcast (line 128) | def test_poisson_pmf_broadcast():
  function test_t_pdf (line 131) | def test_t_pdf():
  function test_t_cdf (line 134) | def test_t_cdf():
  function test_t_logpdf (line 137) | def test_t_logpdf():
  function test_t_logcdf (line 140) | def test_t_logcdf():
  function test_t_pdf_broadcast (line 143) | def test_t_pdf_broadcast():
  function test_t_cdf_broadcast (line 148) | def test_t_cdf_broadcast():
  function test_t_logpdf_broadcast (line 151) | def test_t_logpdf_broadcast():
  function test_t_logcdf_broadcast (line 156) | def test_t_logcdf_broadcast():
  function make_psd (line 159) | def make_psd(mat):
  function test_mvn_pdf (line 162) | def test_mvn_pdf():
  function test_mvn_logpdf (line 167) | def test_mvn_logpdf():
  function test_mvn_entropy (line 172) | def test_mvn_entropy():
  function test_mvn_sing_cov (line 175) | def test_mvn_sing_cov():
  function test_mvn_pdf_broadcast (line 197) | def test_mvn_pdf_broadcast():
  function test_mvn_logpdf_broadcast (line 200) | def test_mvn_logpdf_broadcast():
  function normalize (line 207) | def normalize(x):
  function normalized_dirichlet_pdf (line 210) | def normalized_dirichlet_pdf(x, alpha):
  function normalized_dirichlet_logpdf (line 213) | def normalized_dirichlet_logpdf(x, alpha):
  function test_dirichlet_pdf_x (line 216) | def test_dirichlet_pdf_x():
  function test_dirichlet_pdf_alpha (line 219) | def test_dirichlet_pdf_alpha():
  function test_dirichlet_logpdf_x (line 222) | def test_dirichlet_logpdf_x():
  function test_dirichlet_logpdf_alpha (line 225) | def test_dirichlet_logpdf_alpha():
  function test_logsumexp1 (line 229) | def test_logsumexp1():
  function test_logsumexp2 (line 234) | def test_logsumexp2():
  function test_logsumexp3 (line 239) | def test_logsumexp3():
  function test_logsumexp4 (line 244) | def test_logsumexp4():
  function test_logsumexp5 (line 254) | def test_logsumexp5():
  function test_logsumexp6 (line 259) | def test_logsumexp6():
  function test_convolve_generalization (line 269) | def test_convolve_generalization():
  function test_convolve (line 302) | def test_convolve():
  function test_convolve_2d (line 307) | def test_convolve_2d():
  function test_convolve_ignore (line 312) | def test_convolve_ignore():
  function test_convolve_ignore_dot (line 320) | def test_convolve_ignore_dot():
  function test_beta (line 330) | def test_beta():
  function test_betainc (line 333) | def test_betainc():
  function test_betaln (line 336) | def test_betaln():
  function test_gammainc (line 339) | def test_gammainc():
  function test_gammaincc (line 342) | def test_gammaincc():
  function test_polygamma (line 345) | def test_polygamma():
  function test_jn (line 348) | def test_jn():
  function test_yn (line 351) | def test_yn():
  function test_psi (line 354) | def test_psi():
  function test_digamma (line 357) | def test_digamma():
  function test_gamma (line 360) | def test_gamma():
  function test_gammaln (line 363) | def test_gammaln():
  function test_gammasgn (line 366) | def test_gammasgn():
  function test_rgamma (line 369) | def test_rgamma():
  function test_multigammaln (line 372) | def test_multigammaln():
  function test_j0 (line 375) | def test_j0():
  function test_j1 (line 378) | def test_j1():
  function test_y0 (line 381) | def test_y0():
  function test_y1 (line 384) | def test_y1():
  function test_i0 (line 387) | def test_i0():
  function test_i1 (line 390) | def test_i1():
  function test_iv (line 393) | def test_iv():
  function test_ive (line 396) | def test_ive():
  function test_erf (line 399) | def test_erf():
  function test_erfc (line 402) | def test_erfc():
  function test_erfinv (line 405) | def test_erfinv():
  function test_erfcinv (line 408) | def test_erfcinv():
  function test_logit (line 411) | def test_logit():
  function test_expit (line 414) | def test_expit():
  function func (line 418) | def func(y, t, arg1, arg2):
  function test_odeint (line 421) | def test_odeint():
  function test_sqrtm (line 425) | def test_sqrtm():
  function test_sqrtm (line 428) | def test_sqrtm():
  function test_solve_sylvester (line 431) | def test_solve_sylvester():
  function test_solve_banded (line 436) | def test_solve_banded():

FILE: tests/test_systematic.py
  function test_max (line 14) | def test_max():
  function test_max (line 20) | def test_max():
  function test_mean (line 24) | def test_mean():
  function test_min (line 28) | def test_min():
  function test_sum (line 32) | def test_sum():
  function test_prod (line 36) | def test_prod():
  function test_var (line 40) | def test_var():
  function test_std (line 44) | def test_std():
  function test_sin (line 51) | def test_sin():
  function test_abs (line 55) | def test_abs():
  function test_absolute (line 59) | def test_absolute():
  function test_arccosh (line 63) | def test_arccosh():
  function test_arcsinh (line 67) | def test_arcsinh():
  function test_arctanh (line 71) | def test_arctanh():
  function test_ceil (line 75) | def test_ceil():
  function test_cos (line 79) | def test_cos():
  function test_cosh (line 83) | def test_cosh():
  function test_deg2rad (line 87) | def test_deg2rad():
  function test_degrees (line 91) | def test_degrees():
  function test_exp (line 95) | def test_exp():
  function test_exp2 (line 99) | def test_exp2():
  function test_expm1 (line 103) | def test_expm1():
  function test_fabs (line 107) | def test_fabs():
  function test_floor (line 111) | def test_floor():
  function test_log (line 115) | def test_log():
  function test_log10 (line 119) | def test_log10():
  function test_log1p (line 123) | def test_log1p():
  function test_log2 (line 127) | def test_log2():
  function test_rad2deg (line 131) | def test_rad2deg():
  function test_radians (line 135) | def test_radians():
  function test_sign (line 139) | def test_sign():
  function test_sin (line 143) | def test_sin():
  function test_sinh (line 147) | def test_sinh():
  function test_sqrt (line 151) | def test_sqrt():
  function test_square (line 155) | def test_square():
  function test_tan (line 159) | def test_tan():
  function test_tanh (line 163) | def test_tanh():
  function test_real (line 167) | def test_real():
  function test_real_ic (line 171) | def test_real_ic():
  function test_imag (line 175) | def test_imag():
  function test_conj (line 179) | def test_conj():
  function test_conjugate (line 183) | def test_conjugate():
  function test_angle (line 187) | def test_angle():
  function test_add (line 194) | def test_add():
  function test_logaddexp (line 198) | def test_logaddexp():
  function test_logaddexp2 (line 202) | def test_logaddexp2():
  function test_remainder (line 206) | def test_remainder():
  function test_true_divide (line 210) | def test_true_divide():
  function test_mod (line 214) | def test_mod():
  function test_true_divide_neg (line 218) | def test_true_divide_neg():
  function test_mod_neg (line 222) | def test_mod_neg():
  function test_op_mul (line 226) | def test_op_mul():
  function test_op_add (line 230) | def test_op_add():
  function test_op_sub (line 234) | def test_op_sub():
  function test_op_mod (line 238) | def test_op_mod():
  function test_op_mod_neg (line 242) | def test_op_mod_neg():
  function test_transpose (line 252) | def test_transpose():
  function test_moveaxis (line 258) | def test_moveaxis():
  function test_repeat (line 262) | def test_repeat():
  function test_diff (line 266) | def test_diff():
  function test_gradient (line 272) | def test_gradient():
  function test_tile (line 277) | def test_tile():
  function test_kron (line 283) | def test_kron():
  function test_inner (line 290) | def test_inner():
  function test_dot (line 294) | def test_dot():
  function test_outer (line 300) | def test_outer():
  function test_matmul (line 304) | def test_matmul():
  function test_matmul_broadcast (line 310) | def test_matmul_broadcast():
  function test_tensordot_1 (line 314) | def test_tensordot_1():
  function test_tensordot_2 (line 320) | def test_tensordot_2():
  function test_tensordot_3 (line 326) | def test_tensordot_3():
  function test_tensordot_4 (line 332) | def test_tensordot_4():
  function test_tensordot_5 (line 336) | def test_tensordot_5():
  function test_tensordot_6 (line 340) | def test_tensordot_6():
  function test_tensordot_7 (line 344) | def test_tensordot_7():
  function test_tensordot_8 (line 348) | def test_tensordot_8():
  function test_maximum (line 353) | def test_maximum():
  function test_fmax (line 357) | def test_fmax():
  function test_minimum (line 361) | def test_minimum():
  function test_fmin (line 365) | def test_fmin():
  function test_sort (line 369) | def test_sort():
  function test_msort (line 375) | def test_msort():
  function test_partition (line 379) | def test_partition():
  function test_atleast_1d (line 383) | def test_atleast_1d():
  function test_atleast_2d (line 387) | def test_atleast_2d():
  function test_atleast_3d (line 391) | def test_atleast_3d():
  function test_einsum_transpose (line 395) | def test_einsum_transpose():
  function test_einsum_matmult (line 399) | def test_einsum_matmult():
  function test_einsum_matmult_broadcast (line 403) | def test_einsum_matmult_broadcast():
  function test_einsum_matmult_broadcast_leadzero (line 407) | def test_einsum_matmult_broadcast_leadzero():
  function test_einsum_covsum (line 411) | def test_einsum_covsum():
  function test_einsum_ellipses (line 415) | def test_einsum_ellipses():
  function test_einsum_ellipses_tail (line 421) | def test_einsum_ellipses_tail():
  function test_einsum_ellipses_center (line 425) | def test_einsum_ellipses_center():
  function test_einsum_three_args (line 429) | def test_einsum_three_args():
  function test_einsum2_transpose (line 433) | def test_einsum2_transpose():
  function test_einsum2_matmult (line 437) | def test_einsum2_matmult():
  function test_einsum2_matmult_broadcast (line 441) | def test_einsum2_matmult_broadcast():
  function test_einsum2_covsum (line 451) | def test_einsum2_covsum():
  function test_einsum2_three_args (line 455) | def test_einsum2_three_args():
  function test_einsum_naked_sum (line 461) | def test_einsum_naked_sum():
  function test_einsum_naked_sum2 (line 465) | def test_einsum_naked_sum2():
  function test_einsum_naked_sum_ellipsis (line 469) | def test_einsum_naked_sum_ellipsis():
  function test_einsum_no_output_indices (line 473) | def test_einsum_no_output_indices():
  function test_trace (line 477) | def test_trace():
  function test_diag (line 481) | def test_diag():
  function test_diag_flat (line 485) | def test_diag_flat():
  function test_tril (line 489) | def test_tril():
  function test_triu (line 493) | def test_triu():
  function test_tril_3d (line 497) | def test_tril_3d():
  function test_triu_3d (line 501) | def test_triu_3d():
  function test_swapaxes (line 505) | def test_swapaxes():
  function test_rollaxis (line 509) | def test_rollaxis():
  function test_cross (line 513) | def test_cross():
  function test_vsplit_2d (line 519) | def test_vsplit_2d():
  function test_vsplit_3d (line 523) | def test_vsplit_3d():
  function test_hsplit_2d (line 527) | def test_hsplit_2d():
  function test_hsplit_3d (line 531) | def test_hsplit_3d():
  function test_dsplit_3d (line 535) | def test_dsplit_3d():
  function test_split_1d (line 539) | def test_split_1d():
  function test_split_2d (line 543) | def test_split_2d():
  function test_split_3d (line 547) | def test_split_3d():
  function test_array_split_1d (line 551) | def test_array_split_1d():
  function test_array_split_2d (line 555) | def test_array_split_2d():
  function test_array_split_3d (line 559) | def test_array_split_3d():
  function test_concatenate_1ist (line 563) | def test_concatenate_1ist():
  function test_concatenate_tuple (line 567) | def test_concatenate_tuple():
  function test_concatenate_2d (line 571) | def test_concatenate_2d():
  function test_concatenate_3d (line 575) | def test_concatenate_3d():
  function test_vstack_1d (line 579) | def test_vstack_1d():
  function test_vstack_2d (line 583) | def test_vstack_2d():
  function test_vstack_3d (line 587) | def test_vstack_3d():
  function test_hstack_1d (line 591) | def test_hstack_1d():
  function test_hstack_2d (line 595) | def test_hstack_2d():
  function test_hstack_3d (line 599) | def test_hstack_3d():
  function test_stack_1d (line 603) | def test_stack_1d():
  function test_row_stack_1d (line 607) | def test_row_stack_1d():
  function test_row_stack_2d (line 611) | def test_row_stack_2d():
  function test_column_stack_1d (line 615) | def test_column_stack_1d():
  function test_column_stack_2d (line 619) | def test_column_stack_2d():
  function test_select (line 623) | def test_select():
  function test_pad (line 631) | def test_pad():

FILE: tests/test_tests.py
  function test_check_vjp_1st_order_fail (line 8) | def test_check_vjp_1st_order_fail():
  function test_check_vjp_2nd_order_fail (line 19) | def test_check_vjp_2nd_order_fail():

FILE: tests/test_truediv.py
  function test_div (line 9) | def test_div():

FILE: tests/test_tuple.py
  function test_getter (line 11) | def test_getter():
  function test_grads (line 27) | def test_grads():
  function test_nested_higher_order (line 46) | def test_nested_higher_order():
  function test_isinstance (line 58) | def test_isinstance():

FILE: tests/test_vspaces.py
  function check_vspace (line 10) | def check_vspace(value):
  function test_array_vspace (line 136) | def test_array_vspace():
  function test_array_vspace_0_dim (line 140) | def test_array_vspace_0_dim():
  function test_array_vspace_complex (line 144) | def test_array_vspace_complex():
  function test_list_vspace (line 148) | def test_list_vspace():
  function test_tuple_vspace (line 152) | def test_tuple_vspace():
  function test_dict_vspace (line 156) | def test_dict_vspace():
  function test_mixed_vspace (line 160) | def test_mixed_vspace():

FILE: tests/test_wrappers.py
  function test_return_both (line 28) | def test_return_both():
  function test_value_and_grad (line 39) | def test_value_and_grad():
  function test_hessian (line 54) | def test_hessian():
  function test_multigrad (line 67) | def test_multigrad():
  function test_value_and_multigrad (line 88) | def test_value_and_multigrad():
  function test_multigrad_onearg (line 108) | def test_multigrad_onearg():
  function test_elementwise_grad (line 115) | def test_elementwise_grad():
  function test_elementwise_grad_multiple_args (line 126) | def test_elementwise_grad_multiple_args():
  function test_hessian_tensor_product (line 139) | def test_hessian_tensor_product():
  function test_hvp (line 147) | def test_hvp():
  function test_hessian_matrix_product (line 156) | def test_hessian_matrix_product():
  function test_hessian_tensor_product_3d (line 164) | def test_hessian_tensor_product_3d():
  function test_tensor_jacobian_product (line 172) | def test_tensor_jacobian_product():
  function test_matrix_jacobian_product (line 181) | def test_matrix_jacobian_product():
  function test_tensor_jacobian_product (line 189) | def test_tensor_jacobian_product():
  function test_deprecated_defgrad_wrapper (line 197) | def test_deprecated_defgrad_wrapper():
  function test_deprecated_defvjp_wrapper (line 216) | def test_deprecated_defvjp_wrapper():
  function test_deprecated_defvjp_is_zero_wrapper (line 235) | def test_deprecated_defvjp_is_zero_wrapper():
  function test_deprecated_quick_grad_check_wrapper (line 254) | def test_deprecated_quick_grad_check_wrapper():
  function test_partial (line 261) | def test_partial():
  function test_dtypes (line 269) | def test_dtypes():
  function test_checkpoint_correctness (line 287) | def test_checkpoint_correctness():
  function checkpoint_memory (line 303) | def checkpoint_memory():
  function test_make_jvp (line 332) | def test_make_jvp():
  function _make_explicit_ggnvp (line 344) | def _make_explicit_ggnvp(f, g=lambda x: 1.0 / 2 * np.dot(x, x)):
  function test_make_ggnvp (line 357) | def test_make_ggnvp():
  function test_make_ggnvp_nondefault_g (line 369) | def test_make_ggnvp_nondefault_g():
  function test_grad_and_aux (line 383) | def test_grad_and_aux():
  function test_wrapped_name_and_docs (line 408) | def test_wrapped_name_and_docs():
Condensed preview — 120 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (460K chars).
[
  {
    "path": ".github/workflows/check.yml",
    "chars": 817,
    "preview": "name: Style and package checks\n\non:\n  pull_request:\n    branches:\n    - master\n  push:\n    branches:\n    - master\n  work"
  },
  {
    "path": ".github/workflows/publish.yml",
    "chars": 1449,
    "preview": "name: Publish\n\non:\n  workflow_dispatch:\n  release:\n    types: [published]\n\nenv:\n  PIP_DISABLE_PIP_VERSION_CHECK: '1'\n  F"
  },
  {
    "path": ".github/workflows/test.yml",
    "chars": 2603,
    "preview": "name: CI\n\non:\n  pull_request:\n    branches:\n      - master\n  push:\n    branches:\n      - master\n  workflow_dispatch:\n  s"
  },
  {
    "path": ".gitignore",
    "chars": 676,
    "preview": "__pycache__/\n*.py[cod]\n*$py.class\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 1064,
    "preview": "ci:\n  autoupdate_commit_msg: \"chore: update pre-commit hooks\"\n  autofix_commit_msg: \"style: pre-commit fixes\"\n\nrepos:\n  "
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 2504,
    "preview": "# Contributing\n\nUse [Nox](https://nox.thea.codes/en/stable/) to run tests and linting, e.g.,\n\n```shell\npip install nox\n`"
  },
  {
    "path": "README.md",
    "chars": 4780,
    "preview": "# Autograd  [![Checks status][checks-badge]][checks-url] [![Tests status][tests-badge]][tests-url] [![Publish status][pu"
  },
  {
    "path": "autograd/__init__.py",
    "chars": 528,
    "preview": "from autograd.core import primitive_with_deprecation_warnings as primitive\n\nfrom .builtins import dict, isinstance, list"
  },
  {
    "path": "autograd/builtins.py",
    "chars": 6475,
    "preview": "from .extend import (\n    Box,\n    SparseObject,\n    VSpace,\n    defjvp,\n    defjvp_argnum,\n    defvjp,\n    defvjp_argnu"
  },
  {
    "path": "autograd/core.py",
    "chars": 11985,
    "preview": "from functools import reduce\nfrom itertools import count\n\nfrom .tracer import Box, Node, getval, isbox, primitive, topos"
  },
  {
    "path": "autograd/differential_operators.py",
    "chars": 8468,
    "preview": "\"\"\"Convenience functions built on top of `make_vjp`.\"\"\"\n\nfrom collections import OrderedDict\n\ntry:\n    from inspect impo"
  },
  {
    "path": "autograd/extend.py",
    "chars": 317,
    "preview": "# Exposes API for extending autograd\nfrom .core import (\n    JVPNode,\n    SparseObject,\n    VJPNode,\n    VSpace,\n    def"
  },
  {
    "path": "autograd/misc/__init__.py",
    "chars": 62,
    "preview": "from .flatten import flatten\nfrom .tracers import const_graph\n"
  },
  {
    "path": "autograd/misc/fixed_points.py",
    "chars": 748,
    "preview": "from autograd import make_vjp\nfrom autograd.builtins import tuple\nfrom autograd.extend import defvjp, primitive, vspace\n"
  },
  {
    "path": "autograd/misc/flatten.py",
    "chars": 1125,
    "preview": "\"\"\"\nHandy functions for flattening nested containers containing numpy\narrays. The main purpose is to make examples and o"
  },
  {
    "path": "autograd/misc/optimizers.py",
    "chars": 2714,
    "preview": "\"\"\"Some standard gradient-based stochastic optimizers.\n\nThese are just standard routines that don't make any use of auto"
  },
  {
    "path": "autograd/misc/tracers.py",
    "chars": 2220,
    "preview": "from functools import partial\nfrom itertools import repeat\n\nfrom autograd.tracer import Node, trace\nfrom autograd.util i"
  },
  {
    "path": "autograd/numpy/__init__.py",
    "chars": 171,
    "preview": "from . import fft, linalg, numpy_boxes, numpy_jvps, numpy_vjps, numpy_vspaces, random\nfrom .numpy_wrapper import *\nfrom "
  },
  {
    "path": "autograd/numpy/fft.py",
    "chars": 5101,
    "preview": "import numpy.fft as ffto\n\nfrom autograd.extend import defvjp, primitive, vspace\n\nfrom . import numpy_wrapper as anp\nfrom"
  },
  {
    "path": "autograd/numpy/linalg.py",
    "chars": 10998,
    "preview": "from functools import partial\n\nimport numpy.linalg as npla\n\nfrom autograd.extend import defjvp, defvjp\n\nfrom . import nu"
  },
  {
    "path": "autograd/numpy/numpy_boxes.py",
    "chars": 4180,
    "preview": "import numpy as np\n\nfrom autograd.builtins import SequenceBox\nfrom autograd.extend import Box, primitive\n\nfrom . import "
  },
  {
    "path": "autograd/numpy/numpy_jvps.py",
    "chars": 10161,
    "preview": "import numpy as onp\n\nfrom autograd.extend import JVPNode, def_linear, defjvp, defjvp_argnum, register_notrace, vspace\n\nf"
  },
  {
    "path": "autograd/numpy/numpy_vjps.py",
    "chars": 32396,
    "preview": "from functools import partial\n\nimport numpy as onp\n\nfrom autograd.extend import SparseObject, VJPNode, defvjp, defvjp_ar"
  },
  {
    "path": "autograd/numpy/numpy_vspaces.py",
    "chars": 3673,
    "preview": "import numpy as np\n\nfrom autograd.builtins import NamedTupleVSpace\nfrom autograd.extend import VSpace\n\n\nclass ArrayVSpac"
  },
  {
    "path": "autograd/numpy/numpy_wrapper.py",
    "chars": 5658,
    "preview": "import warnings\n\nimport numpy as _np\n\nimport autograd.builtins as builtins\nfrom autograd.extend import notrace_primitive"
  },
  {
    "path": "autograd/numpy/random.py",
    "chars": 111,
    "preview": "import numpy.random as npr\n\nfrom .numpy_wrapper import wrap_namespace\n\nwrap_namespace(npr.__dict__, globals())\n"
  },
  {
    "path": "autograd/scipy/__init__.py",
    "chars": 48,
    "preview": "from . import integrate, signal, special, stats\n"
  },
  {
    "path": "autograd/scipy/integrate.py",
    "chars": 2782,
    "preview": "import scipy.integrate\n\nimport autograd.numpy as np\nfrom autograd import make_vjp\nfrom autograd.builtins import tuple\nfr"
  },
  {
    "path": "autograd/scipy/linalg.py",
    "chars": 4295,
    "preview": "from functools import partial\n\nimport scipy.linalg\n\nimport autograd.numpy as anp\nfrom autograd.extend import defjvp, def"
  },
  {
    "path": "autograd/scipy/signal.py",
    "chars": 5750,
    "preview": "from functools import partial\n\nimport numpy as npo  # original numpy\nfrom numpy.lib.stride_tricks import as_strided\n\nimp"
  },
  {
    "path": "autograd/scipy/special.py",
    "chars": 5083,
    "preview": "import scipy.special\n\nimport autograd.numpy as np\nfrom autograd.extend import defjvp, defvjp, primitive\nfrom autograd.nu"
  },
  {
    "path": "autograd/scipy/stats/__init__.py",
    "chars": 287,
    "preview": "from . import beta, chi2, gamma, norm, poisson, t\n\n# Try block needed in case the user has an\n# old version of scipy wit"
  },
  {
    "path": "autograd/scipy/stats/beta.py",
    "chars": 1334,
    "preview": "import scipy.stats\n\nimport autograd.numpy as np\nfrom autograd.extend import defvjp, primitive\nfrom autograd.numpy.numpy_"
  },
  {
    "path": "autograd/scipy/stats/chi2.py",
    "chars": 800,
    "preview": "import scipy.stats\n\nimport autograd.numpy as np\nfrom autograd.extend import defvjp, primitive\nfrom autograd.numpy.numpy_"
  },
  {
    "path": "autograd/scipy/stats/dirichlet.py",
    "chars": 684,
    "preview": "import scipy.stats\n\nimport autograd.numpy as np\nfrom autograd.extend import defvjp, primitive\nfrom autograd.scipy.specia"
  },
  {
    "path": "autograd/scipy/stats/gamma.py",
    "chars": 970,
    "preview": "import scipy.stats\n\nimport autograd.numpy as np\nfrom autograd.extend import defvjp, primitive\nfrom autograd.numpy.numpy_"
  },
  {
    "path": "autograd/scipy/stats/multivariate_normal.py",
    "chars": 2505,
    "preview": "import scipy.stats\n\nimport autograd.numpy as np\nfrom autograd.extend import defvjp, primitive\nfrom autograd.numpy.numpy_"
  },
  {
    "path": "autograd/scipy/stats/norm.py",
    "chars": 2749,
    "preview": "\"\"\"Gradients of the normal distribution.\"\"\"\n\nimport scipy.stats\n\nimport autograd.numpy as anp\nfrom autograd.extend impor"
  },
  {
    "path": "autograd/scipy/stats/poisson.py",
    "chars": 684,
    "preview": "import scipy.stats\n\nimport autograd.numpy as np\nfrom autograd.extend import defvjp, primitive\nfrom autograd.numpy.numpy_"
  },
  {
    "path": "autograd/scipy/stats/t.py",
    "chars": 2728,
    "preview": "\"\"\"Gradients of the univariate t distribution.\"\"\"\n\nimport scipy.stats\n\nimport autograd.numpy as np\nfrom autograd.extend "
  },
  {
    "path": "autograd/test_util.py",
    "chars": 2796,
    "preview": "from itertools import product\n\nfrom .core import make_jvp, make_vjp, vspace\nfrom .wrap_util import get_name, unary_to_na"
  },
  {
    "path": "autograd/tracer.py",
    "chars": 3906,
    "preview": "import warnings\nfrom collections import defaultdict\nfrom contextlib import contextmanager\n\nfrom .util import subvals, to"
  },
  {
    "path": "autograd/util.py",
    "chars": 1394,
    "preview": "import operator\n\n\ndef subvals(x, ivs):\n    x_ = list(x)\n    for i, v in ivs:\n        x_[i] = v\n    return tuple(x_)\n\n\nde"
  },
  {
    "path": "autograd/wrap_util.py",
    "chars": 1614,
    "preview": "from .util import subvals\n\n\ndef unary_to_nary(unary_operator):\n    @wraps(unary_operator)\n    def nary_operator(fun, arg"
  },
  {
    "path": "benchmarks/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "benchmarks/asv.conf.json.sample",
    "chars": 387,
    "preview": "{\n    \"version\": 1,\n    \"project\": \"autograd\",\n    \"project_url\": \"http://github.com/hips/autograd\",\n    \"branches\": [\"m"
  },
  {
    "path": "benchmarks/bench_core.py",
    "chars": 3206,
    "preview": "import numpy as onp\n\nimport autograd.numpy as np\nfrom autograd import grad\n\ntry:\n    from autograd.core import VJPNode, "
  },
  {
    "path": "benchmarks/bench_mem.py",
    "chars": 231,
    "preview": "import autograd.numpy as np\nfrom autograd import grad\n\n\ndef peakmem_needless_nodes():\n    N, M = 1000, 100\n\n    def fun("
  },
  {
    "path": "benchmarks/bench_numpy_vjps.py",
    "chars": 2191,
    "preview": "import autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import make_vjp\n\ndot_0 = lambda a, b, g: m"
  },
  {
    "path": "benchmarks/bench_rnn.py",
    "chars": 5812,
    "preview": "# Write the benchmarking functions here.\n# See \"Writing benchmarks\" in the asv docs for more information.\n# http://asv.r"
  },
  {
    "path": "benchmarks/bench_util.py",
    "chars": 1272,
    "preview": "import autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import grad\n\ntry:\n    from autograd.misc.f"
  },
  {
    "path": "conda_recipe/conda.yaml",
    "chars": 706,
    "preview": "package:\n  name: autograd\n  # there are ways to derive version from other sources; for now, it's hard-coded\n  version: 1"
  },
  {
    "path": "docs/tutorial.md",
    "chars": 17287,
    "preview": "# Autograd tutorial\n\n## Motivation\n\nImagine you want to test out a new machine learning model for your data. This\nusuall"
  },
  {
    "path": "docs/updateguide.md",
    "chars": 4444,
    "preview": "# Autograd v1.2 update guide\n\nAutograd v1.2 changed the interface for defining custom vector-Jacobian\nproducts (VJPs). L"
  },
  {
    "path": "examples/README.md",
    "chars": 720,
    "preview": "# Autograd examples\n\n## Usage instructions\n\nSome of the examples require additional dependencies beyond Autograd and its"
  },
  {
    "path": "examples/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "examples/bayesian_neural_net.py",
    "chars": 3772,
    "preview": "import matplotlib.pyplot as plt\nfrom black_box_svi import black_box_variational_inference\n\nimport autograd.numpy as np\ni"
  },
  {
    "path": "examples/bayesian_optimization.py",
    "chars": 4952,
    "preview": "\"\"\"This Bayesian optimization demo using gradient-based optimization\nto find the next query point.\"\"\"\n\nimport matplotlib"
  },
  {
    "path": "examples/black_box_svi.py",
    "chars": 3041,
    "preview": "import matplotlib.pyplot as plt\n\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\nimport autograd.scipy.s"
  },
  {
    "path": "examples/convnet.py",
    "chars": 7083,
    "preview": "\"\"\"Convolutional neural net on MNIST, modeled on 'LeNet-5',\nhttp://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf\"\"\"\n\nimpor"
  },
  {
    "path": "examples/data.py",
    "chars": 2753,
    "preview": "import data_mnist\nimport matplotlib.image\nimport matplotlib.pyplot as plt\n\nimport autograd.numpy as np\nimport autograd.n"
  },
  {
    "path": "examples/data_mnist.py",
    "chars": 1437,
    "preview": "import array\nimport gzip\nimport os\nimport struct\nfrom urllib.request import urlretrieve\n\nimport numpy as np\n\n\ndef downlo"
  },
  {
    "path": "examples/deep_gaussian_process.py",
    "chars": 5017,
    "preview": "import matplotlib.pyplot as plt\nfrom gaussian_process import make_gp_funs, rbf_covariance\nfrom scipy.optimize import min"
  },
  {
    "path": "examples/define_gradient.py",
    "chars": 2261,
    "preview": "\"\"\"This example shows how to define the gradient of your own functions.\nThis can be useful for speed, numerical stabilit"
  },
  {
    "path": "examples/dot_graph.py",
    "chars": 2113,
    "preview": "\"\"\"Generates a graphviz DOT file of an evaluation trace.\nUsage (need the dot binary, from the graphviz package, www.grap"
  },
  {
    "path": "examples/fixed_points.py",
    "chars": 636,
    "preview": "import autograd.numpy as np\nfrom autograd import grad\nfrom autograd.misc.fixed_points import fixed_point\n\n\ndef newton_sq"
  },
  {
    "path": "examples/fluidsim/fluidsim.py",
    "chars": 4786,
    "preview": "import os\n\nimport matplotlib\nimport matplotlib.pyplot as plt\nfrom matplotlib.pyplot import imread\nfrom scipy.optimize im"
  },
  {
    "path": "examples/fluidsim/wing.py",
    "chars": 6101,
    "preview": "import os\n\nimport matplotlib.pyplot as plt\nfrom scipy.optimize import minimize\n\nimport autograd.numpy as np\nfrom autogra"
  },
  {
    "path": "examples/gaussian_process.py",
    "chars": 3937,
    "preview": "import matplotlib.pyplot as plt\nfrom scipy.optimize import minimize\n\nimport autograd.numpy as np\nimport autograd.numpy.r"
  },
  {
    "path": "examples/generative_adversarial_net.py",
    "chars": 5988,
    "preview": "# Implements a Generative Adversarial Network, from\n# arxiv.org/abs/1406.2661\n# but, it always collapses to generating a"
  },
  {
    "path": "examples/gmm.py",
    "chars": 2962,
    "preview": "\"\"\"Implements a Gaussian mixture model, in which parameters are fit using\ngradient descent.  This example runs on 2-dime"
  },
  {
    "path": "examples/gplvm.py",
    "chars": 3027,
    "preview": "# Implements a Gaussian process latent-variable model.\n# The (high-dimensional) data, Y is explained by some low-dimensi"
  },
  {
    "path": "examples/hmm_em.py",
    "chars": 2793,
    "preview": "import string\nfrom functools import partial\nfrom os.path import dirname, join\n\nimport autograd.numpy as np\nimport autogr"
  },
  {
    "path": "examples/ica.py",
    "chars": 4394,
    "preview": "import matplotlib.cm as cm\nimport matplotlib.pyplot as plt\nfrom scipy.optimize import minimize\n\nimport autograd.numpy as"
  },
  {
    "path": "examples/logistic_regression.py",
    "chars": 1237,
    "preview": "import autograd.numpy as np\nfrom autograd import grad\nfrom autograd.test_util import check_grads\n\n\ndef sigmoid(x):\n    r"
  },
  {
    "path": "examples/lstm.py",
    "chars": 4216,
    "preview": "\"\"\"Implements the long-short term memory character model.\nThis version vectorizes over multiple examples, but each strin"
  },
  {
    "path": "examples/mixture_variational_inference.py",
    "chars": 6663,
    "preview": "# Implements black-box variational inference, where the variational\n# distribution is a mixture of Gaussians.\n#\n# This t"
  },
  {
    "path": "examples/natural_gradient_black_box_svi.py",
    "chars": 4636,
    "preview": "import matplotlib.pyplot as plt\n\n# same BBSVI function!\nfrom black_box_svi import black_box_variational_inference\n\nimpor"
  },
  {
    "path": "examples/negative_binomial_maxlike.py",
    "chars": 2117,
    "preview": "import scipy.optimize\n\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import grad\nfrom au"
  },
  {
    "path": "examples/neural_net.py",
    "chars": 3143,
    "preview": "\"\"\"A multi-layer perceptron for classification of MNIST handwritten digits.\"\"\"\n\nfrom data import load_mnist\n\nimport auto"
  },
  {
    "path": "examples/neural_net_regression.py",
    "chars": 2611,
    "preview": "import matplotlib.pyplot as plt\n\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\nimport autograd.scipy.s"
  },
  {
    "path": "examples/ode_net.py",
    "chars": 3879,
    "preview": "# A demo of gradients through scipy.integrate.odeint,\n# estimating the dynamics of a system given a trajectory.\n\n\nimport"
  },
  {
    "path": "examples/print_trace.py",
    "chars": 1403,
    "preview": "\"\"\"Demonstrates how to use the tracer module, independent of autodiff, by\ncreating a trace that prints out functions and"
  },
  {
    "path": "examples/rkhs.py",
    "chars": 2460,
    "preview": "\"\"\"\nInferring a function from a reproducing kernel Hilbert space (RKHS) by taking\ngradients of eval with respect to the "
  },
  {
    "path": "examples/rnn.py",
    "chars": 4697,
    "preview": "\"\"\"Implements the long-short term memory character model.\nThis version vectorizes over multiple examples, but each strin"
  },
  {
    "path": "examples/rosenbrock.py",
    "chars": 461,
    "preview": "from scipy.optimize import minimize\n\nimport autograd.numpy as np\nfrom autograd import value_and_grad\n\n\ndef rosenbrock(x)"
  },
  {
    "path": "examples/sinusoid.py",
    "chars": 1003,
    "preview": "import matplotlib.pyplot as plt\n\nimport autograd.numpy as np\nfrom autograd import grad\n\n\ndef fun(x):\n    return np.sin(x"
  },
  {
    "path": "examples/tanh.py",
    "chars": 1491,
    "preview": "import matplotlib.pyplot as plt\n\nimport autograd.numpy as np\nfrom autograd import elementwise_grad as egrad\n\n\"\"\"\nMathema"
  },
  {
    "path": "examples/variational_autoencoder.py",
    "chars": 5422,
    "preview": "# Implements auto-encoding variational Bayes.\n\nfrom data import load_mnist, save_images\n\nimport autograd.numpy as np\nimp"
  },
  {
    "path": "license.txt",
    "chars": 1117,
    "preview": "The MIT License (MIT)\n\nCopyright (c) 2025 by the President and Fellows of Harvard University\n\nPermission is hereby grant"
  },
  {
    "path": "noxfile.py",
    "chars": 2374,
    "preview": "import platform\n\nimport nox\n\nNIGHTLY_INDEX_URL = \"https://pypi.anaconda.org/scientific-python-nightly-wheels/simple\"\nUV_"
  },
  {
    "path": "pyproject.toml",
    "chars": 2189,
    "preview": "[build-system]\nrequires = [\"hatchling\"]\nbuild-backend = \"hatchling.build\"\n\n[project]\nname = \"autograd\"\nversion = \"1.8.0\""
  },
  {
    "path": "tests/_test_complexity.py",
    "chars": 1459,
    "preview": "import time\nimport warnings\n\nimport autograd.numpy as np\nfrom autograd import deriv, grad\nfrom autograd.builtins import "
  },
  {
    "path": "tests/check_examples_run.sh",
    "chars": 950,
    "preview": "#!/bin/bash\n\nPYTHONPATH=\".:$PYTHONPATH\"\ntrap 'kill -INT -$pid && exit 1' INT\n\nworking=()\nfailing=()\n\nexamples=$(find exa"
  },
  {
    "path": "tests/conftest.py",
    "chars": 107,
    "preview": "import numpy as np\nimport pytest\n\n\n@pytest.fixture(autouse=True)\ndef random_seed():\n    np.random.seed(42)\n"
  },
  {
    "path": "tests/numpy_utils.py",
    "chars": 2902,
    "preview": "import autograd.numpy.random as npr\nfrom autograd.test_util import combo_check\n\n\ndef stat_check(fun, test_complex=True, "
  },
  {
    "path": "tests/profiling.py",
    "chars": 1112,
    "preview": "from contextlib import contextmanager\nfrom time import time\n\nimport autograd.numpy as np\nimport autograd.numpy.random as"
  },
  {
    "path": "tests/test_binary_ops.py",
    "chars": 3003,
    "preview": "import itertools as it\nimport warnings\n\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd im"
  },
  {
    "path": "tests/test_builtins.py",
    "chars": 674,
    "preview": "import autograd.numpy as np\nfrom autograd import grad\nfrom autograd.builtins import isinstance\n\n\ndef test_isinstance():\n"
  },
  {
    "path": "tests/test_complex.py",
    "chars": 1119,
    "preview": "import autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import grad\nfrom autograd.test_util import"
  },
  {
    "path": "tests/test_core.py",
    "chars": 1679,
    "preview": "\"\"\"This file doesn't import the numpy wrapper, to check if core works\non basic operations even without numpy.\"\"\"\n\nimport"
  },
  {
    "path": "tests/test_dict.py",
    "chars": 3500,
    "preview": "import operator as op\n\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import dict as ag_d"
  },
  {
    "path": "tests/test_direct.py",
    "chars": 902,
    "preview": "\"\"\"\nSet of tests that are as explicit as possible, in case the test helpers like\nautograd.test_util break and start lett"
  },
  {
    "path": "tests/test_fft.py",
    "chars": 6431,
    "preview": "from functools import partial\n\nimport pytest\n\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autog"
  },
  {
    "path": "tests/test_graphs.py",
    "chars": 4958,
    "preview": "import warnings\n\nimport pytest\n\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import gra"
  },
  {
    "path": "tests/test_jacobian.py",
    "chars": 1429,
    "preview": "import autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import grad, jacobian\nfrom autograd.test_u"
  },
  {
    "path": "tests/test_linalg.py",
    "chars": 13015,
    "preview": "from functools import partial\n\nimport numpy as onp\nimport pytest\n\nimport autograd.numpy as np\nimport autograd.numpy.rand"
  },
  {
    "path": "tests/test_list.py",
    "chars": 1787,
    "preview": "import autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import grad\nfrom autograd import isinstanc"
  },
  {
    "path": "tests/test_logic.py",
    "chars": 1559,
    "preview": "import warnings\nfrom contextlib import contextmanager\n\nimport pytest\n\nimport autograd.numpy as np\nfrom autograd import d"
  },
  {
    "path": "tests/test_misc.py",
    "chars": 2843,
    "preview": "import autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import grad, make_vjp\nfrom autograd.misc i"
  },
  {
    "path": "tests/test_numpy.py",
    "chars": 18309,
    "preview": "import warnings\n\nfrom numpy_utils import combo_check\n\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\nfr"
  },
  {
    "path": "tests/test_performance.py",
    "chars": 144,
    "preview": "# TODO:\n# Do a huge calculation with trivial primitive computations\n# and lots of diamonds and get a benchmark per-node "
  },
  {
    "path": "tests/test_scalar_ops.py",
    "chars": 4661,
    "preview": "import autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import grad\nfrom autograd.test_util import"
  },
  {
    "path": "tests/test_scipy.py",
    "chars": 15547,
    "preview": "from functools import partial\n\nimport numpy as npo\n\ntry:\n    import scipy\nexcept:\n    from warnings import warn\n\n    war"
  },
  {
    "path": "tests/test_systematic.py",
    "chars": 14289,
    "preview": "import operator as op\n\nimport numpy as onp\nfrom numpy_utils import binary_ufunc_check, binary_ufunc_check_no_same_args, "
  },
  {
    "path": "tests/test_tests.py",
    "chars": 778,
    "preview": "from pytest import raises\n\nfrom autograd.extend import defvjp\nfrom autograd.test_util import check_grads\nfrom autograd.t"
  },
  {
    "path": "tests/test_truediv.py",
    "chars": 417,
    "preview": "# This file is to check that future division works.\n\nfrom test_binary_ops import arg_pairs\n\nimport autograd.numpy as np\n"
  },
  {
    "path": "tests/test_tuple.py",
    "chars": 1675,
    "preview": "import autograd.numpy as np\nimport autograd.numpy.random as npr\nfrom autograd import grad\nfrom autograd import isinstanc"
  },
  {
    "path": "tests/test_vspaces.py",
    "chars": 4770,
    "preview": "import itertools as it\nfrom functools import reduce\n\nimport numpy as np\n\nfrom autograd.core import vspace\nfrom autograd."
  },
  {
    "path": "tests/test_wrappers.py",
    "chars": 10842,
    "preview": "import warnings\nfrom functools import partial\n\nimport pytest\n\nimport autograd.numpy as np\nimport autograd.numpy.random a"
  }
]

About this extraction

This page contains the full source code of the HIPS/autograd GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 120 files (426.4 KB), approximately 134.9k tokens, and a symbol index with 1282 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!