Full Code of facebookresearch/functorch for AI

main b71aa0b4387b cached
18 files
98.5 KB
30.5k tokens
1 requests
Download .txt
Repository: facebookresearch/functorch
Branch: main
Commit: b71aa0b4387b
Files: 18
Total size: 98.5 KB

Directory structure:
gitextract_sltc66ra/

├── .flake8
├── .github/
│   └── workflows/
│       └── wheels.yml
├── .gitignore
├── .lintrunner.toml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── notebooks/
│   ├── README.md
│   └── colab/
│       ├── aot_autograd_optimizations.ipynb
│       ├── jacobians_hessians_colab.ipynb
│       └── per_sample_grads_colab.ipynb
├── packaging/
│   └── windows/
│       └── internal/
│           ├── cuda_install.bat
│           └── driver_update.bat
├── pull_request_template.md
├── setup.cfg
├── setup.py
└── version.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: .flake8
================================================
[flake8]
select = B,C,E,F,P,T4,W,B9
max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
    E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
    # shebang has extra meaning in fbcode lints, so I think it's not worth trying
    # to line this up with executable bit
    EXE001,
    # these ignores are from flake8-bugbear; please fix!
    B007,B008,
    # these ignores are from flake8-comprehensions; please fix!
    C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
exclude =
    ./.git,
    ./benchmarks,
    ./docs,
    ./examples,
    ./notebooks


================================================
FILE: .github/workflows/wheels.yml
================================================
name: Wheels
on:
  pull_request:
    types: [opened, synchronize, reopened]
  push:
    branches:
      - main

jobs:

  build-wheel:
    runs-on: ubuntu-22.04
    steps:
      - name: Setup Python
        uses: actions/setup-python@v2
        with:
          python-version: 3.9
          architecture: x64
      - name: Checkout functorch
        uses: actions/checkout@v2
      - name: Install PyTorch Nightly
        run: |
          python3 -mpip install --pre torch>=1.13.0.dev -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
      - name: Build wheel
        run: |
          python3 -mpip install wheel
          python3 setup.py bdist_wheel
      - name: Upload wheel as GHA artifact
        uses: actions/upload-artifact@v2
        with:
          name: functorch.whl
          path: dist/*.whl


================================================
FILE: .gitignore
================================================
build/
dist/
functorch.egg-info/
*__pycache__*
functorch/version.py
functorch/_C.so
.gdbinit
t.py
.vscode/
ccache.sh
docs/build
docs/src
docs/source/generated
.DS_Store
op_analysis/*.txt

# Editor temporaries
*.swn
*.swo
*.swp
*.swm


================================================
FILE: .lintrunner.toml
================================================
[[linter]]
code = 'FLAKE8'
include_patterns = ['**/*.py']
exclude_patterns = [
    '.git/**',
    'benchmarks/**',
    'docs/**',
    'examples/**',
    'notebooks/**',
]
command = [
    'python3',
    'tools/lint/flake8_linter.py',
    '--',
    '@{{PATHSFILE}}'
]
init_command = [
    'python3',
    'tools/lint/pip_init.py',
    '--dry-run={{DRYRUN}}',
    'flake8==3.8.2',
    'flake8-bugbear==20.1.4',
    'flake8-comprehensions==3.3.0',
    'flake8-executable==2.0.4',
    'flake8-pyi==20.5.0',
    'mccabe==0.6.1',
    'pycodestyle==2.6.0',
    'pyflakes==2.2.0',
]

# [[linter]]
# code = 'BLACK'
# include_patterns = [
#     '**/*.py',
# ]
# command = [
#     'python3',
#     'tools/lint/black_linter.py',
#     '--',
#     '@{{PATHSFILE}}'
# ]
# init_command = [
#     'python3',
#     'tools/lint/pip_init.py',
#     '--dry-run={{DRYRUN}}',
#     'black==22.3.0',
# ]
# is_formatter = true


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <conduct@pytorch.org>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq


================================================
FILE: CONTRIBUTING.md
================================================
## Contributing
Feedback on our APIs, as well as finding bugs, would be very helpful.

Please feel free to chat us up on the PyTorch Slack, or open an issue
at https://github.com/pytorch/functorch if you're interested in
contributing.

To contribute a change to functorch, please make sure you are submitting a
Pull Request to the functorch folder in https://github.com/pytorch/pytorch
repository. The source of truth for functorch has moved there from
https://github.com/pytorch/functorch ; the code in the pytorch/functorch
repository is read-only.


================================================
FILE: LICENSE
================================================
Copyright (c) 2021 Facebook, Inc. and its affiliates. All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice,
   this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors
   may be used to endorse or promote products derived from this software
   without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
# functorch

