Repository: facebookresearch/functorch Branch: main Commit: b71aa0b4387b Files: 18 Total size: 98.5 KB Directory structure: gitextract_sltc66ra/ ├── .flake8 ├── .github/ │ └── workflows/ │ └── wheels.yml ├── .gitignore ├── .lintrunner.toml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── notebooks/ │ ├── README.md │ └── colab/ │ ├── aot_autograd_optimizations.ipynb │ ├── jacobians_hessians_colab.ipynb │ └── per_sample_grads_colab.ipynb ├── packaging/ │ └── windows/ │ └── internal/ │ ├── cuda_install.bat │ └── driver_update.bat ├── pull_request_template.md ├── setup.cfg ├── setup.py └── version.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: .flake8 ================================================ [flake8] select = B,C,E,F,P,T4,W,B9 max-line-length = 120 # C408 ignored because we like the dict keyword argument syntax # E501 is not flexible enough, we're using B950 instead ignore = E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, # shebang has extra meaning in fbcode lints, so I think it's not worth trying # to line this up with executable bit EXE001, # these ignores are from flake8-bugbear; please fix! B007,B008, # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 exclude = ./.git, ./benchmarks, ./docs, ./examples, ./notebooks ================================================ FILE: .github/workflows/wheels.yml ================================================ name: Wheels on: pull_request: types: [opened, synchronize, reopened] push: branches: - main jobs: build-wheel: runs-on: ubuntu-22.04 steps: - name: Setup Python uses: actions/setup-python@v2 with: python-version: 3.9 architecture: x64 - name: Checkout functorch uses: actions/checkout@v2 - name: Install PyTorch Nightly run: | python3 -mpip install --pre torch>=1.13.0.dev -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - name: Build wheel run: | python3 -mpip install wheel python3 setup.py bdist_wheel - name: Upload wheel as GHA artifact uses: actions/upload-artifact@v2 with: name: functorch.whl path: dist/*.whl ================================================ FILE: .gitignore ================================================ build/ dist/ functorch.egg-info/ *__pycache__* functorch/version.py functorch/_C.so .gdbinit t.py .vscode/ ccache.sh docs/build docs/src docs/source/generated .DS_Store op_analysis/*.txt # Editor temporaries *.swn *.swo *.swp *.swm ================================================ FILE: .lintrunner.toml ================================================ [[linter]] code = 'FLAKE8' include_patterns = ['**/*.py'] exclude_patterns = [ '.git/**', 'benchmarks/**', 'docs/**', 'examples/**', 'notebooks/**', ] command = [ 'python3', 'tools/lint/flake8_linter.py', '--', '@{{PATHSFILE}}' ] init_command = [ 'python3', 'tools/lint/pip_init.py', '--dry-run={{DRYRUN}}', 'flake8==3.8.2', 'flake8-bugbear==20.1.4', 'flake8-comprehensions==3.3.0', 'flake8-executable==2.0.4', 'flake8-pyi==20.5.0', 'mccabe==0.6.1', 'pycodestyle==2.6.0', 'pyflakes==2.2.0', ] # [[linter]] # code = 'BLACK' # include_patterns = [ # '**/*.py', # ] # command = [ # 'python3', # 'tools/lint/black_linter.py', # '--', # '@{{PATHSFILE}}' # ] # init_command = [ # 'python3', # 'tools/lint/pip_init.py', # '--dry-run={{DRYRUN}}', # 'black==22.3.0', # ] # is_formatter = true ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: CONTRIBUTING.md ================================================ ## Contributing Feedback on our APIs, as well as finding bugs, would be very helpful. Please feel free to chat us up on the PyTorch Slack, or open an issue at https://github.com/pytorch/functorch if you're interested in contributing. To contribute a change to functorch, please make sure you are submitting a Pull Request to the functorch folder in https://github.com/pytorch/pytorch repository. The source of truth for functorch has moved there from https://github.com/pytorch/functorch ; the code in the pytorch/functorch repository is read-only. ================================================ FILE: LICENSE ================================================ Copyright (c) 2021 Facebook, Inc. and its affiliates. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ # functorch [**Why functorch?**](#why-composable-function-transforms) | [**Install guide**](#install) | [**Transformations**](#what-are-the-transforms) | [**Documentation**](#documentation) | [**Future Plans**](#future-plans) **This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the library.** `functorch` is [JAX-like](https://github.com/google/jax) composable function transforms for PyTorch. It aims to provide composable `vmap` and `grad` transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance. In addition, there is experimental functionality to trace through these transformations using FX in order to capture the results of these transforms ahead of time. This would allow us to compile the results of vmap or grad to improve performance. ## Why composable function transforms? There are a number of use cases that are tricky to do in PyTorch today: - computing per-sample-gradients (or other per-sample quantities) - running ensembles of models on a single machine - efficiently batching together tasks in the inner-loop of MAML - efficiently computing Jacobians and Hessians - efficiently computing batched Jacobians and Hessians Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax). ## Install There are two ways to install functorch: 1. functorch from source 2. functorch beta (compatible with recent PyTorch releases) We recommend trying out the functorch beta first. ### Installing functorch from source
Click to expand

#### Using Colab Follow the instructions [in this Colab notebook](https://colab.research.google.com/drive/1CrLkqIrydBYP_svnF89UUO-aQEqNPE8x?usp=sharing) #### Locally As of 9/21/2022, `functorch` comes installed alongside a nightly PyTorch binary. Please install a Preview (nightly) PyTorch binary; see https://pytorch.org/ for instructions. Once you've done that, run a quick sanity check in Python: ```py import torch from functorch import vmap x = torch.randn(3) y = vmap(torch.sin)(x) assert torch.allclose(y, x.sin()) ``` #### functorch development setup As of 9/21/2022, `functorch` comes installed alongside PyTorch and is in the PyTorch source tree. Please install [PyTorch from source](https://github.com/pytorch/pytorch#from-source), then, you will be able to `import functorch`. Try to run some tests to make sure all is OK: ```bash pytest test/test_vmap.py -v pytest test/test_eager_transforms.py -v ``` AOTAutograd has some additional optional requirements. You can install them via: ```bash pip install networkx ``` To run functorch tests, please install our test dependencies (`expecttest`, `pyyaml`).

### Installing functorch beta (compatible with recent PyTorch releases)
Click to expand

#### Using Colab Follow the instructions [here](https://colab.research.google.com/drive/1GNfb01W_xf8JRu78ZKoNnLqiwcrJrbYG#scrollTo=HJ1srOGeNCGA) #### pip Prerequisite: [Install PyTorch](https://pytorch.org/get-started/locally/) ```bash pip install functorch ``` Finally, run a quick sanity check in python: ```py import torch from functorch import vmap x = torch.randn(3) y = vmap(torch.sin)(x) assert torch.allclose(y, x.sin()) ```