[**Why functorch?**](#why-composable-function-transforms)
| [**Install guide**](#install)
| [**Transformations**](#what-are-the-transforms)
| [**Documentation**](#documentation)
| [**Future Plans**](#future-plans)

**This library is currently under heavy development - if you have suggestions
on the API or use-cases you'd like to be covered, please open an github issue
or reach out. We'd love to hear about how you're using the library.**

`functorch` is [JAX-like](https://github.com/google/jax) composable function
transforms for PyTorch.

It aims to provide composable `vmap` and `grad` transforms that work with
PyTorch modules and PyTorch autograd with good eager-mode performance.

In addition, there is experimental functionality to trace through these
transformations using FX in order to capture the results of these transforms
ahead of time. This would allow us to compile the results of vmap or grad
to improve performance.

## Why composable function transforms?

There are a number of use cases that are tricky to do in
PyTorch today:
- computing per-sample-gradients (or other per-sample quantities)
- running ensembles of models on a single machine
- efficiently batching together tasks in the inner-loop of MAML
- efficiently computing Jacobians and Hessians
- efficiently computing batched Jacobians and Hessians

Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above
without designing a separate subsystem for each. This idea of composable function
transforms comes from the [JAX framework](https://github.com/google/jax).

## Install

There are two ways to install functorch:
1. functorch from source
2. functorch beta (compatible with recent PyTorch releases)

We recommend trying out the functorch beta first.

### Installing functorch from source

<details><summary>Click to expand</summary>
<p>

#### Using Colab

Follow the instructions [in this Colab notebook](https://colab.research.google.com/drive/1CrLkqIrydBYP_svnF89UUO-aQEqNPE8x?usp=sharing)

#### Locally

As of 9/21/2022, `functorch` comes installed alongside a nightly PyTorch binary.
Please install a Preview (nightly) PyTorch binary; see  https://pytorch.org/
for instructions.

Once you've done that, run a quick sanity check in Python:
```py
import torch
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())
```

#### functorch development setup

As of 9/21/2022, `functorch` comes installed alongside PyTorch and is in the
PyTorch source tree. Please install
[PyTorch from source](https://github.com/pytorch/pytorch#from-source), then,
you will be able to `import functorch`.

Try to run some tests to make sure all is OK:
```bash
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v
```

AOTAutograd has some additional optional requirements. You can install them via:
```bash
pip install networkx
```

To run functorch tests, please install our test dependencies (`expecttest`, `pyyaml`).


</p>
</details>

### Installing functorch beta (compatible with recent PyTorch releases)

<details><summary>Click to expand</summary>
<p>

#### Using Colab

Follow the instructions [here](https://colab.research.google.com/drive/1GNfb01W_xf8JRu78ZKoNnLqiwcrJrbYG#scrollTo=HJ1srOGeNCGA)

#### pip

Prerequisite: [Install PyTorch](https://pytorch.org/get-started/locally/)


```bash
pip install functorch
```

Finally, run a quick sanity check in python:
```py
import torch
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())
```

</p>
</details>

## What are the transforms?

Right now, we support the following transforms:
- `grad`, `vjp`, `jvp`,
- `jacrev`, `jacfwd`, `hessian`
- `vmap`

Furthermore, we have some utilities for working with PyTorch modules.
- `make_functional(model)`
- `make_functional_with_buffers(model)`

### vmap

Note: `vmap` imposes restrictions on the code that it can be used on.
For more details, please read its docstring.

`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor
operations in `func`. `vmap(func)` returns a new function that maps `func` over
some dimension (default: 0) of each Tensor in `inputs`.

`vmap` is useful for hiding batch dimensions: one can write a function `func`
that runs on examples and then lift it to a function that can take batches of
examples with `vmap(func)`, leading to a simpler modeling experience:

```py
from functorch import vmap
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

def model(feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
result = vmap(model)(examples)
```

### grad

`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute
the gradients of the output of func w.r.t. to `inputs[0]`.

```py
from functorch import grad
x = torch.randn([])
cos_x = grad(lambda x: torch.sin(x))(x)
assert torch.allclose(cos_x, x.cos())

# Second-order gradients
neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
assert torch.allclose(neg_sin_x, -x.sin())
```

When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
```py
from functorch import vmap
batch_size, feature_size = 3, 5

def model(weights,feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

def compute_loss(weights, example, target):
    y = model(weights, example)
    return ((y - target) ** 2).mean()  # MSELoss

weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)
inputs = (weights,examples, targets)
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
```

### vjp

The `vjp` transform applies `func` to `inputs` and returns a new function that
computes vjps given some `cotangents` Tensors.
```py
from functorch import vjp
outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)
```

### jvp

The `jvp` transforms computes Jacobian-vector-products and is also known as
"forward-mode AD". It is not a higher-order function unlike most other transforms,
but it returns the outputs of `func(inputs)` as well as the `jvp`s.
```py
from functorch import jvp
x = torch.randn(5)
y = torch.randn(5)
f = lambda x, y: (x * y)
_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
assert torch.allclose(output, x + y)
```

### jacrev, jacfwd, and hessian

The `jacrev` transform returns a new function that takes in `x` and returns the
Jacobian of `torch.sin` with respect to `x` using reverse-mode AD.
```py
from functorch import jacrev
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)
```
Use `jacrev` to compute the jacobian. This can be composed with vmap to produce
batched jacobians:

```py
x = torch.randn(64, 5)
jacobian = vmap(jacrev(torch.sin))(x)
assert jacobian.shape == (64, 5, 5)
```

`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using
forward-mode AD:
```py
from functorch import jacfwd
x = torch.randn(5)
jacobian = jacfwd(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)
```

Composing `jacrev` with itself or `jacfwd` can produce hessians:
```py
def f(x):
  return x.sin().sum()

x = torch.randn(5)
hessian0 = jacrev(jacrev(f))(x)
hessian1 = jacfwd(jacrev(f))(x)
```

The `hessian` is a convenience function that combines `jacfwd` and `jacrev`:
```py
from functorch import hessian

def f(x):
  return x.sin().sum()

x = torch.randn(5)
hess = hessian(f)(x)
```

### Tracing through the transformations
We can also trace through these transformations in order to capture the results as new code using `make_fx`. There is also experimental integration with the NNC compiler (only works on CPU for now!).

```py
from functorch import make_fx, grad
def f(x):
    return torch.sin(x).sum()
x = torch.randn(100)
grad_f = make_fx(grad(f))(x)
print(grad_f.code)

def forward(self, x_1):
    sin = torch.ops.aten.sin(x_1)
    sum_1 = torch.ops.aten.sum(sin, None);  sin = None
    cos = torch.ops.aten.cos(x_1);  x_1 = None
    _tensor_constant0 = self._tensor_constant0
    mul = torch.ops.aten.mul(_tensor_constant0, cos);  _tensor_constant0 = cos = None
    return mul
```

### Working with NN modules: make_functional and friends

Sometimes you may want to perform a transform with respect to the parameters
and/or buffers of an nn.Module. This can happen for example in:
- model ensembling, where all of your weights and buffers have an additional
dimension
- per-sample-gradient computation where you want to compute per-sample-grads
of the loss with respect to the model parameters

Our solution to this right now is an API that, given an nn.Module, creates a
stateless version of it that can be called like a function.

- `make_functional(model)` returns a functional version of `model` and the
`model.parameters()`
- `make_functional_with_buffers(model)` returns a functional version of
`model` and the `model.parameters()` and `model.buffers()`.

Here's an example where we compute per-sample-gradients using an nn.Linear
layer:

```py
import torch
from functorch import make_functional, vmap, grad

model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)

func_model, params = make_functional(model)

def compute_loss(params, data, targets):
    preds = func_model(params, data)
    return torch.mean((preds - targets) ** 2)

per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)
```

If you're making an ensemble of models, you may find
`combine_state_for_ensemble` useful.

## Documentation

For more documentation, see [our docs website](https://pytorch.org/functorch).

## Debugging
`torch._C._functorch.dump_tensor`: Dumps dispatch keys on stack
`torch._C._functorch._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you.

## Future Plans

In the end state, we'd like to upstream this into PyTorch once we iron out the
design details. To figure out the details, we need your help -- please send us
your use cases by starting a conversation in the issue tracker or trying our
project out.

## License
Functorch has a BSD-style license, as found in the [LICENSE](LICENSE) file.

## Citing functorch

If you use functorch in your publication, please cite it by using the following BibTeX entry.

```bibtex
@Misc{functorch2021,
  author =       {Horace He, Richard Zou},
  title =        {functorch: JAX-like composable function transforms for PyTorch},
  howpublished = {\url{https://github.com/pytorch/functorch}},
  year =         {2021}
}
```


================================================
FILE: notebooks/README.md
================================================
The new, updated versions of these notebooks may be found in the pytorch/pytorch repo.

We're leaving the old notebooks here as a temporary solution so that our website still
points to the correct thing. We plan to rewrite the links on the website to point to
their newer counterparts soon.


================================================
FILE: notebooks/colab/aot_autograd_optimizations.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AOT Autograd - How to use and optimize?\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/aot_autograd_optimizations.ipynb\">\n",
    "  <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>\n",
    "\n",
    "## Background\n",
    "In this tutorial, we will learn how to use AOT Autograd to speedup training of deep learning models.\n",
    "\n",
    "For background, AOT Autograd is a toolkit to assist developers in accelerating training on PyTorch. Broadly, it has two key features\n",
    "* AOT Autograd traces the forward and backward graph ahead of time. Presence of forward and backward graph ahead of time facilitates joint graph optimizations such as recomputation or activation checkpointing.\n",
    "* AOT Autograd provides simple mechanisms to compile the extracted forward and backward graphs through deep learning compilers, such as NVFuser, NNC, TVM and others.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## What will you learn?\n",
    "In this tutorial, we will look at how AOT Autograd can be used, in conjunction with backend compilers, to accelerate the training of PyTorch models. More specifically, you will learn\n",
    "* How to use AOT Autograd?\n",
    "* How AOT Autograd uses backend compilers to perform operation fusion?\n",
    "* How AOT Autograd enables training-specific optimizations such as Recomputation?\n",
    "\n",
    "So, lets get started.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Let's setup a simple model.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def fn(a, b, c, d):\n",
    "    x = a + b + c + d\n",
    "    return x.cos().cos()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test that it works\n",
    "a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]\n",
    "ref = fn(a, b, c, d)\n",
    "loss = ref.sum()\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use AOT Autograd\n",
    "\n",
    "Now, lets use AOT Autograd and look at the extracted forward and backward graphs. Internally, AOT uses `__torch_dispatch__` based tracing mechanism to extract forward and backward graphs, and wraps them in `torch.Fx` GraphModule containers. Note that AOT Autograd tracing is different from the usual Fx symbolic tracing. AOT Autograd uses Fx GraphModule just to represent the traced graphs (and not for tracing).\n",
    "\n",
    "AOT Autograd then sends these forward and backward graphs to the user supplied compilers. So, lets write a compiler that just prints the graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "def forward(self, primals_1, primals_2, primals_3, primals_4):\n",
      "    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None\n",
      "    add_1 = torch.ops.aten.add(add, primals_3);  add = primals_3 = None\n",
      "    add_2 = torch.ops.aten.add(add_1, primals_4);  add_1 = primals_4 = None\n",
      "    cos = torch.ops.aten.cos(add_2)\n",
      "    cos_1 = torch.ops.aten.cos(cos)\n",
      "    return [cos_1, add_2, cos]\n",
      "    \n",
      "\n",
      "\n",
      "\n",
      "def forward(self, add_2, cos, tangents_1):\n",
      "    sin = torch.ops.aten.sin(cos);  cos = None\n",
      "    neg = torch.ops.aten.neg(sin);  sin = None\n",
      "    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None\n",
      "    sin_1 = torch.ops.aten.sin(add_2);  add_2 = None\n",
      "    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None\n",
      "    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None\n",
      "    return [mul_1, mul_1, mul_1, mul_1]\n",
      "    \n"
     ]
    }
   ],
   "source": [
    "from functorch.compile import aot_function\n",
    "\n",
    "# The compiler_fn is called after the forward and backward graphs are extracted.\n",
    "# Here, we just print the code in the compiler_fn. Return of this function is a callable.\n",
    "def compiler_fn(fx_module: torch.fx.GraphModule, _):\n",
    "    print(fx_module.code)\n",
    "    return fx_module\n",
    "\n",
    "# Pass on the compiler_fn to the aot_function API\n",
    "aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)\n",
    "\n",
    "# Run the aot_print_fn once to trigger the compilation and print the graphs\n",
    "res = aot_print_fn(a, b, c, d).sum().backward()\n",
    "assert torch.allclose(ref, res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above code prints the Fx graph for the forward and backward graph. You can see that in addition to the original input of the forward pass, the forward graph outputs some additional tensors. These tensors are saved for the backward pass for gradient calculation. We will come back to these later while talking about recomputation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Operator Fusion\n",
    "Now that we understand how to use AOT Autograd to print forward and backward graphs, let us use AOT Autograd to use some actual deep learning compiler. In this tutorial, we use PyTorch Neural Network Compiler (NNC) to perform pointwise operator fusion for CPU devices. For CUDA devices, a suitable alternative is NvFuser. So, lets use NNC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile\n",
    "from functorch.compile import ts_compile\n",
    "\n",
    "# Lets compile the forward and backward through ts_compile.\n",
    "aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)\n",
    "\n",
    "# Correctness checking. Lets clone the input so that we can check grads.\n",
    "cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n",
    "cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n",
    "\n",
    "res = aot_nnc_fn(*cloned_inputs)\n",
    "loss = res.sum()\n",
    "loss.backward()\n",
    "assert torch.allclose(ref, res)\n",
    "assert torch.allclose(a.grad, cloned_a.grad)\n",
    "assert torch.allclose(b.grad, cloned_b.grad)\n",
    "assert torch.allclose(c.grad, cloned_c.grad)\n",
    "assert torch.allclose(d.grad, cloned_d.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets benchmark the original and AOT Autograd + NNC compiled function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lets write a function to benchmark the forward and backward pass\n",
    "import time\n",
    "import statistics\n",
    "\n",
    "def bench(fn, args, prefix):\n",
    "    warmup = 10\n",
    "    iterations = 100\n",
    "\n",
    "    for _ in range(warmup):\n",
    "        ref = fn(*args)\n",
    "        ref.sum().backward()\n",
    "    \n",
    "    fw_latencies = []\n",
    "    bw_latencies = []\n",
    "    for _ in range(iterations):\n",
    "        for arg in args:\n",
    "            arg.grad = None\n",
    "\n",
    "        fw_begin = time.perf_counter()\n",
    "        ref = fn(*args)\n",
    "        fw_end = time.perf_counter()\n",
    "\n",
    "        loss = ref.sum() \n",
    "\n",
    "        bw_begin = time.perf_counter()\n",
    "        loss.backward()\n",
    "        bw_end = time.perf_counter()\n",
    "\n",
    "        fw_latencies.append(fw_end - fw_begin)\n",
    "        bw_latencies.append(bw_end - bw_begin)\n",
    "    \n",
    "    avg_fw_latency = statistics.mean(fw_latencies) * 10**6\n",
    "    avg_bw_latency = statistics.mean(bw_latencies) * 10**6\n",
    "    print(prefix, \"Fwd = \" + str(avg_fw_latency) + \" us\", \"Bwd = \" + str(avg_bw_latency) + \" us\", sep=', ')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us\n",
      "AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us\n"
     ]
    }
   ],
   "source": [
    "large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]\n",
    "\n",
    "# Benchmark the Eager and AOT Autograd functions\n",
    "bench(fn, large_inputs, \"Eager\")\n",
    "bench(aot_nnc_fn, large_inputs, \"AOT\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the help of NNC, AOT Autograd speeds up both the forward and backward pass. If we look at the printed graphs earlier, all the operators are pointwise. The pointwise operators are memory bandwidth bound, and thus benefit from operator fusion. Looking closely at the numbers, the backward pass gets higher speedup. This is because forward pass has to output some intermediate tensors for gradient calculation for the backward pass, preventing it from saving some memory reads and writes. However, such restriction does not exist in the backward graph."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recomputation (aka Activation Checkpointing)\n",
    "Recomputation (often called activation checkpointing) is a technique in which, instead of saving some activations for use in backwards, we recompute them **during** the backwards pass. Recomputing saves memory, but we incur performance overhead.\n",
    "\n",
    "However, in the presence of fusing compiler, we can do better that that. We can recompute the fusion-friendly operators to save memory, and then rely on the fusing compiler to fuse the recomputed operators. This reduces both memory and runtime. Please refer to this [discuss post](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) for more details.\n",
    "\n",
    "Here, we use AOT Autograd with NNC to perform similar type of recomputation. At the end of `__torch_dispatch__` tracing, AOT Autograd has a forward graph and joint forward-backward graph. AOT Autograd then uses a partitioner to isolate the forward and backward graph. In the example above, we used a default partitioner. For this experiment, we will use another partitioner called `min_cut_rematerialization_partition` to perform smarter fusion-aware recomputation. The partitioner is configurable and one can write their own partitioner to plug it in AOT Autograd."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "def forward(self, primals_1, primals_2, primals_3, primals_4):\n",
      "    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None\n",
      "    add_1 = torch.ops.aten.add(add, primals_3);  add = primals_3 = None\n",
      "    add_2 = torch.ops.aten.add(add_1, primals_4);  add_1 = primals_4 = None\n",
      "    cos = torch.ops.aten.cos(add_2)\n",
      "    cos_1 = torch.ops.aten.cos(cos);  cos = None\n",
      "    return [cos_1, add_2]\n",
      "    \n",
      "\n",
      "\n",
      "\n",
      "def forward(self, add_2, tangents_1):\n",
      "    cos = torch.ops.aten.cos(add_2)\n",
      "    sin = torch.ops.aten.sin(cos);  cos = None\n",
      "    neg = torch.ops.aten.neg(sin);  sin = None\n",
      "    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None\n",
      "    sin_1 = torch.ops.aten.sin(add_2);  add_2 = None\n",
      "    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None\n",
      "    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None\n",
      "    return [mul_1, mul_1, mul_1, mul_1]\n",
      "    \n"
     ]
    }
   ],
   "source": [
    "from functorch.compile import min_cut_rematerialization_partition\n",
    "\n",
    "# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.\n",
    "# This will show us how the recomputation has modified the graph.\n",
    "aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)\n",
    "res = aot_fn(a, b, c, d).sum().backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that compared to default partitioner, forward pass now outputs fewer tensors, and recomputes some operations in the backward pass. Let us try NNC compiler now to perform operator fusions (note that we also have a wrapper function - `memory_efficient_fusion` which internally uses `min_cut_rematerialization_partition` and Torchscript compiler to achieve the same effect as following code)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Lets set up the partitioner and NNC compiler.\n",
    "aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)\n",
    "\n",
    "# Correctness checking. Lets clone the input so that we can check grads.\n",
    "cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n",
    "cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n",
    "\n",
    "res = aot_recompute_nnc_fn(*cloned_inputs)\n",
    "loss = res.sum()\n",
    "loss.backward()\n",
    "assert torch.allclose(ref, res)\n",
    "assert torch.allclose(a.grad, cloned_a.grad)\n",
    "assert torch.allclose(b.grad, cloned_b.grad)\n",
    "assert torch.allclose(c.grad, cloned_c.grad)\n",
    "assert torch.allclose(d.grad, cloned_d.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, lets benchmark the different functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us\n",
      "AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us\n",
      "AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us\n"
     ]
    }
   ],
   "source": [
    "bench(fn, large_inputs, \"Eager\")\n",
    "bench(aot_nnc_fn, large_inputs, \"AOT\")\n",
    "bench(aot_recompute_nnc_fn, large_inputs, \"AOT_Recomp\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We observe that both forward and backward latency improve over the default partitioner (and a lot better than eager). Fewer outputs in the forward pass and fewer inputs in the backward pass, along with fusion, allows better memory bandwidth utilization leading to further speedups."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Actual Usage\n",
    "For actual usage on CUDA devices, we've wrapped AOTAutograd in a convenient wrapper - `memory_efficient_fusion`. Use this for fusion on GPU!\n",
    "\n",
    "```\n",
    "from functorch.compile import memory_efficient_fusion\n",
    "```\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.5 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "73b6e0ee7c860e06bb349c72324473b318d6cb6c97bcad772bce0703fb8d0dfb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: notebooks/colab/jacobians_hessians_colab.ipynb
================================================
{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/jacobians_hessians.ipynb\">\n",
        "  <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "Computing jacobians or hessians are useful in a number of non-traditional\n",
        "deep learning models. It is difficult (or annoying) to compute these quantities\n",
        "efficiently using a standard autodiff system like PyTorch Autograd; functorch\n",
        "provides ways of computing various higher-order autodiff quantities efficiently."
      ],
      "metadata": {
        "id": "zPbR6-eP51fe"
      },
      "id": "zPbR6-eP51fe"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Computing the Jacobian"
      ],
      "metadata": {
        "id": "3kDj8fhn52j3"
      },
      "id": "3kDj8fhn52j3"
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from functools import partial\n",
        "_ = torch.manual_seed(0)"
      ],
      "metadata": {
        "id": "w_IinyjzflUH"
      },
      "execution_count": null,
      "outputs": [],
      "id": "w_IinyjzflUH"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let’s start with a function that we’d like to compute the jacobian of.  This is a simple linear function with non-linear activation.\n",
        "\n"
      ],
      "metadata": {
        "id": "cibF_PEYflUH"
      },
      "id": "cibF_PEYflUH"
    },
    {
      "cell_type": "code",
      "source": [
        "def predict(weight, bias, x):\n",
        "    return F.linear(x, weight, bias).tanh()"
      ],
      "metadata": {
        "id": "qhcD9hWYflUH"
      },
      "execution_count": null,
      "outputs": [],
      "id": "qhcD9hWYflUH"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's add some dummy data:   a weight, a bias, and a feature vector x.\n",
        "\n"
      ],
      "metadata": {
        "id": "G8tqQrO_flUH"
      },
      "id": "G8tqQrO_flUH"
    },
    {
      "cell_type": "code",
      "source": [
        "D = 16\n",
        "weight = torch.randn(D, D)\n",
        "bias = torch.randn(D)\n",
        "x = torch.randn(D) # feature vector"
      ],
      "metadata": {
        "id": "FZ4uJfZGflUH"
      },
      "execution_count": null,
      "outputs": [],
      "id": "FZ4uJfZGflUH"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.\n",
        "PyTorch Autograd computes vector-Jacobian products. In order to compute the full\n",
        "Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row\n",
        "by using a different unit vector each time."
      ],
      "metadata": {
        "id": "uMAW-ArQflUH"
      },
      "id": "uMAW-ArQflUH"
    },
    {
      "cell_type": "code",
      "source": [
        "def compute_jac(xp):\n",
        "    jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n",
        "                     for vec in unit_vectors]\n",
        "    return torch.stack(jacobian_rows)"
      ],
      "metadata": {
        "id": "z-BJPtbpflUI"
      },
      "execution_count": null,
      "outputs": [],
      "id": "z-BJPtbpflUI"
    },
    {
      "cell_type": "code",
      "source": [
        "xp = x.clone().requires_grad_()\n",
        "unit_vectors = torch.eye(D)\n",
        "\n",
        "jacobian = compute_jac(xp)\n",
        "\n",
        "print(jacobian.shape)\n",
        "print(jacobian[0])  # show first row"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f1f1ec12-56ef-40f7-8c3c-cbad7bf86644",
        "id": "zuWGSXspflUI"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([16, 16])\n",
            "tensor([-0.5956, -0.6096, -0.1326, -0.2295,  0.4490,  0.3661, -0.1672, -1.1190,\n",
            "         0.1705, -0.6683,  0.1851,  0.1630,  0.0634,  0.6547,  0.5908, -0.1308])\n"
          ]
        }
      ],
      "id": "zuWGSXspflUI"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Instead of computing the jacobian row-by-row, we can use vmap to get rid of the for-loop and vectorize the computation. \n",
        "We can’t directly apply vmap to PyTorch Autograd; instead, functorch provides a vjp transform:\n",
        "\n"
      ],
      "metadata": {
        "id": "mxlEOUieflUI"
      },
      "id": "mxlEOUieflUI"
    },
    {
      "cell_type": "code",
      "source": [
        "from functorch import vmap, vjp\n",
        "\n",
        "_, vjp_fn = vjp(partial(predict, weight, bias), x)\n",
        "\n",
        "ft_jacobian, = vmap(vjp_fn)(unit_vectors)\n",
        "\n",
        "# lets confirm both methods compute the same result\n",
        "assert torch.allclose(ft_jacobian, jacobian)"
      ],
      "metadata": {
        "id": "DeF6uy4WflUI"
      },
      "execution_count": null,
      "outputs": [],
      "id": "DeF6uy4WflUI"
    },
    {
      "cell_type": "markdown",
      "source": [
        "In future tutorial a composition of reverse-mode AD and vmap will give us per-sample-gradients. \n",
        "In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! \n",
        "Various compositions of vmap and autodiff transforms can give us different interesting quantities.\n",
        "\n",
        "functorch provides **jacrev** as a convenience function that performs the vmap-vjp composition to compute jacobians. **jacrev** accepts an argnums argument that says which argument we would like to compute Jacobians with respect to.\n",
        "\n"
      ],
      "metadata": {
        "id": "Hy4REmwDflUI"
      },
      "id": "Hy4REmwDflUI"
    },
    {
      "cell_type": "code",
      "source": [
        "from functorch import jacrev\n",
        "\n",
        "ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n",
        "\n",
        "# confirm \n",
        "assert torch.allclose(ft_jacobian, jacobian)"
      ],
      "metadata": {
        "id": "Rt7i6_YlflUI"
      },
      "execution_count": null,
      "outputs": [],
      "id": "Rt7i6_YlflUI"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let’s compare the performance of the two ways to compute the jacobian. The functorch version is much faster (and becomes even faster the more outputs there are). \n",
        "\n",
        "In general, we expect that vectorization via vmap can help eliminate overhead and give better utilization of your hardware.\n",
        "\n",
        "Vmap does this magic by pushing the outer loop down into the functions primitive operations in order to obtain better performance.\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "JYe2H1UcflUJ"
      },
      "id": "JYe2H1UcflUJ"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:"
      ],
      "metadata": {
        "id": "i_143LZwflUJ"
      },
      "id": "i_143LZwflUJ"
    },
    {
      "cell_type": "code",
      "source": [
        "def get_perf(first, first_descriptor, second, second_descriptor):\n",
        "  \"\"\"  takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n",
        "  faster = second.times[0]\n",
        "  slower = first.times[0]\n",
        "  gain = (slower-faster)/slower\n",
        "  if gain < 0: gain *=-1 \n",
        "  final_gain = gain*100\n",
        "  print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")"
      ],
      "metadata": {
        "id": "II7r6jBtflUJ"
      },
      "execution_count": null,
      "outputs": [],
      "id": "II7r6jBtflUJ"
    },
    {
      "cell_type": "markdown",
      "source": [
        "And then run the performance comparison:"
      ],
      "metadata": {
        "id": "r4clPnPKflUJ"
      },
      "id": "r4clPnPKflUJ"
    },
    {
      "cell_type": "code",
      "source": [
        "from torch.utils.benchmark import Timer\n",
        "\n",
        "without_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\n",
        "with_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
        "\n",
        "no_vmap_timer = without_vmap.timeit(500)\n",
        "with_vmap_timer = with_vmap.timeit(500)\n",
        "\n",
        "print(no_vmap_timer)\n",
        "print(with_vmap_timer)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "cbf77a19-aac9-428d-eba1-74d337c53e49",
        "id": "ZPtoxF6eflUJ"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a911b350>\n",
            "compute_jac(xp)\n",
            "  2.25 ms\n",
            "  1 measurement, 500 runs , 1 thread\n",
            "<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a6a99d50>\n",
            "jacrev(predict, argnums=2)(weight, bias, x)\n",
            "  884.34 us\n",
            "  1 measurement, 500 runs , 1 thread\n"
          ]
        }
      ],
      "id": "ZPtoxF6eflUJ"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Lets do a relative performance comparison of the above with our get_perf function:"
      ],
      "metadata": {
        "id": "nGBBi4dZflUJ"
      },
      "id": "nGBBi4dZflUJ"
    },
    {
      "cell_type": "code",
      "source": [
        "get_perf(no_vmap_timer, \"without vmap\",  with_vmap_timer, \"vmap\");"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "85d0bc5f-34aa-4826-f953-6c637404490c",
        "id": "zqV2RzEXflUJ"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            " Performance delta: 60.7170 percent improvement with vmap \n"
          ]
        }
      ],
      "id": "zqV2RzEXflUJ"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input."
      ],
      "metadata": {
        "id": "EQAB99EQflUJ"
      },
      "id": "EQAB99EQflUJ"
    },
    {
      "cell_type": "code",
      "source": [
        "# note the change in input via argnums params of 0,1 to map to weight and bias\n",
        "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)"
      ],
      "metadata": {
        "id": "8UZpC8DnflUK"
      },
      "execution_count": null,
      "outputs": [],
      "id": "8UZpC8DnflUK"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n"
      ],
      "metadata": {
        "id": "F3USYENIflUK"
      },
      "id": "F3USYENIflUK"
    },
    {
      "cell_type": "markdown",
      "source": [
        "We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: \n",
        "- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. \n",
        "- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. \n",
        "\n",
        "jacfwd and jacrev can be substituted for each other but they have different performance characteristics.\n",
        "\n",
        "As a general rule of thumb, if you’re computing the jacobian of an $𝑅^N \\to R^M$ function, and there are many more outputs than inputs (i.e. $M > N$) then jacfwd is preferred, otherwise use jacrev. There are exceptions to this rule, but a non-rigorous argument for this follows:\n",
        "\n",
        "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n",
        "\n"
      ],
      "metadata": {
        "id": "V7B3vE8dflUK"
      },
      "id": "V7B3vE8dflUK"
    },
    {
      "cell_type": "code",
      "source": [
        "from functorch import jacrev, jacfwd"
      ],
      "metadata": {
        "id": "k7Tok7m3flUK"
      },
      "execution_count": null,
      "outputs": [],
      "id": "k7Tok7m3flUK"
    },
    {
      "cell_type": "markdown",
      "source": [
        "First, let's benchmark with more inputs than outputs:\n",
        "\n"
      ],
      "metadata": {
        "id": "YrV-gZAaflUL"
      },
      "id": "YrV-gZAaflUL"
    },
    {
      "cell_type": "code",
      "source": [
        "Din = 32\n",
        "Dout = 2048\n",
        "weight = torch.randn(Dout, Din)\n",
        "\n",
        "bias = torch.randn(Dout)\n",
        "x = torch.randn(Din)\n",
        "\n",
        "# remember the general rule about taller vs wider...here we have a taller matrix:\n",
        "print(weight.shape)\n",
        "\n",
        "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
        "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
        "\n",
        "jacfwd_timing = using_fwd.timeit(500)\n",
        "jacrev_timing = using_bwd.timeit(500)\n",
        "\n",
        "print(f'jacfwd time: {jacfwd_timing}')\n",
        "print(f'jacrev time: {jacrev_timing}')\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "dd882726-9723-47c0-a72f-3c7835a85aa1",
        "id": "m5j-4hSxflUL"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([2048, 32])\n",
            "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d792d0>\n",
            "jacfwd(predict, argnums=2)(weight, bias, x)\n",
            "  1.32 ms\n",
            "  1 measurement, 500 runs , 1 thread\n",
            "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a4dee450>\n",
            "jacrev(predict, argnums=2)(weight, bias, x)\n",
            "  12.46 ms\n",
            "  1 measurement, 500 runs , 1 thread\n"
          ]
        }
      ],
      "id": "m5j-4hSxflUL"
    },
    {
      "cell_type": "markdown",
      "source": [
        "and then do a relative benchmark:"
      ],
      "metadata": {
        "id": "k_Sg-4tVflUL"
      },
      "id": "k_Sg-4tVflUL"
    },
    {
      "cell_type": "code",
      "source": [
        "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3a6586a1-269d-46d8-d119-e24f6d46277f",
        "id": "_4T96zGjflUL"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            " Performance delta: 842.8274 percent improvement with jacrev \n"
          ]
        }
      ],
      "id": "_4T96zGjflUL"
    },
    {
      "cell_type": "markdown",
      "source": [
        "and now the reverse - more outputs (M) than inputs (N):"
      ],
      "metadata": {
        "id": "RCDPot1yflUL"
      },
      "id": "RCDPot1yflUL"
    },
    {
      "cell_type": "code",
      "source": [
        "Din = 2048\n",
        "Dout = 32\n",
        "weight = torch.randn(Dout, Din)\n",
        "bias = torch.randn(Dout)\n",
        "x = torch.randn(Din)\n",
        "\n",
        "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
        "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n",
        "\n",
        "jacfwd_timing = using_fwd.timeit(500)\n",
        "jacrev_timing = using_bwd.timeit(500)\n",
        "\n",
        "print(f'jacfwd time: {jacfwd_timing}')\n",
        "print(f'jacrev time: {jacrev_timing}')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "913e9ccd-3d4f-472a-a749-19cee36d0a16",
        "id": "_DRFqzqZflUM"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d64790>\n",
            "jacfwd(predict, argnums=2)(weight, bias, x)\n",
            "  7.99 ms\n",
            "  1 measurement, 500 runs , 1 thread\n",
            "jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d67b50>\n",
            "jacrev(predict, argnums=2)(weight, bias, x)\n",
            "  1.09 ms\n",
            "  1 measurement, 500 runs , 1 thread\n"
          ]
        }
      ],
      "id": "_DRFqzqZflUM"
    },
    {
      "cell_type": "markdown",
      "source": [
        "and a relative perf comparison:"
      ],
      "metadata": {
        "id": "5SRbMCNsflUM"
      },
      "id": "5SRbMCNsflUM"
    },
    {
      "cell_type": "code",
      "source": [
        "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c282ce25-4f6e-44cd-aed7-60f6f5010e5b",
        "id": "uF_9GaoiflUM"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            " Performance delta: 635.2095 percent improvement with jacfwd \n"
          ]
        }
      ],
      "id": "uF_9GaoiflUM"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Hessian computation with functorch.hessian\n"
      ],
      "metadata": {
        "id": "J29FQaBQflUM"
      },
      "id": "J29FQaBQflUM"
    },
    {
      "cell_type": "markdown",
      "source": [
        "We offer a convenience API to compute hessians: `functorch.hessian`. \n",
        "Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).\n",
        "\n",
        "This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. \n",
        "Indeed, under the hood, `hessian(f)` is simply `jacfwd(jacrev(f))`.\n",
        "\n"
      ],
      "metadata": {
        "id": "My4DPH97flUM"
      },
      "id": "My4DPH97flUM"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Note: to boost performance: depending on your model, you may also want to use `jacfwd(jacfwd(f))` or `jacrev(jacrev(f))` instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n",
        "\n"
      ],
      "metadata": {
        "id": "FJt038l5flUM"
      },
      "id": "FJt038l5flUM"
    },
    {
      "cell_type": "code",
      "source": [
        "from functorch import hessian\n",
        "\n",
        "# lets reduce the size in order not to blow out colab. Hessians require significant memory:\n",
        "Din = 512\n",
        "Dout = 32\n",
        "weight = torch.randn(Dout, Din)\n",
        "bias = torch.randn(Dout)\n",
        "x = torch.randn(Din)\n",
        "\n",
        "hess_api = hessian(predict, argnums=2)(weight, bias, x)\n",
        "hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\n",
        "#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)\n"
      ],
      "metadata": {
        "id": "jEqr2ywZflUM"
      },
      "execution_count": null,
      "outputs": [],
      "id": "jEqr2ywZflUM"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())"
      ],
      "metadata": {
        "id": "n9BHcICQflUN"
      },
      "id": "n9BHcICQflUN"
    },
    {
      "cell_type": "code",
      "source": [
        "torch.allclose(hess_api, hess_fwdfwd)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e457e3bc-f085-4f90-966d-f98893b98ea8",
        "id": "eHiWRkjJflUN"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {},
          "execution_count": 18
        }
      ],
      "id": "eHiWRkjJflUN"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Batch Jacobian and Batch Hessian\n"
      ],
      "metadata": {
        "id": "Gjt1RO8HflUN"
      },
      "id": "Gjt1RO8HflUN"
    },
    {
      "cell_type": "markdown",
      "source": [
        "In the above examples we’ve been operating with a single feature vector. In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. That is, given a batch of inputs of shape `(B, N)` and a function that goes from $R^N \\to R^M$, we would like a Jacobian of shape `(B, M, N)`. \n",
        "\n",
        "The easiest way to do this is to use vmap:"
      ],
      "metadata": {
        "id": "RjIzdoQNflUN"
      },
      "id": "RjIzdoQNflUN"
    },
    {
      "cell_type": "code",
      "source": [
        "batch_size = 64\n",
        "Din = 31\n",
        "Dout = 33\n",
        "\n",
        "weight = torch.randn(Dout, Din)\n",
        "print(f\"weight shape = {weight.shape}\")\n",
        "\n",
        "bias = torch.randn(Dout)\n",
        "\n",
        "x = torch.randn(batch_size, Din)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "561eb618-e00f-40d5-bd99-fa51ab82051f",
        "id": "B1eoEO4UflUN"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "weight shape = torch.Size([33, 31])\n"
          ]
        }
      ],
      "id": "B1eoEO4UflUN"
    },
    {
      "cell_type": "code",
      "source": [
        "compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\n",
        "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)"
      ],
      "metadata": {
        "id": "nZ_V02NhflUN"
      },
      "execution_count": null,
      "outputs": [],
      "id": "nZ_V02NhflUN"
    },
    {
      "cell_type": "markdown",
      "source": [
        "If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it’s also sometimes possible to do this without using vmap by summing the outputs and then computing the Jacobian of that function:\n",
        "\n"
      ],
      "metadata": {
        "id": "_OLDiY3MflUN"
      },
      "id": "_OLDiY3MflUN"
    },
    {
      "cell_type": "code",
      "source": [
        "def predict_with_output_summed(weight, bias, x):\n",
        "    return predict(weight, bias, x).sum(0)\n",
        "\n",
        "batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\n",
        "assert torch.allclose(batch_jacobian0, batch_jacobian1)"
      ],
      "metadata": {
        "id": "_QH4hD8PflUO"
      },
      "execution_count": null,
      "outputs": [],
      "id": "_QH4hD8PflUO"
    },
    {
      "cell_type": "markdown",
      "source": [
        "If you instead have a function that goes from $𝑅^𝑁 \\to 𝑅^𝑀$ but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n",
        "\n",
        "Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.\n",
        "\n"
      ],
      "metadata": {
        "id": "eUjw65cCflUO"
      },
      "id": "eUjw65cCflUO"
    },
    {
      "cell_type": "code",
      "source": [
        "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n",
        "\n",
        "batch_hess = compute_batch_hessian(weight, bias, x)\n",
        "batch_hess.shape"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f3135cfa-e9e5-4f18-8cb7-0655e8a37cb5",
        "id": "3vAyQjMsflUO"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([64, 33, 31, 31])"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ],
      "id": "3vAyQjMsflUO"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Computing Hessian-vector products\n",
        "\n",
        "The naive way to compute a Hessian-vector product (hvp) is to materialize the full Hessian and perform a dot-product with a vector. We can do better: it turns out we don't need to materialize the full Hessian to do this. We'll go through two (of many) different strategies to compute Hessian-vector products:\n",
        "- composing reverse-mode AD with reverse-mode AD\n",
        "- composing reverse-mode AD with forward-mode AD\n",
        "\n",
        "Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn't need to construct an Autograd graph and save intermediates for backward:"
      ],
      "metadata": {
        "id": "Wa8E48sQgpkb"
      },
      "id": "Wa8E48sQgpkb"
    },
    {
      "cell_type": "code",
      "source": [
        "from functorch import jvp, grad, vjp\n",
        "\n",
        "def hvp(f, primals, tangents):\n",
        "  return jvp(grad(f), primals, tangents)[1]"
      ],
      "metadata": {
        "id": "trw6WbAth6BM"
      },
      "execution_count": null,
      "outputs": [],
      "id": "trw6WbAth6BM"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Here's some sample usage."
      ],
      "metadata": {
        "id": "DQMpRo6nitfr"
      },
      "id": "DQMpRo6nitfr"
    },
    {
      "cell_type": "code",
      "source": [
        "def f(x):\n",
        "  return x.sin().sum()\n",
        "\n",
        "x = torch.randn(2048)\n",
        "tangent = torch.randn(2048)\n",
        "\n",
        "result = hvp(f, (x,), (tangent,))"
      ],
      "metadata": {
        "id": "sPwg8SOdiVAK"
      },
      "execution_count": null,
      "outputs": [],
      "id": "sPwg8SOdiVAK"
    },
    {
      "cell_type": "markdown",
      "source": [
        "If PyTorch forward-AD does not have coverage for your operations, then we can instead compose reverse-mode AD with reverse-mode AD:"
      ],
      "metadata": {
        "id": "zGvUIcB0j1Ez"
      },
      "id": "zGvUIcB0j1Ez"
    },
    {
      "cell_type": "code",
      "source": [
        "def hvp_revrev(f, primals, tangents):\n",
        "  _, vjp_fn = vjp(grad(f), *primals)\n",
        "  return vjp_fn(*tangents)"
      ],
      "metadata": {
        "id": "mdDFZdlekAOK"
      },
      "execution_count": null,
      "outputs": [],
      "id": "mdDFZdlekAOK"
    },
    {
      "cell_type": "code",
      "source": [
        "result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\n",
        "assert torch.allclose(result, result_hvp_revrev[0])"
      ],
      "metadata": {
        "id": "_CuCk9X0lW7C"
      },
      "execution_count": null,
      "outputs": [],
      "id": "_CuCk9X0lW7C"
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.3"
    },
    "colab": {
      "name": "jacobians_hessians.ipynb",
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}


================================================
FILE: notebooks/colab/per_sample_grads_colab.ipynb
================================================
{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "a474c143-05c4-43b6-b12c-17b592d07a6a",
      "metadata": {
        "id": "a474c143-05c4-43b6-b12c-17b592d07a6a"
      },
      "source": [
        "# Per-sample-gradients\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/per_sample_grads.ipynb\">\n",
        "  <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "## What is it?\n",
        "\n",
        "Per-sample-gradient computation is computing the gradient for each and every\n",
        "sample in a batch of data. It is a useful quantity in differential privacy, meta-learning,\n",
        "and optimization research.\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from functools import partial\n",
        "\n",
        "torch.manual_seed(0);"
      ],
      "metadata": {
        "id": "Gb-yt4VKUUuc"
      },
      "execution_count": null,
      "outputs": [],
      "id": "Gb-yt4VKUUuc"
    },
    {
      "cell_type": "code",
      "source": [
        "# Here's a simple CNN and loss function:\n",
        "\n",
        "class SimpleCNN(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(SimpleCNN, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
        "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
        "        self.fc1 = nn.Linear(9216, 128)\n",
        "        self.fc2 = nn.Linear(128, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.conv1(x)\n",
        "        x = F.relu(x)\n",
        "        x = self.conv2(x)\n",
        "        x = F.relu(x)\n",
        "        x = F.max_pool2d(x, 2)\n",
        "        x = torch.flatten(x, 1)\n",
        "        x = self.fc1(x)\n",
        "        x = F.relu(x)\n",
        "        x = self.fc2(x)\n",
        "        output = F.log_softmax(x, dim=1)\n",
        "        output = x\n",
        "        return output\n",
        "\n",
        "def loss_fn(predictions, targets):\n",
        "    return F.nll_loss(predictions, targets)"
      ],
      "metadata": {
        "id": "tf-HKHjUUbyY"
      },
      "execution_count": null,
      "outputs": [],
      "id": "tf-HKHjUUbyY"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset.  \n",
        "\n",
        "The dummy images are 28 by 28 and we use a minibatch of size 64.\n",
        "\n"
      ],
      "metadata": {
        "id": "VEDPe-EoU5Fa"
      },
      "id": "VEDPe-EoU5Fa"
    },
    {
      "cell_type": "code",
      "source": [
        "device = 'cuda'\n",
        "\n",
        "num_models = 10\n",
        "batch_size = 64\n",
        "data = torch.randn(batch_size, 1, 28, 28, device=device)\n",
        "\n",
        "targets = torch.randint(10, (64,), device=device)"
      ],
      "metadata": {
        "id": "WB2Qe3AHUvPN"
      },
      "execution_count": null,
      "outputs": [],
      "id": "WB2Qe3AHUvPN"
    },
    {
      "cell_type": "markdown",
      "source": [
        "In regular model training, one would forward the minibatch through the model, and then call .backward() to compute gradients.  This would generate an 'average' gradient of the entire mini-batch:\n",
        "\n"
      ],
      "metadata": {
        "id": "GOGJ-OUxVcT5"
      },
      "id": "GOGJ-OUxVcT5"
    },
    {
      "cell_type": "code",
      "source": [
        "model = SimpleCNN().to(device=device)\n",
        "predictions = model(data) # move the entire mini-batch through the model\n",
        "\n",
        "loss = loss_fn(predictions, targets)\n",
        "loss.backward() # back propogate the 'average' gradient of this mini-batch"
      ],
      "metadata": {
        "id": "WYjMx8QTUvRu"
      },
      "execution_count": null,
      "outputs": [],
      "id": "WYjMx8QTUvRu"
    },
    {
      "cell_type": "markdown",
      "source": [
        "In contrast to the above approach, per-sample-gradient computation is equivalent to: \n",
        "- for each individual sample of the data, perform a forward and a backward pass to get an individual (per-sample) gradient.\n",
        "\n"
      ],
      "metadata": {
        "id": "HNw4_IVzU5Pz"
      },
      "id": "HNw4_IVzU5Pz"
    },
    {
      "cell_type": "code",
      "source": [
        "def compute_grad(sample, target):\n",
        "    \n",
        "    sample = sample.unsqueeze(0)  # prepend batch dimension for processing\n",
        "    target = target.unsqueeze(0)\n",
        "\n",
        "    prediction = model(sample)\n",
        "    loss = loss_fn(prediction, target)\n",
        "\n",
        "    return torch.autograd.grad(loss, list(model.parameters()))\n",
        "\n",
        "\n",
        "def compute_sample_grads(data, targets):\n",
        "    \"\"\" manually process each sample with per sample gradient \"\"\"\n",
        "    sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]\n",
        "    sample_grads = zip(*sample_grads)\n",
        "    sample_grads = [torch.stack(shards) for shards in sample_grads]\n",
        "    return sample_grads\n",
        "\n",
        "per_sample_grads = compute_sample_grads(data, targets)"
      ],
      "metadata": {
        "id": "vUsb3VfexJrY"
      },
      "execution_count": null,
      "outputs": [],
      "id": "vUsb3VfexJrY"
    },
    {
      "cell_type": "markdown",
      "source": [
        "`sample_grads[0]` is the per-sample-grad for model.conv1.weight. `model.conv1.weight.shape` is `[32, 1, 3, 3]`; notice how there is one gradient, per sample, in the batch for a total of 64.\n",
        "\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "aNkX6lFIxzcm"
      },
      "id": "aNkX6lFIxzcm"
    },
    {
      "cell_type": "code",
      "source": [
        "print(per_sample_grads[0].shape)"
      ],
      "metadata": {
        "id": "C3a9_clvyPho",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "407abc1a-846f-4e50-83bc-c90719a26073"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([64, 32, 1, 3, 3])\n"
          ]
        }
      ],
      "id": "C3a9_clvyPho"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Per-sample-grads, *the efficient way*, using functorch\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "mFJDWMM9yaYZ"
      },
      "id": "mFJDWMM9yaYZ"
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can compute per-sample-gradients efficiently by using function transforms. \n",
        "\n",
        "First, let’s create a stateless functional version of `model` by using `functorch.make_functional_with_buffers`.  \n",
        "\n",
        "This will separate state (the parameters) from the model and turn the model into a pure function:\n",
        "\n"
      ],
      "metadata": {
        "id": "tlkmyQyfY6XU"
      },
      "id": "tlkmyQyfY6XU"
    },
    {
      "cell_type": "code",
      "source": [
        "from functorch import make_functional_with_buffers, vmap, grad\n",
        "\n",
        "fmodel, params, buffers = make_functional_with_buffers(model)"
      ],
      "metadata": {
        "id": "WiSMupvCyecd"
      },
      "execution_count": null,
      "outputs": [],
      "id": "WiSMupvCyecd"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's review the changes - first, the model has become the stateless FunctionalModuleWithBuffers:"
      ],
      "metadata": {
        "id": "wMsbppPNZklo"
      },
      "id": "wMsbppPNZklo"
    },
    {
      "cell_type": "code",
      "source": [
        "fmodel"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Xj0cZOJMZbbB",
        "outputId": "2e87dfde-3af2-4e1f-cd91-5c232446fb53"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "FunctionalModuleWithBuffers(\n",
              "  (stateless_model): SimpleCNN(\n",
              "    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
              "    (fc1): Linear(in_features=9216, out_features=128, bias=True)\n",
              "    (fc2): Linear(in_features=128, out_features=10, bias=True)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ],
      "id": "Xj0cZOJMZbbB"
    },
    {
      "cell_type": "markdown",
      "source": [
        "And the model parameters now exist independently of the model, stored as a tuple:"
      ],
      "metadata": {
        "id": "zv4_YYPxZvvg"
      },
      "id": "zv4_YYPxZvvg"
    },
    {
      "cell_type": "code",
      "source": [
        "for x in params:\n",
        "  print(f\"{x.shape}\")\n",
        "\n",
        "print(f\"\\n{type(params)}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tH0TAZhBZ3bS",
        "outputId": "97c4401f-cccb-43f6-b071-c85a18fc439b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([32, 1, 3, 3])\n",
            "torch.Size([32])\n",
            "torch.Size([64, 32, 3, 3])\n",
            "torch.Size([64])\n",
            "torch.Size([128, 9216])\n",
            "torch.Size([128])\n",
            "torch.Size([10, 128])\n",
            "torch.Size([10])\n",
            "\n",
            "<class 'tuple'>\n"
          ]
        }
      ],
      "id": "tH0TAZhBZ3bS"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Next, let’s define a function to compute the loss of the model given a single input rather than a batch of inputs. It is important that this function accepts the parameters, the input, and the target, because we will be transforming over them. \n",
        "\n",
        "Note - because the model was originally written to handle batches, we’ll use `torch.unsqueeze` to add a batch dimension.\n",
        "\n"
      ],
      "metadata": {
        "id": "cTgIIZ9Wyih8"
      },
      "id": "cTgIIZ9Wyih8"
    },
    {
      "cell_type": "code",
      "source": [
        "def compute_loss_stateless_model (params, buffers, sample, target):\n",
        "    batch = sample.unsqueeze(0)\n",
        "    targets = target.unsqueeze(0)\n",
        "\n",
        "    predictions = fmodel(params, buffers, batch) \n",
        "    loss = loss_fn(predictions, targets)\n",
        "    return loss"
      ],
      "metadata": {
        "id": "ItURFU3M-p98"
      },
      "execution_count": null,
      "outputs": [],
      "id": "ItURFU3M-p98"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now, let’s use functorch's `grad` to create a new function that computes the gradient with respect to the first argument of `compute_loss` (i.e. the params)."
      ],
      "metadata": {
        "id": "Qo3sbDK2i_bH"
      },
      "id": "Qo3sbDK2i_bH"
    },
    {
      "cell_type": "code",
      "source": [
        "ft_compute_grad = grad(compute_loss_stateless_model)"
      ],
      "metadata": {
        "id": "sqRp_Sxni-Xm"
      },
      "execution_count": null,
      "outputs": [],
      "id": "sqRp_Sxni-Xm"
    },
    {
      "cell_type": "markdown",
      "source": [
        "The `ft_compute_grad` function computes the gradient for a single (sample, target) pair. We can use vmap to get it to compute the gradient over an entire batch of samples and targets. Note that `in_dims=(None, None, 0, 0)` because we wish to map `ft_compute_grad` over the 0th dimension of the data and targets,  and use the same params and buffers for each.\n",
        "\n"
      ],
      "metadata": {
        "id": "2pG3Ofqjjc8O"
      },
      "id": "2pG3Ofqjjc8O"
    },
    {
      "cell_type": "code",
      "source": [
        "ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))"
      ],
      "metadata": {
        "id": "62ecNMO6inqX"
      },
      "execution_count": null,
      "outputs": [],
      "id": "62ecNMO6inqX"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Finally, let’s used our transformed function to compute per-sample-gradients:\n",
        "\n"
      ],
      "metadata": {
        "id": "_alXdQ3QkETu"
      },
      "id": "_alXdQ3QkETu"
    },
    {
      "cell_type": "code",
      "source": [
        "ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)\n",
        "\n",
        "# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:\n",
        "for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):\n",
        "    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)"
      ],
      "metadata": {
        "id": "1gehVA1c-BHd"
      },
      "execution_count": null,
      "outputs": [],
      "id": "1gehVA1c-BHd"
    },
    {
      "cell_type": "markdown",
      "source": [
        "A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs, and that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations.\n",
        "\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "BEZaNt1d_bc1"
      },
      "id": "BEZaNt1d_bc1"
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Performance comparison"
      ],
      "metadata": {
        "id": "BASP151Iml7B"
      },
      "id": "BASP151Iml7B"
    },
    {
      "cell_type": "markdown",
      "source": [
        "Curious about how the performance of vmap compares?\n",
        "\n",
        "Currently the best results are obtained on newer GPU's such as the A100 (Ampere) where we've seen up to 25x speedups on this example, but here are some results done in Colab:"
      ],
      "metadata": {
        "id": "jr1xNpV4nJ7u"
      },
      "id": "jr1xNpV4nJ7u"
    },
    {
      "cell_type": "code",
      "source": [
        "def get_perf(first, first_descriptor, second, second_descriptor):\n",
        "  \"\"\"  takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n",
        "  second_res = second.times[0]\n",
        "  first_res = first.times[0]\n",
        "\n",
        "  gain = (first_res-second_res)/first_res\n",
        "  if gain < 0: gain *=-1 \n",
        "  final_gain = gain*100\n",
        "\n",
        "  print(f\" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} \")"
      ],
      "metadata": {
        "id": "GnAnMkYmoc-j"
      },
      "execution_count": null,
      "outputs": [],
      "id": "GnAnMkYmoc-j"
    },
    {
      "cell_type": "code",
      "source": [
        "from torch.utils.benchmark import Timer\n",
        "\n",
        "without_vmap = Timer( stmt=\"compute_sample_grads(data, targets)\", globals=globals())\n",
        "with_vmap = Timer(stmt=\"ft_compute_sample_grad(params, buffers, data, targets)\",globals=globals())\n",
        "no_vmap_timing = without_vmap.timeit(100)\n",
        "with_vmap_timing = with_vmap.timeit(100)\n",
        "\n",
        "print(f'Per-sample-grads without vmap {no_vmap_timing}')\n",
        "print(f'Per-sample-grads with vmap {with_vmap_timing}')"
      ],
      "metadata": {
        "id": "Zfnn2C2g-6Fb",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "922f3901-773f-446b-b562-88e78f49036c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f71ac3f1850>\n",
            "compute_sample_grads(data, targets)\n",
            "  79.86 ms\n",
            "  1 measurement, 100 runs , 1 thread\n",
            "Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f7143e26f10>\n",
            "ft_compute_sample_grad(params, buffers, data, targets)\n",
            "  12.93 ms\n",
            "  1 measurement, 100 runs , 1 thread\n"
          ]
        }
      ],
      "id": "Zfnn2C2g-6Fb"
    },
    {
      "cell_type": "code",
      "source": [
        "get_perf(with_vmap_timing, \"vmap\", no_vmap_timing,\"no vmap\" )"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NV9R3LZQoavl",
        "outputId": "e11e8be9-287d-4e60-e517-e08f8d6909bd"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            " Performance delta: 517.5791 percent improvement with vmap \n"
          ]
        }
      ],
      "id": "NV9R3LZQoavl"
    },
    {
      "cell_type": "markdown",
      "source": [
        "There are other optimized solutions (like in https://github.com/pytorch/opacus) to computing per-sample-gradients in PyTorch that also perform better than the naive method. But it’s cool that composing `vmap` and `grad` give us a nice speedup.\n",
        "\n",
        "\n",
        "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n",
        "\n"
      ],
      "metadata": {
        "id": "UI74G9JarQU8"
      },
      "id": "UI74G9JarQU8"
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.5"
    },
    "colab": {
      "name": "per_sample_grads.ipynb",
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}