## What are the transforms? Right now, we support the following transforms: - `grad`, `vjp`, `jvp`, - `jacrev`, `jacfwd`, `hessian` - `vmap` Furthermore, we have some utilities for working with PyTorch modules. - `make_functional(model)` - `make_functional_with_buffers(model)` ### vmap Note: `vmap` imposes restrictions on the code that it can be used on. For more details, please read its docstring. `vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor operations in `func`. `vmap(func)` returns a new function that maps `func` over some dimension (default: 0) of each Tensor in `inputs`. `vmap` is useful for hiding batch dimensions: one can write a function `func` that runs on examples and then lift it to a function that can take batches of examples with `vmap(func)`, leading to a simpler modeling experience: ```py from functorch import vmap batch_size, feature_size = 3, 5 weights = torch.randn(feature_size, requires_grad=True) def model(feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu() examples = torch.randn(batch_size, feature_size) result = vmap(model)(examples) ``` ### grad `grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute the gradients of the output of func w.r.t. to `inputs[0]`. ```py from functorch import grad x = torch.randn([]) cos_x = grad(lambda x: torch.sin(x))(x) assert torch.allclose(cos_x, x.cos()) # Second-order gradients neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) assert torch.allclose(neg_sin_x, -x.sin()) ``` When composed with `vmap`, `grad` can be used to compute per-sample-gradients: ```py from functorch import vmap batch_size, feature_size = 3, 5 def model(weights,feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu() def compute_loss(weights, example, target): y = model(weights, example) return ((y - target) ** 2).mean() # MSELoss weights = torch.randn(feature_size, requires_grad=True) examples = torch.randn(batch_size, feature_size) targets = torch.randn(batch_size) inputs = (weights,examples, targets) grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) ``` ### vjp The `vjp` transform applies `func` to `inputs` and returns a new function that computes vjps given some `cotangents` Tensors. ```py from functorch import vjp outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) ``` ### jvp The `jvp` transforms computes Jacobian-vector-products and is also known as "forward-mode AD". It is not a higher-order function unlike most other transforms, but it returns the outputs of `func(inputs)` as well as the `jvp`s. ```py from functorch import jvp x = torch.randn(5) y = torch.randn(5) f = lambda x, y: (x * y) _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) assert torch.allclose(output, x + y) ``` ### jacrev, jacfwd, and hessian The `jacrev` transform returns a new function that takes in `x` and returns the Jacobian of `torch.sin` with respect to `x` using reverse-mode AD. ```py from functorch import jacrev x = torch.randn(5) jacobian = jacrev(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected) ``` Use `jacrev` to compute the jacobian. This can be composed with vmap to produce batched jacobians: ```py x = torch.randn(64, 5) jacobian = vmap(jacrev(torch.sin))(x) assert jacobian.shape == (64, 5, 5) ``` `jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using forward-mode AD: ```py from functorch import jacfwd x = torch.randn(5) jacobian = jacfwd(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected) ``` Composing `jacrev` with itself or `jacfwd` can produce hessians: ```py def f(x): return x.sin().sum() x = torch.randn(5) hessian0 = jacrev(jacrev(f))(x) hessian1 = jacfwd(jacrev(f))(x) ``` The `hessian` is a convenience function that combines `jacfwd` and `jacrev`: ```py from functorch import hessian def f(x): return x.sin().sum() x = torch.randn(5) hess = hessian(f)(x) ``` ### Tracing through the transformations We can also trace through these transformations in order to capture the results as new code using `make_fx`. There is also experimental integration with the NNC compiler (only works on CPU for now!). ```py from functorch import make_fx, grad def f(x): return torch.sin(x).sum() x = torch.randn(100) grad_f = make_fx(grad(f))(x) print(grad_f.code) def forward(self, x_1): sin = torch.ops.aten.sin(x_1) sum_1 = torch.ops.aten.sum(sin, None); sin = None cos = torch.ops.aten.cos(x_1); x_1 = None _tensor_constant0 = self._tensor_constant0 mul = torch.ops.aten.mul(_tensor_constant0, cos); _tensor_constant0 = cos = None return mul ``` ### Working with NN modules: make_functional and friends Sometimes you may want to perform a transform with respect to the parameters and/or buffers of an nn.Module. This can happen for example in: - model ensembling, where all of your weights and buffers have an additional dimension - per-sample-gradient computation where you want to compute per-sample-grads of the loss with respect to the model parameters Our solution to this right now is an API that, given an nn.Module, creates a stateless version of it that can be called like a function. - `make_functional(model)` returns a functional version of `model` and the `model.parameters()` - `make_functional_with_buffers(model)` returns a functional version of `model` and the `model.parameters()` and `model.buffers()`. Here's an example where we compute per-sample-gradients using an nn.Linear layer: ```py import torch from functorch import make_functional, vmap, grad model = torch.nn.Linear(3, 3) data = torch.randn(64, 3) targets = torch.randn(64, 3) func_model, params = make_functional(model) def compute_loss(params, data, targets): preds = func_model(params, data) return torch.mean((preds - targets) ** 2) per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets) ``` If you're making an ensemble of models, you may find `combine_state_for_ensemble` useful. ## Documentation For more documentation, see [our docs website](https://pytorch.org/functorch). ## Debugging `torch._C._functorch.dump_tensor`: Dumps dispatch keys on stack `torch._C._functorch._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you. ## Future Plans In the end state, we'd like to upstream this into PyTorch once we iron out the design details. To figure out the details, we need your help -- please send us your use cases by starting a conversation in the issue tracker or trying our project out. ## License Functorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. ## Citing functorch If you use functorch in your publication, please cite it by using the following BibTeX entry. ```bibtex @Misc{functorch2021, author = {Horace He, Richard Zou}, title = {functorch: JAX-like composable function transforms for PyTorch}, howpublished = {\url{https://github.com/pytorch/functorch}}, year = {2021} } ``` ================================================ FILE: notebooks/README.md ================================================ The new, updated versions of these notebooks may be found in the pytorch/pytorch repo. We're leaving the old notebooks here as a temporary solution so that our website still points to the correct thing. We plan to rewrite the links on the website to point to their newer counterparts soon. ================================================ FILE: notebooks/colab/aot_autograd_optimizations.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# AOT Autograd - How to use and optimize?\n", "\n", "\n", " \"Open\n", "\n", "\n", "## Background\n", "In this tutorial, we will learn how to use AOT Autograd to speedup training of deep learning models.\n", "\n", "For background, AOT Autograd is a toolkit to assist developers in accelerating training on PyTorch. Broadly, it has two key features\n", "* AOT Autograd traces the forward and backward graph ahead of time. Presence of forward and backward graph ahead of time facilitates joint graph optimizations such as recomputation or activation checkpointing.\n", "* AOT Autograd provides simple mechanisms to compile the extracted forward and backward graphs through deep learning compilers, such as NVFuser, NNC, TVM and others.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## What will you learn?\n", "In this tutorial, we will look at how AOT Autograd can be used, in conjunction with backend compilers, to accelerate the training of PyTorch models. More specifically, you will learn\n", "* How to use AOT Autograd?\n", "* How AOT Autograd uses backend compilers to perform operation fusion?\n", "* How AOT Autograd enables training-specific optimizations such as Recomputation?\n", "\n", "So, lets get started.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "Let's setup a simple model.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "def fn(a, b, c, d):\n", " x = a + b + c + d\n", " return x.cos().cos()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Test that it works\n", "a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]\n", "ref = fn(a, b, c, d)\n", "loss = ref.sum()\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use AOT Autograd\n", "\n", "Now, lets use AOT Autograd and look at the extracted forward and backward graphs. Internally, AOT uses `__torch_dispatch__` based tracing mechanism to extract forward and backward graphs, and wraps them in `torch.Fx` GraphModule containers. Note that AOT Autograd tracing is different from the usual Fx symbolic tracing. AOT Autograd uses Fx GraphModule just to represent the traced graphs (and not for tracing).\n", "\n", "AOT Autograd then sends these forward and backward graphs to the user supplied compilers. So, lets write a compiler that just prints the graph." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "def forward(self, primals_1, primals_2, primals_3, primals_4):\n", " add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None\n", " add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None\n", " add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None\n", " cos = torch.ops.aten.cos(add_2)\n", " cos_1 = torch.ops.aten.cos(cos)\n", " return [cos_1, add_2, cos]\n", " \n", "\n", "\n", "\n", "def forward(self, add_2, cos, tangents_1):\n", " sin = torch.ops.aten.sin(cos); cos = None\n", " neg = torch.ops.aten.neg(sin); sin = None\n", " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", " sin_1 = torch.ops.aten.sin(add_2); add_2 = None\n", " neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None\n", " mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None\n", " return [mul_1, mul_1, mul_1, mul_1]\n", " \n" ] } ], "source": [ "from functorch.compile import aot_function\n", "\n", "# The compiler_fn is called after the forward and backward graphs are extracted.\n", "# Here, we just print the code in the compiler_fn. Return of this function is a callable.\n", "def compiler_fn(fx_module: torch.fx.GraphModule, _):\n", " print(fx_module.code)\n", " return fx_module\n", "\n", "# Pass on the compiler_fn to the aot_function API\n", "aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)\n", "\n", "# Run the aot_print_fn once to trigger the compilation and print the graphs\n", "res = aot_print_fn(a, b, c, d).sum().backward()\n", "assert torch.allclose(ref, res)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above code prints the Fx graph for the forward and backward graph. You can see that in addition to the original input of the forward pass, the forward graph outputs some additional tensors. These tensors are saved for the backward pass for gradient calculation. We will come back to these later while talking about recomputation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Operator Fusion\n", "Now that we understand how to use AOT Autograd to print forward and backward graphs, let us use AOT Autograd to use some actual deep learning compiler. In this tutorial, we use PyTorch Neural Network Compiler (NNC) to perform pointwise operator fusion for CPU devices. For CUDA devices, a suitable alternative is NvFuser. So, lets use NNC" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile\n", "from functorch.compile import ts_compile\n", "\n", "# Lets compile the forward and backward through ts_compile.\n", "aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)\n", "\n", "# Correctness checking. Lets clone the input so that we can check grads.\n", "cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n", "cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n", "\n", "res = aot_nnc_fn(*cloned_inputs)\n", "loss = res.sum()\n", "loss.backward()\n", "assert torch.allclose(ref, res)\n", "assert torch.allclose(a.grad, cloned_a.grad)\n", "assert torch.allclose(b.grad, cloned_b.grad)\n", "assert torch.allclose(c.grad, cloned_c.grad)\n", "assert torch.allclose(d.grad, cloned_d.grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets benchmark the original and AOT Autograd + NNC compiled function." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Lets write a function to benchmark the forward and backward pass\n", "import time\n", "import statistics\n", "\n", "def bench(fn, args, prefix):\n", " warmup = 10\n", " iterations = 100\n", "\n", " for _ in range(warmup):\n", " ref = fn(*args)\n", " ref.sum().backward()\n", " \n", " fw_latencies = []\n", " bw_latencies = []\n", " for _ in range(iterations):\n", " for arg in args:\n", " arg.grad = None\n", "\n", " fw_begin = time.perf_counter()\n", " ref = fn(*args)\n", " fw_end = time.perf_counter()\n", "\n", " loss = ref.sum() \n", "\n", " bw_begin = time.perf_counter()\n", " loss.backward()\n", " bw_end = time.perf_counter()\n", "\n", " fw_latencies.append(fw_end - fw_begin)\n", " bw_latencies.append(bw_end - bw_begin)\n", " \n", " avg_fw_latency = statistics.mean(fw_latencies) * 10**6\n", " avg_bw_latency = statistics.mean(bw_latencies) * 10**6\n", " print(prefix, \"Fwd = \" + str(avg_fw_latency) + \" us\", \"Bwd = \" + str(avg_bw_latency) + \" us\", sep=', ')\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us\n", "AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us\n" ] } ], "source": [ "large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]\n", "\n", "# Benchmark the Eager and AOT Autograd functions\n", "bench(fn, large_inputs, \"Eager\")\n", "bench(aot_nnc_fn, large_inputs, \"AOT\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With the help of NNC, AOT Autograd speeds up both the forward and backward pass. If we look at the printed graphs earlier, all the operators are pointwise. The pointwise operators are memory bandwidth bound, and thus benefit from operator fusion. Looking closely at the numbers, the backward pass gets higher speedup. This is because forward pass has to output some intermediate tensors for gradient calculation for the backward pass, preventing it from saving some memory reads and writes. However, such restriction does not exist in the backward graph." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Recomputation (aka Activation Checkpointing)\n", "Recomputation (often called activation checkpointing) is a technique in which, instead of saving some activations for use in backwards, we recompute them **during** the backwards pass. Recomputing saves memory, but we incur performance overhead.\n", "\n", "However, in the presence of fusing compiler, we can do better that that. We can recompute the fusion-friendly operators to save memory, and then rely on the fusing compiler to fuse the recomputed operators. This reduces both memory and runtime. Please refer to this [discuss post](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) for more details.\n", "\n", "Here, we use AOT Autograd with NNC to perform similar type of recomputation. At the end of `__torch_dispatch__` tracing, AOT Autograd has a forward graph and joint forward-backward graph. AOT Autograd then uses a partitioner to isolate the forward and backward graph. In the example above, we used a default partitioner. For this experiment, we will use another partitioner called `min_cut_rematerialization_partition` to perform smarter fusion-aware recomputation. The partitioner is configurable and one can write their own partitioner to plug it in AOT Autograd." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "def forward(self, primals_1, primals_2, primals_3, primals_4):\n", " add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None\n", " add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None\n", " add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None\n", " cos = torch.ops.aten.cos(add_2)\n", " cos_1 = torch.ops.aten.cos(cos); cos = None\n", " return [cos_1, add_2]\n", " \n", "\n", "\n", "\n", "def forward(self, add_2, tangents_1):\n", " cos = torch.ops.aten.cos(add_2)\n", " sin = torch.ops.aten.sin(cos); cos = None\n", " neg = torch.ops.aten.neg(sin); sin = None\n", " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", " sin_1 = torch.ops.aten.sin(add_2); add_2 = None\n", " neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None\n", " mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None\n", " return [mul_1, mul_1, mul_1, mul_1]\n", " \n" ] } ], "source": [ "from functorch.compile import min_cut_rematerialization_partition\n", "\n", "# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.\n", "# This will show us how the recomputation has modified the graph.\n", "aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)\n", "res = aot_fn(a, b, c, d).sum().backward()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that compared to default partitioner, forward pass now outputs fewer tensors, and recomputes some operations in the backward pass. Let us try NNC compiler now to perform operator fusions (note that we also have a wrapper function - `memory_efficient_fusion` which internally uses `min_cut_rematerialization_partition` and Torchscript compiler to achieve the same effect as following code)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "\n", "# Lets set up the partitioner and NNC compiler.\n", "aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)\n", "\n", "# Correctness checking. Lets clone the input so that we can check grads.\n", "cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n", "cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n", "\n", "res = aot_recompute_nnc_fn(*cloned_inputs)\n", "loss = res.sum()\n", "loss.backward()\n", "assert torch.allclose(ref, res)\n", "assert torch.allclose(a.grad, cloned_a.grad)\n", "assert torch.allclose(b.grad, cloned_b.grad)\n", "assert torch.allclose(c.grad, cloned_c.grad)\n", "assert torch.allclose(d.grad, cloned_d.grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, lets benchmark the different functions" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us\n", "AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us\n", "AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us\n" ] } ], "source": [ "bench(fn, large_inputs, \"Eager\")\n", "bench(aot_nnc_fn, large_inputs, \"AOT\")\n", "bench(aot_recompute_nnc_fn, large_inputs, \"AOT_Recomp\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We observe that both forward and backward latency improve over the default partitioner (and a lot better than eager). Fewer outputs in the forward pass and fewer inputs in the backward pass, along with fusion, allows better memory bandwidth utilization leading to further speedups." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Actual Usage\n", "For actual usage on CUDA devices, we've wrapped AOTAutograd in a convenient wrapper - `memory_efficient_fusion`. Use this for fusion on GPU!\n", "\n", "```\n", "from functorch.compile import memory_efficient_fusion\n", "```\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.5 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "73b6e0ee7c860e06bb349c72324473b318d6cb6c97bcad772bce0703fb8d0dfb" } } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: notebooks/colab/jacobians_hessians_colab.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "source": [ "# Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms\n", "\n", "\n", " \"Open\n", "\n", "\n", "Computing jacobians or hessians are useful in a number of non-traditional\n", "deep learning models. It is difficult (or annoying) to compute these quantities\n", "efficiently using a standard autodiff system like PyTorch Autograd; functorch\n", "provides ways of computing various higher-order autodiff quantities efficiently." ], "metadata": { "id": "zPbR6-eP51fe" }, "id": "zPbR6-eP51fe" }, { "cell_type": "markdown", "source": [ "## Computing the Jacobian" ], "metadata": { "id": "3kDj8fhn52j3" }, "id": "3kDj8fhn52j3" }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from functools import partial\n", "_ = torch.manual_seed(0)" ], "metadata": { "id": "w_IinyjzflUH" }, "execution_count": null, "outputs": [], "id": "w_IinyjzflUH" }, { "cell_type": "markdown", "source": [ "Let’s start with a function that we’d like to compute the jacobian of. This is a simple linear function with non-linear activation.\n", "\n" ], "metadata": { "id": "cibF_PEYflUH" }, "id": "cibF_PEYflUH" }, { "cell_type": "code", "source": [ "def predict(weight, bias, x):\n", " return F.linear(x, weight, bias).tanh()" ], "metadata": { "id": "qhcD9hWYflUH" }, "execution_count": null, "outputs": [], "id": "qhcD9hWYflUH" }, { "cell_type": "markdown", "source": [ "Let's add some dummy data: a weight, a bias, and a feature vector x.\n", "\n" ], "metadata": { "id": "G8tqQrO_flUH" }, "id": "G8tqQrO_flUH" }, { "cell_type": "code", "source": [ "D = 16\n", "weight = torch.randn(D, D)\n", "bias = torch.randn(D)\n", "x = torch.randn(D) # feature vector" ], "metadata": { "id": "FZ4uJfZGflUH" }, "execution_count": null, "outputs": [], "id": "FZ4uJfZGflUH" }, { "cell_type": "markdown", "source": [ "Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.\n", "PyTorch Autograd computes vector-Jacobian products. In order to compute the full\n", "Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row\n", "by using a different unit vector each time." ], "metadata": { "id": "uMAW-ArQflUH" }, "id": "uMAW-ArQflUH" }, { "cell_type": "code", "source": [ "def compute_jac(xp):\n", " jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n", " for vec in unit_vectors]\n", " return torch.stack(jacobian_rows)" ], "metadata": { "id": "z-BJPtbpflUI" }, "execution_count": null, "outputs": [], "id": "z-BJPtbpflUI" }, { "cell_type": "code", "source": [ "xp = x.clone().requires_grad_()\n", "unit_vectors = torch.eye(D)\n", "\n", "jacobian = compute_jac(xp)\n", "\n", "print(jacobian.shape)\n", "print(jacobian[0]) # show first row" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "f1f1ec12-56ef-40f7-8c3c-cbad7bf86644", "id": "zuWGSXspflUI" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([16, 16])\n", "tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,\n", " 0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])\n" ] } ], "id": "zuWGSXspflUI" }, { "cell_type": "markdown", "source": [ "Instead of computing the jacobian row-by-row, we can use vmap to get rid of the for-loop and vectorize the computation. \n", "We can’t directly apply vmap to PyTorch Autograd; instead, functorch provides a vjp transform:\n", "\n" ], "metadata": { "id": "mxlEOUieflUI" }, "id": "mxlEOUieflUI" }, { "cell_type": "code", "source": [ "from functorch import vmap, vjp\n", "\n", "_, vjp_fn = vjp(partial(predict, weight, bias), x)\n", "\n", "ft_jacobian, = vmap(vjp_fn)(unit_vectors)\n", "\n", "# lets confirm both methods compute the same result\n", "assert torch.allclose(ft_jacobian, jacobian)" ], "metadata": { "id": "DeF6uy4WflUI" }, "execution_count": null, "outputs": [], "id": "DeF6uy4WflUI" }, { "cell_type": "markdown", "source": [ "In future tutorial a composition of reverse-mode AD and vmap will give us per-sample-gradients. \n", "In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! \n", "Various compositions of vmap and autodiff transforms can give us different interesting quantities.\n", "\n", "functorch provides **jacrev** as a convenience function that performs the vmap-vjp composition to compute jacobians. **jacrev** accepts an argnums argument that says which argument we would like to compute Jacobians with respect to.\n", "\n" ], "metadata": { "id": "Hy4REmwDflUI" }, "id": "Hy4REmwDflUI" }, { "cell_type": "code", "source": [ "from functorch import jacrev\n", "\n", "ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n", "\n", "# confirm \n", "assert torch.allclose(ft_jacobian, jacobian)" ], "metadata": { "id": "Rt7i6_YlflUI" }, "execution_count": null, "outputs": [], "id": "Rt7i6_YlflUI" }, { "cell_type": "markdown", "source": [ "Let’s compare the performance of the two ways to compute the jacobian. The functorch version is much faster (and becomes even faster the more outputs there are). \n", "\n", "In general, we expect that vectorization via vmap can help eliminate overhead and give better utilization of your hardware.\n", "\n", "Vmap does this magic by pushing the outer loop down into the functions primitive operations in order to obtain better performance.\n", "\n", "\n" ], "metadata": { "id": "JYe2H1UcflUJ" }, "id": "JYe2H1UcflUJ" }, { "cell_type": "markdown", "source": [ "Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:" ], "metadata": { "id": "i_143LZwflUJ" }, "id": "i_143LZwflUJ" }, { "cell_type": "code", "source": [ "def get_perf(first, first_descriptor, second, second_descriptor):\n", " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n", " faster = second.times[0]\n", " slower = first.times[0]\n", " gain = (slower-faster)/slower\n", " if gain < 0: gain *=-1 \n", " final_gain = gain*100\n", " print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")" ], "metadata": { "id": "II7r6jBtflUJ" }, "execution_count": null, "outputs": [], "id": "II7r6jBtflUJ" }, { "cell_type": "markdown", "source": [ "And then run the performance comparison:" ], "metadata": { "id": "r4clPnPKflUJ" }, "id": "r4clPnPKflUJ" }, { "cell_type": "code", "source": [ "from torch.utils.benchmark import Timer\n", "\n", "without_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\n", "with_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", "\n", "no_vmap_timer = without_vmap.timeit(500)\n", "with_vmap_timer = with_vmap.timeit(500)\n", "\n", "print(no_vmap_timer)\n", "print(with_vmap_timer)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "cbf77a19-aac9-428d-eba1-74d337c53e49", "id": "ZPtoxF6eflUJ" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "compute_jac(xp)\n", " 2.25 ms\n", " 1 measurement, 500 runs , 1 thread\n", "\n", "jacrev(predict, argnums=2)(weight, bias, x)\n", " 884.34 us\n", " 1 measurement, 500 runs , 1 thread\n" ] } ], "id": "ZPtoxF6eflUJ" }, { "cell_type": "markdown", "source": [ "Lets do a relative performance comparison of the above with our get_perf function:" ], "metadata": { "id": "nGBBi4dZflUJ" }, "id": "nGBBi4dZflUJ" }, { "cell_type": "code", "source": [ "get_perf(no_vmap_timer, \"without vmap\", with_vmap_timer, \"vmap\");" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "85d0bc5f-34aa-4826-f953-6c637404490c", "id": "zqV2RzEXflUJ" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Performance delta: 60.7170 percent improvement with vmap \n" ] } ], "id": "zqV2RzEXflUJ" }, { "cell_type": "markdown", "source": [ "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input." ], "metadata": { "id": "EQAB99EQflUJ" }, "id": "EQAB99EQflUJ" }, { "cell_type": "code", "source": [ "# note the change in input via argnums params of 0,1 to map to weight and bias\n", "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)" ], "metadata": { "id": "8UZpC8DnflUK" }, "execution_count": null, "outputs": [], "id": "8UZpC8DnflUK" }, { "cell_type": "markdown", "source": [ "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n" ], "metadata": { "id": "F3USYENIflUK" }, "id": "F3USYENIflUK" }, { "cell_type": "markdown", "source": [ "We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: \n", "- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. \n", "- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. \n", "\n", "jacfwd and jacrev can be substituted for each other but they have different performance characteristics.\n", "\n", "As a general rule of thumb, if you’re computing the jacobian of an $𝑅^N \\to R^M$ function, and there are many more outputs than inputs (i.e. $M > N$) then jacfwd is preferred, otherwise use jacrev. There are exceptions to this rule, but a non-rigorous argument for this follows:\n", "\n", "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n", "\n" ], "metadata": { "id": "V7B3vE8dflUK" }, "id": "V7B3vE8dflUK" }, { "cell_type": "code", "source": [ "from functorch import jacrev, jacfwd" ], "metadata": { "id": "k7Tok7m3flUK" }, "execution_count": null, "outputs": [], "id": "k7Tok7m3flUK" }, { "cell_type": "markdown", "source": [ "First, let's benchmark with more inputs than outputs:\n", "\n" ], "metadata": { "id": "YrV-gZAaflUL" }, "id": "YrV-gZAaflUL" }, { "cell_type": "code", "source": [ "Din = 32\n", "Dout = 2048\n", "weight = torch.randn(Dout, Din)\n", "\n", "bias = torch.randn(Dout)\n", "x = torch.randn(Din)\n", "\n", "# remember the general rule about taller vs wider...here we have a taller matrix:\n", "print(weight.shape)\n", "\n", "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", "\n", "jacfwd_timing = using_fwd.timeit(500)\n", "jacrev_timing = using_bwd.timeit(500)\n", "\n", "print(f'jacfwd time: {jacfwd_timing}')\n", "print(f'jacrev time: {jacrev_timing}')\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "dd882726-9723-47c0-a72f-3c7835a85aa1", "id": "m5j-4hSxflUL" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([2048, 32])\n", "jacfwd time: \n", "jacfwd(predict, argnums=2)(weight, bias, x)\n", " 1.32 ms\n", " 1 measurement, 500 runs , 1 thread\n", "jacrev time: \n", "jacrev(predict, argnums=2)(weight, bias, x)\n", " 12.46 ms\n", " 1 measurement, 500 runs , 1 thread\n" ] } ], "id": "m5j-4hSxflUL" }, { "cell_type": "markdown", "source": [ "and then do a relative benchmark:" ], "metadata": { "id": "k_Sg-4tVflUL" }, "id": "k_Sg-4tVflUL" }, { "cell_type": "code", "source": [ "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "3a6586a1-269d-46d8-d119-e24f6d46277f", "id": "_4T96zGjflUL" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Performance delta: 842.8274 percent improvement with jacrev \n" ] } ], "id": "_4T96zGjflUL" }, { "cell_type": "markdown", "source": [ "and now the reverse - more outputs (M) than inputs (N):" ], "metadata": { "id": "RCDPot1yflUL" }, "id": "RCDPot1yflUL" }, { "cell_type": "code", "source": [ "Din = 2048\n", "Dout = 32\n", "weight = torch.randn(Dout, Din)\n", "bias = torch.randn(Dout)\n", "x = torch.randn(Din)\n", "\n", "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", "\n", "jacfwd_timing = using_fwd.timeit(500)\n", "jacrev_timing = using_bwd.timeit(500)\n", "\n", "print(f'jacfwd time: {jacfwd_timing}')\n", "print(f'jacrev time: {jacrev_timing}')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "913e9ccd-3d4f-472a-a749-19cee36d0a16", "id": "_DRFqzqZflUM" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "jacfwd time: \n", "jacfwd(predict, argnums=2)(weight, bias, x)\n", " 7.99 ms\n", " 1 measurement, 500 runs , 1 thread\n", "jacrev time: \n", "jacrev(predict, argnums=2)(weight, bias, x)\n", " 1.09 ms\n", " 1 measurement, 500 runs , 1 thread\n" ] } ], "id": "_DRFqzqZflUM" }, { "cell_type": "markdown", "source": [ "and a relative perf comparison:" ], "metadata": { "id": "5SRbMCNsflUM" }, "id": "5SRbMCNsflUM" }, { "cell_type": "code", "source": [ "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "c282ce25-4f6e-44cd-aed7-60f6f5010e5b", "id": "uF_9GaoiflUM" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Performance delta: 635.2095 percent improvement with jacfwd \n" ] } ], "id": "uF_9GaoiflUM" }, { "cell_type": "markdown", "source": [ "## Hessian computation with functorch.hessian\n" ], "metadata": { "id": "J29FQaBQflUM" }, "id": "J29FQaBQflUM" }, { "cell_type": "markdown", "source": [ "We offer a convenience API to compute hessians: `functorch.hessian`. \n", "Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).\n", "\n", "This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. \n", "Indeed, under the hood, `hessian(f)` is simply `jacfwd(jacrev(f))`.\n", "\n" ], "metadata": { "id": "My4DPH97flUM" }, "id": "My4DPH97flUM" }, { "cell_type": "markdown", "source": [ "Note: to boost performance: depending on your model, you may also want to use `jacfwd(jacfwd(f))` or `jacrev(jacrev(f))` instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n", "\n" ], "metadata": { "id": "FJt038l5flUM" }, "id": "FJt038l5flUM" }, { "cell_type": "code", "source": [ "from functorch import hessian\n", "\n", "# lets reduce the size in order not to blow out colab. Hessians require significant memory:\n", "Din = 512\n", "Dout = 32\n", "weight = torch.randn(Dout, Din)\n", "bias = torch.randn(Dout)\n", "x = torch.randn(Din)\n", "\n", "hess_api = hessian(predict, argnums=2)(weight, bias, x)\n", "hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\n", "#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)\n" ], "metadata": { "id": "jEqr2ywZflUM" }, "execution_count": null, "outputs": [], "id": "jEqr2ywZflUM" }, { "cell_type": "markdown", "source": [ "Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())" ], "metadata": { "id": "n9BHcICQflUN" }, "id": "n9BHcICQflUN" }, { "cell_type": "code", "source": [ "torch.allclose(hess_api, hess_fwdfwd)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e457e3bc-f085-4f90-966d-f98893b98ea8", "id": "eHiWRkjJflUN" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 18 } ], "id": "eHiWRkjJflUN" }, { "cell_type": "markdown", "source": [ "## Batch Jacobian and Batch Hessian\n" ], "metadata": { "id": "Gjt1RO8HflUN" }, "id": "Gjt1RO8HflUN" }, { "cell_type": "markdown", "source": [ "In the above examples we’ve been operating with a single feature vector. In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. That is, given a batch of inputs of shape `(B, N)` and a function that goes from $R^N \\to R^M$, we would like a Jacobian of shape `(B, M, N)`. \n", "\n", "The easiest way to do this is to use vmap:" ], "metadata": { "id": "RjIzdoQNflUN" }, "id": "RjIzdoQNflUN" }, { "cell_type": "code", "source": [ "batch_size = 64\n", "Din = 31\n", "Dout = 33\n", "\n", "weight = torch.randn(Dout, Din)\n", "print(f\"weight shape = {weight.shape}\")\n", "\n", "bias = torch.randn(Dout)\n", "\n", "x = torch.randn(batch_size, Din)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "561eb618-e00f-40d5-bd99-fa51ab82051f", "id": "B1eoEO4UflUN" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "weight shape = torch.Size([33, 31])\n" ] } ], "id": "B1eoEO4UflUN" }, { "cell_type": "code", "source": [ "compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\n", "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)" ], "metadata": { "id": "nZ_V02NhflUN" }, "execution_count": null, "outputs": [], "id": "nZ_V02NhflUN" }, { "cell_type": "markdown", "source": [ "If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it’s also sometimes possible to do this without using vmap by summing the outputs and then computing the Jacobian of that function:\n", "\n" ], "metadata": { "id": "_OLDiY3MflUN" }, "id": "_OLDiY3MflUN" }, { "cell_type": "code", "source": [ "def predict_with_output_summed(weight, bias, x):\n", " return predict(weight, bias, x).sum(0)\n", "\n", "batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\n", "assert torch.allclose(batch_jacobian0, batch_jacobian1)" ], "metadata": { "id": "_QH4hD8PflUO" }, "execution_count": null, "outputs": [], "id": "_QH4hD8PflUO" }, { "cell_type": "markdown", "source": [ "If you instead have a function that goes from $𝑅^𝑁 \\to 𝑅^𝑀$ but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n", "\n", "Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.\n", "\n" ], "metadata": { "id": "eUjw65cCflUO" }, "id": "eUjw65cCflUO" }, { "cell_type": "code", "source": [ "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n", "\n", "batch_hess = compute_batch_hessian(weight, bias, x)\n", "batch_hess.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "f3135cfa-e9e5-4f18-8cb7-0655e8a37cb5", "id": "3vAyQjMsflUO" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([64, 33, 31, 31])" ] }, "metadata": {}, "execution_count": 22 } ], "id": "3vAyQjMsflUO" }, { "cell_type": "markdown", "source": [ "## Computing Hessian-vector products\n", "\n", "The naive way to compute a Hessian-vector product (hvp) is to materialize the full Hessian and perform a dot-product with a vector. We can do better: it turns out we don't need to materialize the full Hessian to do this. We'll go through two (of many) different strategies to compute Hessian-vector products:\n", "- composing reverse-mode AD with reverse-mode AD\n", "- composing reverse-mode AD with forward-mode AD\n", "\n", "Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn't need to construct an Autograd graph and save intermediates for backward:" ], "metadata": { "id": "Wa8E48sQgpkb" }, "id": "Wa8E48sQgpkb" }, { "cell_type": "code", "source": [ "from functorch import jvp, grad, vjp\n", "\n", "def hvp(f, primals, tangents):\n", " return jvp(grad(f), primals, tangents)[1]" ], "metadata": { "id": "trw6WbAth6BM" }, "execution_count": null, "outputs": [], "id": "trw6WbAth6BM" }, { "cell_type": "markdown", "source": [ "Here's some sample usage." ], "metadata": { "id": "DQMpRo6nitfr" }, "id": "DQMpRo6nitfr" }, { "cell_type": "code", "source": [ "def f(x):\n", " return x.sin().sum()\n", "\n", "x = torch.randn(2048)\n", "tangent = torch.randn(2048)\n", "\n", "result = hvp(f, (x,), (tangent,))" ], "metadata": { "id": "sPwg8SOdiVAK" }, "execution_count": null, "outputs": [], "id": "sPwg8SOdiVAK" }, { "cell_type": "markdown", "source": [ "If PyTorch forward-AD does not have coverage for your operations, then we can instead compose reverse-mode AD with reverse-mode AD:" ], "metadata": { "id": "zGvUIcB0j1Ez" }, "id": "zGvUIcB0j1Ez" }, { "cell_type": "code", "source": [ "def hvp_revrev(f, primals, tangents):\n", " _, vjp_fn = vjp(grad(f), *primals)\n", " return vjp_fn(*tangents)" ], "metadata": { "id": "mdDFZdlekAOK" }, "execution_count": null, "outputs": [], "id": "mdDFZdlekAOK" }, { "cell_type": "code", "source": [ "result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\n", "assert torch.allclose(result, result_hvp_revrev[0])" ], "metadata": { "id": "_CuCk9X0lW7C" }, "execution_count": null, "outputs": [], "id": "_CuCk9X0lW7C" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" }, "colab": { "name": "jacobians_hessians.ipynb", "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: notebooks/colab/per_sample_grads_colab.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "a474c143-05c4-43b6-b12c-17b592d07a6a", "metadata": { "id": "a474c143-05c4-43b6-b12c-17b592d07a6a" }, "source": [ "# Per-sample-gradients\n", "\n", "\n", " \"Open\n", "\n", "\n", "## What is it?\n", "\n", "Per-sample-gradient computation is computing the gradient for each and every\n", "sample in a batch of data. It is a useful quantity in differential privacy, meta-learning,\n", "and optimization research.\n" ] }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from functools import partial\n", "\n", "torch.manual_seed(0);" ], "metadata": { "id": "Gb-yt4VKUUuc" }, "execution_count": null, "outputs": [], "id": "Gb-yt4VKUUuc" }, { "cell_type": "code", "source": [ "# Here's a simple CNN and loss function:\n", "\n", "class SimpleCNN(nn.Module):\n", " def __init__(self):\n", " super(SimpleCNN, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", " self.fc1 = nn.Linear(9216, 128)\n", " self.fc2 = nn.Linear(128, 10)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", " x = self.conv2(x)\n", " x = F.relu(x)\n", " x = F.max_pool2d(x, 2)\n", " x = torch.flatten(x, 1)\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", " x = self.fc2(x)\n", " output = F.log_softmax(x, dim=1)\n", " output = x\n", " return output\n", "\n", "def loss_fn(predictions, targets):\n", " return F.nll_loss(predictions, targets)" ], "metadata": { "id": "tf-HKHjUUbyY" }, "execution_count": null, "outputs": [], "id": "tf-HKHjUUbyY" }, { "cell_type": "markdown", "source": [ "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. \n", "\n", "The dummy images are 28 by 28 and we use a minibatch of size 64.\n", "\n" ], "metadata": { "id": "VEDPe-EoU5Fa" }, "id": "VEDPe-EoU5Fa" }, { "cell_type": "code", "source": [ "device = 'cuda'\n", "\n", "num_models = 10\n", "batch_size = 64\n", "data = torch.randn(batch_size, 1, 28, 28, device=device)\n", "\n", "targets = torch.randint(10, (64,), device=device)" ], "metadata": { "id": "WB2Qe3AHUvPN" }, "execution_count": null, "outputs": [], "id": "WB2Qe3AHUvPN" }, { "cell_type": "markdown", "source": [ "In regular model training, one would forward the minibatch through the model, and then call .backward() to compute gradients. This would generate an 'average' gradient of the entire mini-batch:\n", "\n" ], "metadata": { "id": "GOGJ-OUxVcT5" }, "id": "GOGJ-OUxVcT5" }, { "cell_type": "code", "source": [ "model = SimpleCNN().to(device=device)\n", "predictions = model(data) # move the entire mini-batch through the model\n", "\n", "loss = loss_fn(predictions, targets)\n", "loss.backward() # back propogate the 'average' gradient of this mini-batch" ], "metadata": { "id": "WYjMx8QTUvRu" }, "execution_count": null, "outputs": [], "id": "WYjMx8QTUvRu" }, { "cell_type": "markdown", "source": [ "In contrast to the above approach, per-sample-gradient computation is equivalent to: \n", "- for each individual sample of the data, perform a forward and a backward pass to get an individual (per-sample) gradient.\n", "\n" ], "metadata": { "id": "HNw4_IVzU5Pz" }, "id": "HNw4_IVzU5Pz" }, { "cell_type": "code", "source": [ "def compute_grad(sample, target):\n", " \n", " sample = sample.unsqueeze(0) # prepend batch dimension for processing\n", " target = target.unsqueeze(0)\n", "\n", " prediction = model(sample)\n", " loss = loss_fn(prediction, target)\n", "\n", " return torch.autograd.grad(loss, list(model.parameters()))\n", "\n", "\n", "def compute_sample_grads(data, targets):\n", " \"\"\" manually process each sample with per sample gradient \"\"\"\n", " sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]\n", " sample_grads = zip(*sample_grads)\n", " sample_grads = [torch.stack(shards) for shards in sample_grads]\n", " return sample_grads\n", "\n", "per_sample_grads = compute_sample_grads(data, targets)" ], "metadata": { "id": "vUsb3VfexJrY" }, "execution_count": null, "outputs": [], "id": "vUsb3VfexJrY" }, { "cell_type": "markdown", "source": [ "`sample_grads[0]` is the per-sample-grad for model.conv1.weight. `model.conv1.weight.shape` is `[32, 1, 3, 3]`; notice how there is one gradient, per sample, in the batch for a total of 64.\n", "\n", "\n", "\n" ], "metadata": { "id": "aNkX6lFIxzcm" }, "id": "aNkX6lFIxzcm" }, { "cell_type": "code", "source": [ "print(per_sample_grads[0].shape)" ], "metadata": { "id": "C3a9_clvyPho", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "407abc1a-846f-4e50-83bc-c90719a26073" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([64, 32, 1, 3, 3])\n" ] } ], "id": "C3a9_clvyPho" }, { "cell_type": "markdown", "source": [ "## Per-sample-grads, *the efficient way*, using functorch\n", "\n", "\n" ], "metadata": { "id": "mFJDWMM9yaYZ" }, "id": "mFJDWMM9yaYZ" }, { "cell_type": "markdown", "source": [ "We can compute per-sample-gradients efficiently by using function transforms. \n", "\n", "First, let’s create a stateless functional version of `model` by using `functorch.make_functional_with_buffers`. \n", "\n", "This will separate state (the parameters) from the model and turn the model into a pure function:\n", "\n" ], "metadata": { "id": "tlkmyQyfY6XU" }, "id": "tlkmyQyfY6XU" }, { "cell_type": "code", "source": [ "from functorch import make_functional_with_buffers, vmap, grad\n", "\n", "fmodel, params, buffers = make_functional_with_buffers(model)" ], "metadata": { "id": "WiSMupvCyecd" }, "execution_count": null, "outputs": [], "id": "WiSMupvCyecd" }, { "cell_type": "markdown", "source": [ "Let's review the changes - first, the model has become the stateless FunctionalModuleWithBuffers:" ], "metadata": { "id": "wMsbppPNZklo" }, "id": "wMsbppPNZklo" }, { "cell_type": "code", "source": [ "fmodel" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Xj0cZOJMZbbB", "outputId": "2e87dfde-3af2-4e1f-cd91-5c232446fb53" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "FunctionalModuleWithBuffers(\n", " (stateless_model): SimpleCNN(\n", " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", " (fc1): Linear(in_features=9216, out_features=128, bias=True)\n", " (fc2): Linear(in_features=128, out_features=10, bias=True)\n", " )\n", ")" ] }, "metadata": {}, "execution_count": 15 } ], "id": "Xj0cZOJMZbbB" }, { "cell_type": "markdown", "source": [ "And the model parameters now exist independently of the model, stored as a tuple:" ], "metadata": { "id": "zv4_YYPxZvvg" }, "id": "zv4_YYPxZvvg" }, { "cell_type": "code", "source": [ "for x in params:\n", " print(f\"{x.shape}\")\n", "\n", "print(f\"\\n{type(params)}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tH0TAZhBZ3bS", "outputId": "97c4401f-cccb-43f6-b071-c85a18fc439b" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([32, 1, 3, 3])\n", "torch.Size([32])\n", "torch.Size([64, 32, 3, 3])\n", "torch.Size([64])\n", "torch.Size([128, 9216])\n", "torch.Size([128])\n", "torch.Size([10, 128])\n", "torch.Size([10])\n", "\n", "\n" ] } ], "id": "tH0TAZhBZ3bS" }, { "cell_type": "markdown", "source": [ "Next, let’s define a function to compute the loss of the model given a single input rather than a batch of inputs. It is important that this function accepts the parameters, the input, and the target, because we will be transforming over them. \n", "\n", "Note - because the model was originally written to handle batches, we’ll use `torch.unsqueeze` to add a batch dimension.\n", "\n" ], "metadata": { "id": "cTgIIZ9Wyih8" }, "id": "cTgIIZ9Wyih8" }, { "cell_type": "code", "source": [ "def compute_loss_stateless_model (params, buffers, sample, target):\n", " batch = sample.unsqueeze(0)\n", " targets = target.unsqueeze(0)\n", "\n", " predictions = fmodel(params, buffers, batch) \n", " loss = loss_fn(predictions, targets)\n", " return loss" ], "metadata": { "id": "ItURFU3M-p98" }, "execution_count": null, "outputs": [], "id": "ItURFU3M-p98" }, { "cell_type": "markdown", "source": [ "Now, let’s use functorch's `grad` to create a new function that computes the gradient with respect to the first argument of `compute_loss` (i.e. the params)." ], "metadata": { "id": "Qo3sbDK2i_bH" }, "id": "Qo3sbDK2i_bH" }, { "cell_type": "code", "source": [ "ft_compute_grad = grad(compute_loss_stateless_model)" ], "metadata": { "id": "sqRp_Sxni-Xm" }, "execution_count": null, "outputs": [], "id": "sqRp_Sxni-Xm" }, { "cell_type": "markdown", "source": [ "The `ft_compute_grad` function computes the gradient for a single (sample, target) pair. We can use vmap to get it to compute the gradient over an entire batch of samples and targets. Note that `in_dims=(None, None, 0, 0)` because we wish to map `ft_compute_grad` over the 0th dimension of the data and targets, and use the same params and buffers for each.\n", "\n" ], "metadata": { "id": "2pG3Ofqjjc8O" }, "id": "2pG3Ofqjjc8O" }, { "cell_type": "code", "source": [ "ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))" ], "metadata": { "id": "62ecNMO6inqX" }, "execution_count": null, "outputs": [], "id": "62ecNMO6inqX" }, { "cell_type": "markdown", "source": [ "Finally, let’s used our transformed function to compute per-sample-gradients:\n", "\n" ], "metadata": { "id": "_alXdQ3QkETu" }, "id": "_alXdQ3QkETu" }, { "cell_type": "code", "source": [ "ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)\n", "\n", "# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:\n", "for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):\n", " assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)" ], "metadata": { "id": "1gehVA1c-BHd" }, "execution_count": null, "outputs": [], "id": "1gehVA1c-BHd" }, { "cell_type": "markdown", "source": [ "A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs, and that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations.\n", "\n", "\n", "\n" ], "metadata": { "id": "BEZaNt1d_bc1" }, "id": "BEZaNt1d_bc1" }, { "cell_type": "markdown", "source": [ "## Performance comparison" ], "metadata": { "id": "BASP151Iml7B" }, "id": "BASP151Iml7B" }, { "cell_type": "markdown", "source": [ "Curious about how the performance of vmap compares?\n", "\n", "Currently the best results are obtained on newer GPU's such as the A100 (Ampere) where we've seen up to 25x speedups on this example, but here are some results done in Colab:" ], "metadata": { "id": "jr1xNpV4nJ7u" }, "id": "jr1xNpV4nJ7u" }, { "cell_type": "code", "source": [ "def get_perf(first, first_descriptor, second, second_descriptor):\n", " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n", " second_res = second.times[0]\n", " first_res = first.times[0]\n", "\n", " gain = (first_res-second_res)/first_res\n", " if gain < 0: gain *=-1 \n", " final_gain = gain*100\n", "\n", " print(f\" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} \")" ], "metadata": { "id": "GnAnMkYmoc-j" }, "execution_count": null, "outputs": [], "id": "GnAnMkYmoc-j" }, { "cell_type": "code", "source": [ "from torch.utils.benchmark import Timer\n", "\n", "without_vmap = Timer( stmt=\"compute_sample_grads(data, targets)\", globals=globals())\n", "with_vmap = Timer(stmt=\"ft_compute_sample_grad(params, buffers, data, targets)\",globals=globals())\n", "no_vmap_timing = without_vmap.timeit(100)\n", "with_vmap_timing = with_vmap.timeit(100)\n", "\n", "print(f'Per-sample-grads without vmap {no_vmap_timing}')\n", "print(f'Per-sample-grads with vmap {with_vmap_timing}')" ], "metadata": { "id": "Zfnn2C2g-6Fb", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "922f3901-773f-446b-b562-88e78f49036c" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Per-sample-grads without vmap \n", "compute_sample_grads(data, targets)\n", " 79.86 ms\n", " 1 measurement, 100 runs , 1 thread\n", "Per-sample-grads with vmap \n", "ft_compute_sample_grad(params, buffers, data, targets)\n", " 12.93 ms\n", " 1 measurement, 100 runs , 1 thread\n" ] } ], "id": "Zfnn2C2g-6Fb" }, { "cell_type": "code", "source": [ "get_perf(with_vmap_timing, \"vmap\", no_vmap_timing,\"no vmap\" )" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NV9R3LZQoavl", "outputId": "e11e8be9-287d-4e60-e517-e08f8d6909bd" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Performance delta: 517.5791 percent improvement with vmap \n" ] } ], "id": "NV9R3LZQoavl" }, { "cell_type": "markdown", "source": [ "There are other optimized solutions (like in https://github.com/pytorch/opacus) to computing per-sample-gradients in PyTorch that also perform better than the naive method. But it’s cool that composing `vmap` and `grad` give us a nice speedup.\n", "\n", "\n", "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n", "\n" ], "metadata": { "id": "UI74G9JarQU8" }, "id": "UI74G9JarQU8" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "colab": { "name": "per_sample_grads.ipynb", "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: packaging/windows/internal/cuda_install.bat ================================================ @echo on if "%CU_VERSION%" == "cpu" ( echo Skipping for CPU builds exit /b 0 ) set SRC_DIR=%~dp0\.. if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" rem in unit test workflow, we get CUDA_VERSION, for example 11.1 if defined CUDA_VERSION ( set CUDA_VER=%CUDA_VERSION:.=% ) else ( set CUDA_VER=%CU_VERSION:cu=% ) set /a CUDA_VER=%CU_VERSION:cu=% set CUDA_VER_MAJOR=%CUDA_VER:~0,-1% set CUDA_VER_MINOR=%CUDA_VER:~-1,1% set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR% if %CUDA_VER% EQU 92 goto cuda92 if %CUDA_VER% EQU 100 goto cuda100 if %CUDA_VER% EQU 101 goto cuda101 if %CUDA_VER% EQU 102 goto cuda102 if %CUDA_VER% EQU 110 goto cuda110 if %CUDA_VER% EQU 111 goto cuda111 if %CUDA_VER% EQU 112 goto cuda112 if %CUDA_VER% EQU 113 goto cuda113 if %CUDA_VER% EQU 115 goto cuda115 echo CUDA %CUDA_VERSION_STR% is not supported exit /b 1 :cuda92 if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" ( curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2" ) if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ( curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ) goto cuda_common :cuda100 if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" ( curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0" ) if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ( curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ) goto cuda_common :cuda101 if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvjpeg_10.1 nvjpeg_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" ) if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ) goto cuda_common :cuda102 if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvjpeg_10.2 nvjpeg_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2" ) if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ) rem The below only for cu102, if it's used in other version, e.g. cu111, torch.cuda.is_availabe() would be False. if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" ( curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" if errorlevel 1 exit /b 1 ) echo Installing GPU driver DLLs 7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32" goto cuda_common :cuda110 if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvjpeg_11.0 nvjpeg_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" ) if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ) goto cuda_common :cuda111 if not exist "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.1_456.81_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvjpeg_11.1 nvjpeg_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" ) if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ) goto cuda_common :cuda112 if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" ( curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvjpeg_11.2 nvjpeg_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2" ) if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ( curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ) goto cuda_common :cuda113 set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3" ) set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ) goto cuda_common :cuda115 set CUDA_INSTALL_EXE=cuda_11.5.0_496.13_win10.exe if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" set "ARGS=thrust_11.5 nvcc_11.5 cuobjdump_11.5 nvprune_11.5 nvprof_11.5 cupti_11.5 cublas_11.5 cublas_dev_11.5 cudart_11.5 cufft_11.5 cufft_dev_11.5 curand_11.5 curand_dev_11.5 cusolver_11.5 cusolver_dev_11.5 cusparse_11.5 cusparse_dev_11.5 npp_11.5 npp_dev_11.5 nvrtc_11.5 nvrtc_dev_11.5 nvml_dev_11.5" ) set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ) goto cuda_common :cuda_common if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z" if errorlevel 1 exit /b 1 ) echo Installing CUDA toolkit... 7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda" pushd "%SRC_DIR%\temp_build\cuda" sc config wuauserv start= disabled sc stop wuauserv sc query wuauserv start /wait setup.exe -s %ARGS% -loglevel:6 -log:"%cd%/cuda_install_logs" echo %errorlevel% popd echo Installing VS integration... rem It's for VS 2019 if "%CUDA_VER_MAJOR%" == "10" ( xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" ) if "%CUDA_VER_MAJOR%" == "11" ( xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" ) echo Installing NvToolsExt... 7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" echo Setting up environment... set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( echo CUDA %CUDA_VERSION_STR% installed failed. echo --------- RunDll32.exe.log type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.RunDll32.exe.log" echo --------- setup.exe.log ------- type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.setup.exe.log" exit /b 1 ) echo Installing cuDNN... 7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64" xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include" echo Cleaning temp files rd /s /q "%SRC_DIR%\temp_build" || ver > nul ================================================ FILE: packaging/windows/internal/driver_update.bat ================================================ set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe if errorlevel 1 exit /b 1 start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot if errorlevel 1 exit /b 1 del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL setlocal EnableDelayedExpansion set NVIDIA_GPU_EXISTS=0 for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( set GPUS=%%i if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( SET NVIDIA_GPU_EXISTS=1 goto gpu_check_end ) ) :gpu_check_end endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% if "%NVIDIA_GPU_EXISTS%" == "0" ( echo "CUDA Driver installation Failed" exit /b 1 ) ================================================ FILE: pull_request_template.md ================================================ To contribute a change to functorch, please make sure you are submitting a Pull Request to the functorch folder in https://github.com/pytorch/pytorch repository. The source of truth for functorch has moved there from https://github.com/pytorch/functorch ; the pytorch/functorch repository is now read-only. ================================================ FILE: setup.cfg ================================================ [bdist_wheel] universal=1 [metadata] license_file = LICENSE [pep8] max-line-length = 120 [flake8] max-line-length = 120 exclude = docs, benchmarks, notebooks, tools per-file-ignores = __init__.py: F401 functorch/_src/decompositions.py: E501 [pydocstyle] select = D417 # Missing argument descriptions in the docstring ================================================ FILE: setup.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import os import subprocess from setuptools import setup cwd = os.path.dirname(os.path.abspath(__file__)) version_txt = os.path.join(cwd, 'version.txt') with open(version_txt, 'r') as f: version = f.readline().strip() try: sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() except Exception: sha = 'Unknown' package_name = 'functorch' if os.getenv('BUILD_VERSION'): version = os.getenv('BUILD_VERSION') elif sha != 'Unknown': version += '+' + sha[:7] requirements = [ # This represents a nightly version of PyTorch. # It can be installed as a binary or from source. "torch>=1.14.0.dev", ] extras = {} extras["aot"] = ["networkx", ] if __name__ == '__main__': try: setup( # Metadata name=package_name, version=version, author='PyTorch Core Team', url="https://github.com/pytorch/functorch", description='JAX-like composable function transforms for PyTorch', license='BSD', # Package info packages=[], install_requires=requirements, extras_require=extras, ) except Exception as e: print(e, file=sys.stderr) sys.exit(1) ================================================ FILE: version.txt ================================================ 1.14.0a0