================================================
FILE: packaging/windows/internal/cuda_install.bat
================================================
@echo on

if "%CU_VERSION%" == "cpu" (
    echo Skipping for CPU builds
    exit /b 0
)

set SRC_DIR=%~dp0\..

if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build"

rem in unit test workflow, we get CUDA_VERSION, for example 11.1
if defined CUDA_VERSION (
    set CUDA_VER=%CUDA_VERSION:.=%
) else (
    set CUDA_VER=%CU_VERSION:cu=%
)

set /a CUDA_VER=%CU_VERSION:cu=%
set CUDA_VER_MAJOR=%CUDA_VER:~0,-1%
set CUDA_VER_MINOR=%CUDA_VER:~-1,1%
set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR%


if %CUDA_VER% EQU 92 goto cuda92
if %CUDA_VER% EQU 100 goto cuda100
if %CUDA_VER% EQU 101 goto cuda101
if %CUDA_VER% EQU 102 goto cuda102
if %CUDA_VER% EQU 110 goto cuda110
if %CUDA_VER% EQU 111 goto cuda111
if %CUDA_VER% EQU 112 goto cuda112
if %CUDA_VER% EQU 113 goto cuda113
if %CUDA_VER% EQU 115 goto cuda115


echo CUDA %CUDA_VERSION_STR% is not supported
exit /b 1

:cuda92
if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe"
    set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2"
)

if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip"
)

goto cuda_common

:cuda100

if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe"
    set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0"
)

if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip"
)

goto cuda_common

:cuda101

if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe"
    set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvjpeg_10.1 nvjpeg_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1"
)

if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip"
)

goto cuda_common

:cuda102

if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe"
    set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvjpeg_10.2 nvjpeg_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2"
)

if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip"
)

rem The below only for cu102, if it's used in other version, e.g. cu111, torch.cuda.is_availabe() would be False.
if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" (
    curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip"
    if errorlevel 1 exit /b 1
)

echo Installing GPU driver DLLs
7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32"

goto cuda_common

:cuda110

if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe"
    set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvjpeg_11.0 nvjpeg_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0"
)

if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip"
)

goto cuda_common

:cuda111

if not exist "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.1_456.81_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe"
    set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvjpeg_11.1 nvjpeg_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1"
)

if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip"
)

goto cuda_common

:cuda112

if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" (
    curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe"
    set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvjpeg_11.2 nvjpeg_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2"
)

if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" (
    curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip"
)

goto cuda_common

:cuda113

set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe
if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" (
    curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%"
    set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3"

)

set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip
if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" (
    curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%"
)

goto cuda_common

:cuda115

set CUDA_INSTALL_EXE=cuda_11.5.0_496.13_win10.exe
if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" (
    curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%"
    if errorlevel 1 exit /b 1
    set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%"
    set "ARGS=thrust_11.5 nvcc_11.5 cuobjdump_11.5 nvprune_11.5 nvprof_11.5 cupti_11.5 cublas_11.5 cublas_dev_11.5 cudart_11.5 cufft_11.5 cufft_dev_11.5 curand_11.5 curand_dev_11.5 cusolver_11.5 cusolver_dev_11.5 cusparse_11.5 cusparse_dev_11.5 npp_11.5 npp_dev_11.5 nvrtc_11.5 nvrtc_dev_11.5 nvml_dev_11.5"
)

set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip
if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" (
    curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%"
    if errorlevel 1 exit /b 1
    set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%"
)

goto cuda_common

:cuda_common

if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" (
    curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z"
    if errorlevel 1 exit /b 1
)

echo Installing CUDA toolkit...
7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda"
pushd "%SRC_DIR%\temp_build\cuda"
sc config wuauserv start= disabled
sc stop wuauserv
sc query wuauserv

start /wait setup.exe -s %ARGS% -loglevel:6 -log:"%cd%/cuda_install_logs"
echo %errorlevel%

popd

echo Installing VS integration...
rem It's for VS 2019
if "%CUDA_VER_MAJOR%" == "10" (
    xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations"
)
if "%CUDA_VER_MAJOR%" == "11" (
    xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations"
)

echo Installing NvToolsExt...
7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt"
mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64"
mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include"
mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64"
xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64"
xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include"
xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64"

echo Setting up environment...
set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%"
set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%"
set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%"
set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64"

if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" (
    echo CUDA %CUDA_VERSION_STR% installed failed.
    echo --------- RunDll32.exe.log
    type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.RunDll32.exe.log"
    echo --------- setup.exe.log -------
    type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.setup.exe.log"
    exit /b 1
)

echo Installing cuDNN...
7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn"
xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin"
xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64"
xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include"

echo Cleaning temp files
rd /s /q "%SRC_DIR%\temp_build" || ver > nul


================================================
FILE: packaging/windows/internal/driver_update.bat
================================================
set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe"
curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe
if errorlevel 1 exit /b 1

start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot
if errorlevel 1 exit /b 1

del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL

setlocal EnableDelayedExpansion
set NVIDIA_GPU_EXISTS=0
for /F "delims=" %%i in ('wmic path win32_VideoController get name') do (
    set GPUS=%%i
    if not "x!GPUS:NVIDIA=!" == "x!GPUS!" (
        SET NVIDIA_GPU_EXISTS=1
        goto gpu_check_end
    )
)
:gpu_check_end
endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS%

if "%NVIDIA_GPU_EXISTS%" == "0" (
    echo "CUDA Driver installation Failed"
    exit /b 1
)


================================================
FILE: pull_request_template.md
================================================
To contribute a change to functorch, please make sure you are submitting a
Pull Request to the functorch folder in https://github.com/pytorch/pytorch
repository. The source of truth for functorch has moved there from
https://github.com/pytorch/functorch ; the pytorch/functorch repository
is now read-only.


================================================
FILE: setup.cfg
================================================
[bdist_wheel]
universal=1

[metadata]
license_file = LICENSE

[pep8]
max-line-length = 120

[flake8]
max-line-length = 120
exclude = docs, benchmarks, notebooks, tools
per-file-ignores =
    __init__.py: F401
    functorch/_src/decompositions.py: E501

[pydocstyle]
select = D417 # Missing argument descriptions in the docstring


================================================
FILE: setup.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import subprocess
from setuptools import setup

cwd = os.path.dirname(os.path.abspath(__file__))
version_txt = os.path.join(cwd, 'version.txt')
with open(version_txt, 'r') as f:
    version = f.readline().strip()

try:
    sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
except Exception:
    sha = 'Unknown'
package_name = 'functorch'

if os.getenv('BUILD_VERSION'):
    version = os.getenv('BUILD_VERSION')
elif sha != 'Unknown':
    version += '+' + sha[:7]


requirements = [
    # This represents a nightly version of PyTorch.
    # It can be installed as a binary or from source.
    "torch>=1.14.0.dev",
]

extras = {}
extras["aot"] = ["networkx", ]


if __name__ == '__main__':
    try:
        setup(
            # Metadata
            name=package_name,
            version=version,
            author='PyTorch Core Team',
            url="https://github.com/pytorch/functorch",
            description='JAX-like composable function transforms for PyTorch',
            license='BSD',

            # Package info
            packages=[],
            install_requires=requirements,
            extras_require=extras,
        )
    except Exception as e:
        print(e, file=sys.stderr)
        sys.exit(1)


================================================
FILE: version.txt
================================================
1.14.0a0
Download .txt
gitextract_sltc66ra/

├── .flake8
├── .github/
│   └── workflows/
│       └── wheels.yml
├── .gitignore
├── .lintrunner.toml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── notebooks/
│   ├── README.md
│   └── colab/
│       ├── aot_autograd_optimizations.ipynb
│       ├── jacobians_hessians_colab.ipynb
│       └── per_sample_grads_colab.ipynb
├── packaging/
│   └── windows/
│       └── internal/
│           ├── cuda_install.bat
│           └── driver_update.bat
├── pull_request_template.md
├── setup.cfg
├── setup.py
└── version.txt
Condensed preview — 18 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (110K chars).
[
  {
    "path": ".flake8",
    "chars": 687,
    "preview": "[flake8]\nselect = B,C,E,F,P,T4,W,B9\nmax-line-length = 120\n# C408 ignored because we like the dict keyword argument synta"
  },
  {
    "path": ".github/workflows/wheels.yml",
    "chars": 824,
    "preview": "name: Wheels\non:\n  pull_request:\n    types: [opened, synchronize, reopened]\n  push:\n    branches:\n      - main\n\njobs:\n\n "
  },
  {
    "path": ".gitignore",
    "chars": 233,
    "preview": "build/\ndist/\nfunctorch.egg-info/\n*__pycache__*\nfunctorch/version.py\nfunctorch/_C.so\n.gdbinit\nt.py\n.vscode/\nccache.sh\ndoc"
  },
  {
    "path": ".lintrunner.toml",
    "chars": 901,
    "preview": "[[linter]]\ncode = 'FLAKE8'\ninclude_patterns = ['**/*.py']\nexclude_patterns = [\n    '.git/**',\n    'benchmarks/**',\n    '"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 3343,
    "preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 551,
    "preview": "## Contributing\nFeedback on our APIs, as well as finding bugs, would be very helpful.\n\nPlease feel free to chat us up on"
  },
  {
    "path": "LICENSE",
    "chars": 1519,
    "preview": "Copyright (c) 2021 Facebook, Inc. and its affiliates. All rights reserved.\n\nRedistribution and use in source and binary "
  },
  {
    "path": "README.md",
    "chars": 10829,
    "preview": "# functorch\n\n[**Why functorch?**](#why-composable-function-transforms)\n| [**Install guide**](#install)\n| [**Transformati"
  },
  {
    "path": "notebooks/README.md",
    "chars": 291,
    "preview": "The new, updated versions of these notebooks may be found in the pytorch/pytorch repo.\n\nWe're leaving the old notebooks "
  },
  {
    "path": "notebooks/colab/aot_autograd_optimizations.ipynb",
    "chars": 16387,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# AOT Autograd - How to use and opt"
  },
  {
    "path": "notebooks/colab/jacobians_hessians_colab.ipynb",
    "chars": 29210,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"# Jacobians, Hessians, hvp, vhp, and more"
  },
  {
    "path": "notebooks/colab/per_sample_grads_colab.ipynb",
    "chars": 19294,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"id\": \"a474c143-05c4-43b6-b12c-17b592d07a6a\",\n      \"metadata\""
  },
  {
    "path": "packaging/windows/internal/cuda_install.bat",
    "chars": 13788,
    "preview": "@echo on\n\nif \"%CU_VERSION%\" == \"cpu\" (\n    echo Skipping for CPU builds\n    exit /b 0\n)\n\nset SRC_DIR=%~dp0\\..\n\nif not ex"
  },
  {
    "path": "packaging/windows/internal/driver_update.bat",
    "chars": 918,
    "preview": "set \"DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-in"
  },
  {
    "path": "pull_request_template.md",
    "chars": 307,
    "preview": "To contribute a change to functorch, please make sure you are submitting a\nPull Request to the functorch folder in https"
  },
  {
    "path": "setup.cfg",
    "chars": 329,
    "preview": "[bdist_wheel]\nuniversal=1\n\n[metadata]\nlicense_file = LICENSE\n\n[pep8]\nmax-line-length = 120\n\n[flake8]\nmax-line-length = 1"
  },
  {
    "path": "setup.py",
    "chars": 1479,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD"
  },
  {
    "path": "version.txt",
    "chars": 9,
    "preview": "1.14.0a0\n"
  }
]

About this extraction

This page contains the full source code of the facebookresearch/functorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 18 files (98.5 KB), approximately 30.5k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!